diff --git a/tests/models/test_intern_vit.py b/tests/models/test_intern_vit.py index e980446ff3570..816f846f69bae 100644 --- a/tests/models/test_intern_vit.py +++ b/tests/models/test_intern_vit.py @@ -6,8 +6,6 @@ from huggingface_hub import snapshot_download from transformers import AutoConfig, AutoModel, CLIPImageProcessor -from vllm.model_executor.models.intern_vit import InternVisionModel - from ..conftest import _ImageAssets, cleanup pytestmark = pytest.mark.vlm @@ -49,6 +47,7 @@ def run_intern_vit_test( for pixel_value in pixel_values ] + from vllm.model_executor.models.intern_vit import InternVisionModel vllm_model = InternVisionModel(config) vllm_model.load_weights(hf_model.state_dict().items()) diff --git a/tests/models/test_internvl.py b/tests/models/test_internvl.py index 243bc857c88de..42732cebc6567 100644 --- a/tests/models/test_internvl.py +++ b/tests/models/test_internvl.py @@ -6,9 +6,6 @@ from PIL.Image import Image from transformers import AutoConfig -from vllm.model_executor.models.internvl import (IMG_CONTEXT, IMG_END, - IMG_START, - image_to_pixel_values) from vllm.multimodal.utils import rescale_image_size from vllm.utils import is_cpu @@ -33,35 +30,6 @@ ] -class InternVLProcessor: - """A simple processor for InternVL2 HF model which misses a processor.""" - - def __init__(self, hf_runner: HfRunner): - self.num_image_token = hf_runner.model.num_image_token - self.tokenizer = hf_runner.tokenizer - self.dtype = hf_runner.model.dtype - - self.config = AutoConfig.from_pretrained(hf_runner.model_name) - self.vision_config = self.config.vision_config - self.use_thumbnail = self.config.use_thumbnail - self.min_num = self.config.min_dynamic_patch - self.max_num = self.config.max_dynamic_patch - self.image_size = self.vision_config.image_size - - def __call__(self, text: str, images: Image, **kwargs): - pixel_values = image_to_pixel_values(images, self.image_size, - self.min_num, self.max_num, - self.use_thumbnail).to(self.dtype) - num_patches_list = [pixel_values.shape[0]] - for num_patches in num_patches_list: - context_tokens = IMG_CONTEXT * self.num_image_token * num_patches - image_tokens = IMG_START + context_tokens + IMG_END - text = text.replace('', image_tokens, 1) - prompt = self.tokenizer(text, return_tensors="pt") - prompt.update({"pixel_values": pixel_values}) - return prompt - - # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py def generate( self, @@ -127,6 +95,37 @@ def run_test( # if we run HF first, the cuda initialization will be done and it # will hurt multiprocessing backend with fork method (the default method). + class InternVLProcessor: + """A simple processor for InternVL2 which misses a processor.""" + + def __init__(self, hf_runner: HfRunner): + self.num_image_token = hf_runner.model.num_image_token + self.tokenizer = hf_runner.tokenizer + self.dtype = hf_runner.model.dtype + + self.config = AutoConfig.from_pretrained(hf_runner.model_name) + self.vision_config = self.config.vision_config + self.use_thumbnail = self.config.use_thumbnail + self.min_num = self.config.min_dynamic_patch + self.max_num = self.config.max_dynamic_patch + self.image_size = self.vision_config.image_size + + def __call__(self, text: str, images: Image, **kwargs): + from vllm.model_executor.models.internvl import ( + IMG_CONTEXT, IMG_END, IMG_START, image_to_pixel_values) + pixel_values = image_to_pixel_values( + images, self.image_size, self.min_num, self.max_num, + self.use_thumbnail).to(self.dtype) + num_patches_list = [pixel_values.shape[0]] + for num_patches in num_patches_list: + context_tokens = IMG_CONTEXT * self.num_image_token \ + * num_patches + image_tokens = IMG_START + context_tokens + IMG_END + text = text.replace('', image_tokens, 1) + prompt = self.tokenizer(text, return_tensors="pt") + prompt.update({"pixel_values": pixel_values}) + return prompt + # max_model_len should be greater than image_feature_size with vllm_runner(model, max_model_len=4096, diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py index 830680fd990bf..e6acf8cd5d5bb 100644 --- a/vllm/model_executor/models/blip.py +++ b/vllm/model_executor/models/blip.py @@ -7,12 +7,14 @@ import torch.nn as nn from PIL import Image from transformers import Blip2VisionConfig, BlipVisionConfig -from transformers.models.blip.modeling_blip import BlipAttention +from xformers import ops as xops from vllm.config import ModelConfig +from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.inputs import LLMInputs from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal.utils import (cached_get_tokenizer, @@ -154,6 +156,77 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: return embeddings +class BlipAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: BlipVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + "embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads}).") + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.qkv = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.num_heads, + bias=config.qkv_bias, + quant_config=quant_config, + ) + self.projection = RowParallelLinear( + self.embed_dim, + self.embed_dim, + quant_config=quant_config, + ) + + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_partition = divide(self.num_heads, self.tp_size) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, + self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + ): + """Input shape: Batch x Time x Channel""" + bsz, tgt_len, _ = hidden_states.size() + + qkv_states, _ = self.qkv(hidden_states) + query_states, key_states, value_states = qkv_states.chunk(3, dim=-1) + query_states = query_states.view(bsz, tgt_len, + self.num_heads_per_partition, + self.head_dim) + key_states = key_states.view(bsz, tgt_len, + self.num_heads_per_partition, + self.head_dim) + value_states = value_states.view(bsz, tgt_len, + self.num_heads_per_partition, + self.head_dim) + + out = xops.memory_efficient_attention_forward(query_states, + key_states, + value_states, + p=self.dropout, + scale=self.scale) + out = out.view(bsz, tgt_len, -1) + attn_output, _ = self.projection(out) + + return attn_output + + class BlipMLP(nn.Module): def __init__(self, @@ -188,7 +261,7 @@ def __init__(self, quant_config: Optional[QuantizationConfig] = None): super().__init__() - self.self_attn = BlipAttention(config) + self.self_attn = BlipAttention(config, quant_config=quant_config) self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.mlp = BlipMLP(config, quant_config=quant_config) @@ -199,7 +272,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: residual = hidden_states hidden_states = self.layer_norm1(hidden_states) - hidden_states, _ = self.self_attn(hidden_states=hidden_states) + hidden_states = self.self_attn(hidden_states=hidden_states) hidden_states = residual + hidden_states residual = hidden_states diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 0ed46f39cacd9..39f2b2d853a6b 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -714,8 +714,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): use_default_weight_loading = False if "vision" in name: if self.vision_model is not None: - # We only do sharding for language model and - # not vision model for now. + # BlipVisionModel does not need sharding use_default_weight_loading = True else: for (param_name, weight_name, diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 69bb9f6f3afee..ddfec91d6cab2 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -7,12 +7,14 @@ import torch.nn as nn from PIL import Image from transformers import CLIPVisionConfig -from transformers.models.clip.modeling_clip import CLIPAttention +from xformers import ops as xops from vllm.config import ModelConfig +from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.inputs import LLMInputs from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -160,6 +162,78 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: return embeddings +class CLIPAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: CLIPVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + "embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads}).") + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.qkv_proj = QKVParallelLinear( + hidden_size=self.embed_dim, + head_size=self.head_dim, + total_num_heads=self.num_heads, + quant_config=quant_config, + ) + + self.out_proj = RowParallelLinear( + input_size=self.embed_dim, + output_size=self.embed_dim, + quant_config=quant_config, + ) + + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_partition = divide(self.num_heads, self.tp_size) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, + self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + ): + """Input shape: Batch x Time x Channel""" + bsz, tgt_len, _ = hidden_states.size() + + qkv_states, _ = self.qkv_proj(hidden_states) + query_states, key_states, value_states = qkv_states.chunk(3, dim=-1) + + query_states = query_states.view(bsz, tgt_len, + self.num_heads_per_partition, + self.head_dim) + key_states = key_states.view(bsz, tgt_len, + self.num_heads_per_partition, + self.head_dim) + value_states = value_states.view(bsz, tgt_len, + self.num_heads_per_partition, + self.head_dim) + + out = xops.memory_efficient_attention_forward(query_states, + key_states, + value_states, + p=self.dropout, + scale=self.scale) + out = out.view(bsz, tgt_len, -1) + attn_output, _ = self.out_proj(out) + + return attn_output + + class CLIPMLP(nn.Module): def __init__(self, @@ -192,7 +266,7 @@ def __init__(self, quant_config: Optional[QuantizationConfig] = None): super().__init__() - self.self_attn = CLIPAttention(config) + self.self_attn = CLIPAttention(config, quant_config=quant_config) self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.mlp = CLIPMLP(config, quant_config=quant_config) @@ -204,7 +278,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: residual = hidden_states hidden_states = self.layer_norm1(hidden_states) - hidden_states, _ = self.self_attn(hidden_states=hidden_states) + hidden_states = self.self_attn(hidden_states=hidden_states) hidden_states = residual + hidden_states residual = hidden_states @@ -304,7 +378,15 @@ def forward(self, pixel_values: Optional[torch.Tensor] = None): def device(self): return next(self.parameters()).device + # (TODO) Add prefix argument for filtering out weights to be loaded + # ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986 def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] params_dict = dict(self.named_parameters()) layer_count = len(self.vision_model.encoder.layers) @@ -318,7 +400,16 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if layer_idx >= layer_count: continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py index 54c933e3e4959..ad5919150cad8 100644 --- a/vllm/model_executor/models/intern_vit.py +++ b/vllm/model_executor/models/intern_vit.py @@ -10,10 +10,13 @@ import torch.nn as nn import torch.nn.functional as F from transformers import PretrainedConfig +from xformers import ops as xops +from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -81,7 +84,11 @@ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: class InternAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: PretrainedConfig): + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ): super().__init__() self.config = config self.embed_dim = config.hidden_size @@ -94,9 +101,13 @@ def __init__(self, config: PretrainedConfig): f' {self.num_heads}).') self.scale = self.head_dim**-0.5 - self.qkv = nn.Linear(self.embed_dim, - 3 * self.embed_dim, - bias=config.qkv_bias) + self.qkv = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.num_heads, + bias=config.qkv_bias, + quant_config=quant_config, + ) self.qk_normalization = config.qk_normalization @@ -104,25 +115,40 @@ def __init__(self, config: PretrainedConfig): self.q_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps) self.k_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps) - self.proj = nn.Linear(self.embed_dim, self.embed_dim) + self.proj = RowParallelLinear( + self.embed_dim, + self.embed_dim, + quant_config=quant_config, + ) + + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_partition = divide(self.num_heads, self.tp_size) def forward(self, x): B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, - C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv.unbind(0) - - if self.qk_normalization: - B_, H_, N_, D_ = q.shape - q = self.q_norm.forward_native(q.transpose(1, 2).flatten( - -2, -1)).view(B_, N_, H_, D_).transpose(1, 2) - k = self.k_norm.forward_native(k.transpose(1, 2).flatten( - -2, -1)).view(B_, N_, H_, D_).transpose(1, 2) + qkv, _ = self.qkv(x) + q, k, v = qkv.chunk(3, dim=-1) - x = F.scaled_dot_product_attention(q, k, v, scale=self.scale) - x = x.transpose(1, 2).reshape(B, N, C) + q = q.view(B, N, self.num_heads_per_partition, self.head_dim) + k = k.view(B, N, self.num_heads_per_partition, self.head_dim) + v = v.view(B, N, self.num_heads_per_partition, self.head_dim) - x = self.proj(x) + if self.qk_normalization: + B_, N_, H_, D_ = q.shape + q = self.q_norm.forward_native(q.flatten(-2, + -1)).view(B_, N_, H_, D_) + k = self.k_norm.forward_native(k.flatten(-2, + -1)).view(B_, N_, H_, D_) + + x = xops.memory_efficient_attention_forward( + q, + k, + v, + scale=self.scale, + ) + x = x.view(B, N, -1) + + x, _ = self.proj(x) return x @@ -161,7 +187,7 @@ def __init__(self, self.intermediate_size = config.intermediate_size self.norm_type = config.norm_type - self.attn = InternAttention(config) + self.attn = InternAttention(config, quant_config=quant_config) self.mlp = InternMLP(config, quant_config=quant_config) self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps) diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 104b89e06fa5f..9b29ff69808a6 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -145,7 +145,6 @@ def __init__(self, self.config = config self.multimodal_config = multimodal_config - # TODO(ywang96): Port over SiglipVisionModel & TP self.vision_tower = SiglipVisionModel(config.vision_config) self.multi_modal_projector = PaliGemmaMultiModalProjector( vision_hidden_size=config.vision_config.hidden_size, @@ -308,34 +307,27 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if key_to_modify in name: name = name.replace(key_to_modify, new_key) use_default_weight_loading = False - if "vision" in name: - if self.vision_tower is not None: - # We only do sharding for language model and - # not vision model for now. - use_default_weight_loading = True + for (param_name, shard_name, shard_id) in stacked_params_mapping: + if shard_name not in name: + continue + name = name.replace(shard_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break else: - for (param_name, shard_name, - shard_id) in stacked_params_mapping: - if shard_name not in name: - continue - name = name.replace(shard_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # lm_head is not used in vllm as it is tied with - # embed_token. To prevent errors, skip loading - # lm_head.weight. - if "lm_head.weight" in name: - continue - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - use_default_weight_loading = True + # lm_head is not used in vllm as it is tied with + # embed_token. To prevent errors, skip loading + # lm_head.weight. + if "lm_head.weight" in name: + continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + use_default_weight_loading = True if use_default_weight_loading: param = params_dict[name] diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 2fad3ec3e5651..c449e0fc759a3 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -71,6 +71,23 @@ projection_dim=768) +def _init_img_processor(hf_config: PretrainedConfig): + clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG + layer_idx = hf_config.img_processor.get('layer_idx', -2) + + # Initialize the CLIP only up to the required feature layer + if layer_idx < 0: + num_hidden_layers = clip_config.num_hidden_layers + \ + layer_idx + 1 + else: + num_hidden_layers = layer_idx + 1 + + img_processor = CLIPVisionModel( + clip_config, num_hidden_layers_override=num_hidden_layers) + + return img_processor + + class Phi3VImagePixelInputs(TypedDict): type: Literal["pixel_values"] data: Union[torch.Tensor, List[torch.Tensor]] @@ -139,18 +156,8 @@ def __init__(self, config: PretrainedConfig) -> None: hidden_size = config.n_embd if hasattr( config, 'n_embd') else config.hidden_size - clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG - self.layer_idx = config.img_processor.get('layer_idx', -2) - - # Initialize the CLIP only up to the required feature layer - if self.layer_idx < 0: - num_hidden_layers = clip_config.num_hidden_layers + \ - self.layer_idx + 1 - else: - num_hidden_layers = self.layer_idx + 1 + self.img_processor = _init_img_processor(config) - self.img_processor = CLIPVisionModel( - clip_config, num_hidden_layers_override=num_hidden_layers) image_dim_out = config.img_processor['image_dim_out'] self.num_img_tokens = config.img_processor['num_img_tokens'] @@ -656,23 +663,27 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): (".gate_up_proj", ".gate_proj", 0), (".gate_up_proj", ".up_proj", 1), ] + + # TODO(ChristopherCho): This is a temporary fix to load + # the vision weights with CLIPVisionModel.load_weights() + vision_weights = [] params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - # post_layernorm is not needed in CLIPVisionModel - if "vision_model.post_layernorm" in name: + # Skip loading the img_processor weights since they are + # loaded separately. + if "vision_embed_tokens.img_processor" in name: + vision_weights.append((name, loaded_weight)) continue + for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): if key_to_modify in name: name = name.replace(key_to_modify, new_key) for (param_name, weight_name, shard_id) in stacked_params_mapping: - # We only do sharding for language model - # and not vision model for now. - if "vision_embed_tokens" in name and self.vision_embed_tokens: - continue if weight_name not in name: continue + param = params_dict[name.replace(weight_name, param_name)] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -686,3 +697,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + + # We use regex to extract the sub-module name + # from "model.vision_embed_tokens.img_processor.*" + vision_weights = [ + (re.search(r"vision_embed_tokens\.img_processor\.(.*)", + n).group(1), w) for n, w in vision_weights + ] + self.vision_embed_tokens.img_processor.load_weights(vision_weights) diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 073f60bb3a056..e6f95af0ff49f 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -9,12 +9,10 @@ from PIL import Image from torch import nn from transformers import SiglipVisionConfig -from transformers.models.siglip.modeling_siglip import SiglipAttention -from vllm_flash_attn import flash_attn_func -from xformers.ops import memory_efficient_attention +from xformers import ops as xops from vllm.config import ModelConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.inputs import LLMInputs from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -221,9 +219,7 @@ def forward(self, return embeddings -# NOTE: Not used - kept for later when we TP the ViT -# TODO(ChristopherCho): Implement TP version of Attention -class SiglipTPAttention(nn.Module): +class SiglipAttention(nn.Module): def __init__( self, @@ -233,38 +229,30 @@ def __init__( super().__init__() self.config = config self.embed_dim = config.hidden_size - - tp_size = get_tensor_model_parallel_world_size() - self.total_num_heads = config.num_attention_heads - if self.total_num_heads % tp_size != 0: - raise ValueError( - f"Number of attention heads ({self.total_num_heads}) " - "must be divisible by the tensor model parallel size" - f" ({tp_size}).") - - self.num_heads = self.total_num_heads // tp_size - self.head_dim = self.embed_dim // self.total_num_heads - if self.head_dim * self.total_num_heads != self.embed_dim: + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: raise ValueError(f"embed_dim must be divisible by num_heads (got " "`embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads}).") - self.qkv_size = self.num_heads * self.head_dim + self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout - self.qkv_proj = QKVParallelLinear( hidden_size=self.embed_dim, head_size=self.head_dim, - total_num_heads=self.total_num_heads, + total_num_heads=self.num_heads, quant_config=quant_config, ) + self.out_proj = RowParallelLinear( input_size=self.embed_dim, output_size=self.embed_dim, quant_config=quant_config, ) - self.attn_fn = self._basic_attention_forward + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_partition = divide(self.num_heads, self.tp_size) def forward( self, @@ -274,163 +262,29 @@ def forward( batch_size, q_len, _ = hidden_states.size() qkv_states, _ = self.qkv_proj(hidden_states) - query_states, key_states, value_states = qkv_states.split( - [self.qkv_size] * 3, dim=-1) - - attn_output = self.attn_fn( - q=query_states, - k=key_states, - v=value_states, - batch_size=batch_size, - q_len=q_len, - ) - - attn_output, _ = self.out_proj(attn_output) - return attn_output - - def _basic_attention_forward(self, q, k, v, batch_size, q_len): - q = q.view(batch_size, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - k = k.view(batch_size, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - v = v.view(batch_size, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - - k_v_seq_len = k.shape[-2] - attn_weights = torch.matmul(q, k.transpose(2, 3)) * self.scale - - if attn_weights.size() != ( - batch_size, - self.num_heads, - q_len, - k_v_seq_len, - ): - raise ValueError( - "Attention weights should be of size " - f"{(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" - f" {attn_weights.size()}") - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, - dim=-1, - dtype=torch.float32).to(q.dtype) - attn_weights = nn.functional.dropout(attn_weights, - p=self.dropout, - training=self.training) - attn_output = torch.matmul(attn_weights, v) - - if attn_output.size() != ( - batch_size, - self.num_heads, - q_len, - self.head_dim, - ): - raise ValueError( - "`attn_output` should be of size " - f"{(batch_size, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}") - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) - - return attn_output - - -# NOTE: Not used - kept for later when we TP the ViT -# TODO(ChristopherCho): flash_attn_func is not working properly. -# It constantly throws a CUDA error. -class SiglipFlashAttention2(SiglipTPAttention): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.attn_fn = self._flash_attention_forward - - # Ported from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L449 - # and https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/modeling_flash_attention_utils.py#L133 - def _flash_attention_forward(self, q, k, v, batch_size, q_len, *args, - **kwargs): - """Implements the multihead softmax attention. - Arguments - --------- - q, k, v: The tensor containing the - query, key, and value. (B, S, H, D) - """ - - q = q.view(batch_size, q_len, self.num_heads, self.head_dim) - k = k.view(batch_size, q_len, self.num_heads, self.head_dim) - v = v.view(batch_size, q_len, self.num_heads, self.head_dim) - - attn_output = flash_attn_func( - q, - k, - v, - dropout_p=self.dropout, - causal=False, - ) - - attn_output = attn_output.reshape(batch_size, q_len, - self.embed_dim).contiguous() + query_states, key_states, value_states = qkv_states.chunk(3, dim=-1) + + query_states = query_states.view(batch_size, q_len, + self.num_heads_per_partition, + self.head_dim) + key_states = key_states.view(batch_size, q_len, + self.num_heads_per_partition, + self.head_dim) + value_states = value_states.view(batch_size, q_len, + self.num_heads_per_partition, + self.head_dim) + + out = xops.memory_efficient_attention_forward(query_states, + key_states, + value_states, + p=self.dropout, + scale=self.scale) + out = out.view(batch_size, q_len, -1) + attn_output, _ = self.out_proj(out) return attn_output -# NOTE: Not used - kept for later when we TP the ViT -class SiglipSdpaAttention(SiglipTPAttention): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.is_causal = False - self.attn_fn = self._sdpa_attention_forward - - def _sdpa_attention_forward(self, q, k, v, batch_size, q_len): - q = q.view(batch_size, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - k = k.view(batch_size, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - v = v.view(batch_size, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - - attn_output = torch.nn.functional.scaled_dot_product_attention( - q, k, v, dropout_p=self.dropout, is_causal=False, scale=self.scale) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(batch_size, q_len, self.embed_dim) - - return attn_output - - -# NOTE: Not used - kept for later when we TP the ViT -class SiglipxFormersAttention(SiglipTPAttention): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.attn_fn = self._xformers_attention_forward - - def _xformers_attention_forward(self, q, k, v, batch_size, q_len): - q = q.view(batch_size, q_len, self.num_heads, self.head_dim) - k = k.view(batch_size, q_len, self.num_heads, self.head_dim) - v = v.view(batch_size, q_len, self.num_heads, self.head_dim) - - attn_output = memory_efficient_attention(q, - k, - v, - p=0.0, - scale=self.scale) - attn_output = attn_output.reshape(batch_size, q_len, - self.embed_dim).contiguous() - - return attn_output - - -# NOTE: Not used - kept for later when we TP the ViT -SIGLIP_ATTENTION_CLASSES = { - "eager": SiglipTPAttention, - "flash_attention_2": SiglipFlashAttention2, - "sdpa": SiglipSdpaAttention, - "xformers": SiglipxFormersAttention, -} - - class SiglipMLP(nn.Module): def __init__( @@ -473,8 +327,7 @@ def __init__( super().__init__() self.embed_dim = config.hidden_size - # TODO(ChristopherCho): use TP'ed Attention block - self.self_attn = SiglipAttention(config) + self.self_attn = SiglipAttention(config, quant_config=quant_config) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = SiglipMLP( @@ -491,7 +344,7 @@ def forward( residual = hidden_states hidden_states = self.layer_norm1(hidden_states) - hidden_states, _ = self.self_attn(hidden_states=hidden_states) + hidden_states = self.self_attn(hidden_states=hidden_states) hidden_states = residual + hidden_states residual = hidden_states