diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index 937dc7ec23..039ad42918 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -101,7 +101,10 @@ GaudiMllamaTextCrossAttention, GaudiMllamaTextModel, GaudiMllamaTextSelfAttention, + GaudiMllamaVisionEncoder, + GaudiMllamaVisionEncoderLayer, GaudiMllamaVisionModel, + GaudiMllamaVisionSdpaAttention, GaudiMptAttention, GaudiMptBlock, GaudiMptForCausalLM, @@ -661,6 +664,9 @@ def adapt_transformers_to_gaudi(): transformers.models.mllama.modeling_mllama.MllamaForConditionalGeneration = GaudiMllamaForConditionalGeneration transformers.models.mllama.modeling_mllama.MllamaTextModel = GaudiMllamaTextModel transformers.models.mllama.modeling_mllama.MllamaVisionModel = GaudiMllamaVisionModel + transformers.models.mllama.modeling_mllama.MllamaVisionEncoder = GaudiMllamaVisionEncoder + transformers.models.mllama.modeling_mllama.MllamaVisionEncoderLayer = GaudiMllamaVisionEncoderLayer + transformers.models.mllama.modeling_mllama.MllamaVisionSdpaAttention = GaudiMllamaVisionSdpaAttention transformers.AutoConfig.register("deci", DeciLMConfig) transformers.AutoModelForCausalLM.register(DeciLMConfig, DeciLMForCausalLM) diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index 5e8ffb0b07..ba19be3a56 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -166,7 +166,10 @@ GaudiMllamaTextCrossAttention, GaudiMllamaTextModel, GaudiMllamaTextSelfAttention, + GaudiMllamaVisionEncoder, + GaudiMllamaVisionEncoderLayer, GaudiMllamaVisionModel, + GaudiMllamaVisionSdpaAttention, ) from .modeling_all_models import ( gaudi_check_and_enable_sdpa, diff --git a/optimum/habana/transformers/models/mllama/__init__.py b/optimum/habana/transformers/models/mllama/__init__.py index 198f1cc2aa..7ff31e3220 100644 --- a/optimum/habana/transformers/models/mllama/__init__.py +++ b/optimum/habana/transformers/models/mllama/__init__.py @@ -6,5 +6,8 @@ GaudiMllamaTextCrossAttention, GaudiMllamaTextModel, GaudiMllamaTextSelfAttention, + GaudiMllamaVisionEncoder, + GaudiMllamaVisionEncoderLayer, GaudiMllamaVisionModel, + GaudiMllamaVisionSdpaAttention, ) diff --git a/optimum/habana/transformers/models/mllama/modeling_mllama.py b/optimum/habana/transformers/models/mllama/modeling_mllama.py index e5c7ced0d4..7e73868249 100644 --- a/optimum/habana/transformers/models/mllama/modeling_mllama.py +++ b/optimum/habana/transformers/models/mllama/modeling_mllama.py @@ -18,6 +18,7 @@ import os from typing import List, Optional, Tuple, Union +import habana_frameworks.torch.core as htcore import torch import torch.nn.functional as F import torch.utils.checkpoint @@ -35,6 +36,10 @@ MllamaTextCrossAttention, MllamaTextModel, MllamaTextSelfAttention, + MllamaVisionAttention, + MllamaVisionConfig, + MllamaVisionEncoder, + MllamaVisionEncoderLayer, MllamaVisionModel, _prepare_4d_causal_attention_mask_with_cache_position, _prepare_aspect_ratio_attention_mask, @@ -107,6 +112,163 @@ def _prepare_cross_attention_mask( return cross_attention_mask, full_text_row_masked_out_mask +class GaudiMllamaVisionSdpaAttention(MllamaVisionAttention): + def __init__(self, config: MllamaVisionConfig): + super().__init__(config) + self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None + + # Adapted from MllamaVisionAttention + def forward( + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = None, + use_flash_attention: Optional[bool] = False, + ) -> torch.Tensor: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + """ + Copied from MllamaVisionSdpaAttention::forward:https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L283 + The only differences are: + - add use_flash_attention + """ + if output_attentions: + logger.warning_once( + "MllamaModel is using MllamaVisionSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_state=hidden_state, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + + query = self.q_proj(hidden_state) + key = self.k_proj(hidden_state) + value = self.v_proj(hidden_state) + + batch_size, q_seq_len, _ = query.shape + _, kv_seq_len, _ = key.shape + + query = query.view(batch_size, q_seq_len, self.num_heads, self.head_dim) + key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim) + value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim) + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + if use_flash_attention and FusedSDPA: + attn_output = self.fused_scaled_dot_product_attention(query, key, value, attention_mask, 0.0, False, None) + else: + attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_seq_len, -1) + + output = self.o_proj(attn_output) + + return output, None + + +class GaudiMllamaVisionEncoderLayer(MllamaVisionEncoderLayer): + def __init__(self, config: MllamaVisionConfig, is_gated: bool = False): + super(GaudiMllamaVisionEncoderLayer, self).__init__(config=config, is_gated=is_gated) + self.self_attn = GaudiMllamaVisionSdpaAttention(config) + + def forward( + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = None, + use_flash_attention: Optional[bool] = False, + ): + """ + Copied from MllamaVisionEncoderLayer::forward:https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L348 + The only differences are: + - add use_flash_attention + """ + # Self Attention + residual = hidden_state + hidden_state = self.input_layernorm(hidden_state) + hidden_state, attn_weights = self.self_attn( + hidden_state, attention_mask=attention_mask, use_flash_attention=use_flash_attention + ) + if self.is_gated: + hidden_state = self.gate_attn.tanh() * hidden_state + hidden_state = residual + hidden_state + + # Feed forward + residual = hidden_state + hidden_state = self.post_attention_layernorm(hidden_state) + hidden_state = self.mlp(hidden_state) + if self.is_gated: + hidden_state = self.gate_ffn.tanh() * hidden_state + hidden_state = residual + hidden_state + + outputs = (hidden_state,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class GaudiMllamaVisionEncoder(MllamaVisionEncoder): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + use_flash_attention: Optional[bool] = False, + ) -> Union[Tuple, BaseModelOutput]: + """ + Copied from MllamaVisionEncoder::forward:https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L394 + The only differences are: + - add use_flash_attention + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_state=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + use_flash_attention=use_flash_attention, + ) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + htcore.mark_step() + hidden_states = layer_outputs[0] + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + class GaudiMllamaTextCrossAttention(MllamaTextCrossAttention): def __init__(self, config: Optional[MllamaTextConfig] = None, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) @@ -842,6 +1004,7 @@ def forward( output_hidden_states=output_hidden_states, output_attentions=output_attentions, return_dict=return_dict, + use_flash_attention=use_flash_attention, ) cross_attention_states = vision_outputs[0] cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape( @@ -1020,6 +1183,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + use_flash_attention: Optional[bool] = False, ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]: """ Copied from MllamaVisionModel::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L1425 @@ -1081,6 +1245,7 @@ def forward( attention_mask=attention_mask, output_hidden_states=True, output_attentions=output_attentions, + use_flash_attention=use_flash_attention, ) hidden_state = output[0] @@ -1099,6 +1264,7 @@ def forward( attention_mask=attention_mask, output_hidden_states=output_hidden_states, output_attentions=output_attentions, + use_flash_attention=use_flash_attention, ) hidden_state = global_output[0]