Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,10 @@
GaudiMllamaTextCrossAttention,
GaudiMllamaTextModel,
GaudiMllamaTextSelfAttention,
GaudiMllamaVisionEncoder,
GaudiMllamaVisionEncoderLayer,
GaudiMllamaVisionModel,
GaudiMllamaVisionSdpaAttention,
GaudiMptAttention,
GaudiMptBlock,
GaudiMptForCausalLM,
Expand Down Expand Up @@ -661,6 +664,9 @@ def adapt_transformers_to_gaudi():
transformers.models.mllama.modeling_mllama.MllamaForConditionalGeneration = GaudiMllamaForConditionalGeneration
transformers.models.mllama.modeling_mllama.MllamaTextModel = GaudiMllamaTextModel
transformers.models.mllama.modeling_mllama.MllamaVisionModel = GaudiMllamaVisionModel
transformers.models.mllama.modeling_mllama.MllamaVisionEncoder = GaudiMllamaVisionEncoder
transformers.models.mllama.modeling_mllama.MllamaVisionEncoderLayer = GaudiMllamaVisionEncoderLayer
transformers.models.mllama.modeling_mllama.MllamaVisionSdpaAttention = GaudiMllamaVisionSdpaAttention

transformers.AutoConfig.register("deci", DeciLMConfig)
transformers.AutoModelForCausalLM.register(DeciLMConfig, DeciLMForCausalLM)
Expand Down
3 changes: 3 additions & 0 deletions optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,10 @@
GaudiMllamaTextCrossAttention,
GaudiMllamaTextModel,
GaudiMllamaTextSelfAttention,
GaudiMllamaVisionEncoder,
GaudiMllamaVisionEncoderLayer,
GaudiMllamaVisionModel,
GaudiMllamaVisionSdpaAttention,
)
from .modeling_all_models import (
gaudi_check_and_enable_sdpa,
Expand Down
3 changes: 3 additions & 0 deletions optimum/habana/transformers/models/mllama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,8 @@
GaudiMllamaTextCrossAttention,
GaudiMllamaTextModel,
GaudiMllamaTextSelfAttention,
GaudiMllamaVisionEncoder,
GaudiMllamaVisionEncoderLayer,
GaudiMllamaVisionModel,
GaudiMllamaVisionSdpaAttention,
)
166 changes: 166 additions & 0 deletions optimum/habana/transformers/models/mllama/modeling_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import os
from typing import List, Optional, Tuple, Union

import habana_frameworks.torch.core as htcore
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
Expand All @@ -35,6 +36,10 @@
MllamaTextCrossAttention,
MllamaTextModel,
MllamaTextSelfAttention,
MllamaVisionAttention,
MllamaVisionConfig,
MllamaVisionEncoder,
MllamaVisionEncoderLayer,
MllamaVisionModel,
_prepare_4d_causal_attention_mask_with_cache_position,
_prepare_aspect_ratio_attention_mask,
Expand Down Expand Up @@ -107,6 +112,163 @@ def _prepare_cross_attention_mask(
return cross_attention_mask, full_text_row_masked_out_mask


class GaudiMllamaVisionSdpaAttention(MllamaVisionAttention):
def __init__(self, config: MllamaVisionConfig):
super().__init__(config)
self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None

# Adapted from MllamaVisionAttention
def forward(
self,
hidden_state: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = None,
use_flash_attention: Optional[bool] = False,
) -> torch.Tensor:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
"""
Copied from MllamaVisionSdpaAttention::forward:https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L283
The only differences are:
- add use_flash_attention
"""
if output_attentions:
logger.warning_once(
"MllamaModel is using MllamaVisionSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_state=hidden_state,
attention_mask=attention_mask,
output_attentions=output_attentions,
)

query = self.q_proj(hidden_state)
key = self.k_proj(hidden_state)
value = self.v_proj(hidden_state)

batch_size, q_seq_len, _ = query.shape
_, kv_seq_len, _ = key.shape

query = query.view(batch_size, q_seq_len, self.num_heads, self.head_dim)
key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim)
value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim)

query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
if use_flash_attention and FusedSDPA:
attn_output = self.fused_scaled_dot_product_attention(query, key, value, attention_mask, 0.0, False, None)
else:
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_seq_len, -1)

output = self.o_proj(attn_output)

return output, None


class GaudiMllamaVisionEncoderLayer(MllamaVisionEncoderLayer):
def __init__(self, config: MllamaVisionConfig, is_gated: bool = False):
super(GaudiMllamaVisionEncoderLayer, self).__init__(config=config, is_gated=is_gated)
self.self_attn = GaudiMllamaVisionSdpaAttention(config)

def forward(
self,
hidden_state: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = None,
use_flash_attention: Optional[bool] = False,
):
"""
Copied from MllamaVisionEncoderLayer::forward:https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L348
The only differences are:
- add use_flash_attention
"""
# Self Attention
residual = hidden_state
hidden_state = self.input_layernorm(hidden_state)
hidden_state, attn_weights = self.self_attn(
hidden_state, attention_mask=attention_mask, use_flash_attention=use_flash_attention
)
if self.is_gated:
hidden_state = self.gate_attn.tanh() * hidden_state
hidden_state = residual + hidden_state

# Feed forward
residual = hidden_state
hidden_state = self.post_attention_layernorm(hidden_state)
hidden_state = self.mlp(hidden_state)
if self.is_gated:
hidden_state = self.gate_ffn.tanh() * hidden_state
hidden_state = residual + hidden_state

outputs = (hidden_state,)

if output_attentions:
outputs += (attn_weights,)

return outputs


class GaudiMllamaVisionEncoder(MllamaVisionEncoder):
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
use_flash_attention: Optional[bool] = False,
) -> Union[Tuple, BaseModelOutput]:
"""
Copied from MllamaVisionEncoder::forward:https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L394
The only differences are:
- add use_flash_attention
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None

for encoder_layer in self.layers:
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_state=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
use_flash_attention=use_flash_attention,
)

if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
htcore.mark_step()
hidden_states = layer_outputs[0]

if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)

if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
)


class GaudiMllamaTextCrossAttention(MllamaTextCrossAttention):
def __init__(self, config: Optional[MllamaTextConfig] = None, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx)
Expand Down Expand Up @@ -842,6 +1004,7 @@ def forward(
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=return_dict,
use_flash_attention=use_flash_attention,
)
cross_attention_states = vision_outputs[0]
cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape(
Expand Down Expand Up @@ -1020,6 +1183,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
use_flash_attention: Optional[bool] = False,
) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]:
"""
Copied from MllamaVisionModel::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L1425
Expand Down Expand Up @@ -1081,6 +1245,7 @@ def forward(
attention_mask=attention_mask,
output_hidden_states=True,
output_attentions=output_attentions,
use_flash_attention=use_flash_attention,
)
hidden_state = output[0]

Expand All @@ -1099,6 +1264,7 @@ def forward(
attention_mask=attention_mask,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
use_flash_attention=use_flash_attention,
)
hidden_state = global_output[0]

Expand Down