diff --git a/optimum/habana/transformers/modeling_attn_mask_utils.py b/optimum/habana/transformers/modeling_attn_mask_utils.py index 859292c0a4..4fe6217099 100755 --- a/optimum/habana/transformers/modeling_attn_mask_utils.py +++ b/optimum/habana/transformers/modeling_attn_mask_utils.py @@ -104,91 +104,3 @@ def _gaudi_prepare_4d_causal_attention_mask( ) return attention_mask - - -def _gaudi_prepare_4d_causal_attention_mask_for_sdpa( - attention_mask: Optional[torch.Tensor], - input_shape: Union[torch.Size, Tuple, List], - inputs_embeds: torch.Tensor, - past_key_values_length: int, - sliding_window: Optional[int] = None, -): - """ - Adapted from: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/modeling_attn_mask_utils.py#L331 - - Differences: - - `torch.all(attention_mask == 1)` was removed here: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/modeling_attn_mask_utils.py#L371 - for performance reasons - """ - attn_mask_converter = GaudiAttentionMaskConverter(is_causal=True, sliding_window=sliding_window) - - key_value_length = input_shape[-1] + past_key_values_length - batch_size, query_length = input_shape - - # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1` - # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing. - # TODO: Fix this as well when using torchdynamo with fullgraph=True. - is_tracing = torch.jit.is_tracing() or isinstance(inputs_embeds, torch.fx.Proxy) - - if attention_mask is not None: - # 4d mask is passed through - if len(attention_mask.shape) == 4: - expected_shape = (input_shape[0], 1, input_shape[1], key_value_length) - if tuple(attention_mask.shape) != expected_shape: - raise ValueError( - f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}." - ) - else: - # if the 4D mask has correct shape - invert it and fill with negative infinity - inverted_mask = 1.0 - attention_mask.to(inputs_embeds.dtype) - attention_mask = inverted_mask.masked_fill( - inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min - ) - return attention_mask - - elif not is_tracing: - if query_length == 1: - # For query_length == 1, causal attention and bi-directional attention are the same. - attention_mask = None - elif key_value_length == query_length: - attention_mask = None - else: - # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation - # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here. - # Reference: https://github.com/pytorch/pytorch/issues/108108 - pass - elif query_length > 1 and key_value_length != query_length: - # See the comment above (https://github.com/pytorch/pytorch/issues/108108). - # Ugly: we set it to True here to dispatch in the following controlflow to `to_causal_4d`. - attention_mask = True - elif is_tracing: - raise ValueError( - 'Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument `attn_implementation="eager"` or pass an attention_mask input when tracing the model.' - ) - - if attention_mask is None: - expanded_4d_mask = None - elif attention_mask is True: - expanded_4d_mask = attn_mask_converter.to_causal_4d( - input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device - ) - else: - expanded_4d_mask = attn_mask_converter.to_4d( - attention_mask, - input_shape[-1], - dtype=inputs_embeds.dtype, - key_value_length=key_value_length, - ) - - # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend - # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 - # - # This fix is not applied in case we are tracing with torch.jit.trace or symbolic_trace, as _unmask_unattended has a data-dependent - # controlflow that can not be captured properly. - # TODO: _unmask_unattended does not work either with torch.compile when using fullgraph=True. We should find a way to detect this case. - if query_length > 1 and not is_tracing: - expanded_4d_mask = GaudiAttentionMaskConverter._unmask_unattended( - expanded_4d_mask, attention_mask, unmasked_value=0.0 - ) - - return expanded_4d_mask diff --git a/optimum/habana/transformers/models/bart/modeling_bart.py b/optimum/habana/transformers/models/bart/modeling_bart.py index 14c80f98b9..f551fe0641 100644 --- a/optimum/habana/transformers/models/bart/modeling_bart.py +++ b/optimum/habana/transformers/models/bart/modeling_bart.py @@ -23,6 +23,7 @@ from transformers.modeling_attn_mask_utils import ( _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask_for_sdpa, ) from transformers.modeling_outputs import ( BaseModelOutput, @@ -35,7 +36,6 @@ from ...modeling_attn_mask_utils import ( _gaudi_prepare_4d_causal_attention_mask, - _gaudi_prepare_4d_causal_attention_mask_for_sdpa, ) @@ -465,7 +465,7 @@ def gaudi_BartDecoder_forward( if self._use_sdpa and not output_attentions and cross_attn_head_mask is None: # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _gaudi_prepare_4d_causal_attention_mask_for_sdpa( + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, input_shape, inputs_embeds, diff --git a/optimum/habana/transformers/models/falcon/modeling_falcon.py b/optimum/habana/transformers/models/falcon/modeling_falcon.py index dbbe3c364c..9c853dfb2a 100644 --- a/optimum/habana/transformers/models/falcon/modeling_falcon.py +++ b/optimum/habana/transformers/models/falcon/modeling_falcon.py @@ -29,6 +29,7 @@ import habana_frameworks.torch.core as htcore from torch.nn import CrossEntropyLoss from torch.nn import functional as F +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -45,7 +46,6 @@ from ...modeling_attn_mask_utils import ( GaudiAttentionMaskConverter, _gaudi_prepare_4d_causal_attention_mask, - _gaudi_prepare_4d_causal_attention_mask_for_sdpa, ) @@ -446,11 +446,14 @@ def forward( ) position_ids = position_ids.unsqueeze(0) - if self._use_sdpa and not output_attentions: + # TODO: Due to perf degradation, disable spda_attn_mask + use_sdpa_attn_mask = False + + if self._use_sdpa and not output_attentions and use_sdpa_attn_mask: # output_attentions=True can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. if alibi is None: - attention_mask = _gaudi_prepare_4d_causal_attention_mask_for_sdpa( + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, (batch_size, seq_length), inputs_embeds, diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 0645bfb9dc..2dfae57b6f 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -5,6 +5,7 @@ import torch import torch.nn.functional as F from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import ( @@ -20,7 +21,6 @@ from ...modeling_attn_mask_utils import ( _gaudi_prepare_4d_causal_attention_mask, - _gaudi_prepare_4d_causal_attention_mask_for_sdpa, ) @@ -629,7 +629,7 @@ def forward( if self._use_sdpa and not output_attentions: # output_attentions=True can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _gaudi_prepare_4d_causal_attention_mask_for_sdpa( + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, (batch_size, seq_length), inputs_embeds, diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index 4e110e8f98..c1802c7b71 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -27,13 +27,13 @@ from torch import nn from torch.nn import CrossEntropyLoss from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.mistral.modeling_mistral import MistralForCausalLM, apply_rotary_pos_emb, repeat_kv from transformers.utils import logging from ...modeling_attn_mask_utils import ( _gaudi_prepare_4d_causal_attention_mask, - _gaudi_prepare_4d_causal_attention_mask_for_sdpa, ) @@ -266,7 +266,7 @@ def gaudi_mistral_model_forward( if self._attn_implementation == "sdpa" and not output_attentions: # output_attentions=True can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _gaudi_prepare_4d_causal_attention_mask_for_sdpa( + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, (batch_size, seq_length), inputs_embeds, diff --git a/tests/transformers/tests/models/wav2vec2/test_modeling_wav2vec2.py b/tests/transformers/tests/models/wav2vec2/test_modeling_wav2vec2.py index adf566979c..6bb188156a 100644 --- a/tests/transformers/tests/models/wav2vec2/test_modeling_wav2vec2.py +++ b/tests/transformers/tests/models/wav2vec2/test_modeling_wav2vec2.py @@ -340,6 +340,10 @@ def check_ctc_loss(self, config, input_values, *args): input_values = input_values[:3] attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.long) + # TODO: due to limitation of index op, add mark_step + if torch_device == "hpu": + import habana_frameworks.torch.core as htcore + htcore.mark_step() input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]] max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))