From 7c2b03a160162a90e829f29b6ba2ee0fc2e4c977 Mon Sep 17 00:00:00 2001 From: Libin Tang Date: Fri, 9 Feb 2024 17:17:20 +0000 Subject: [PATCH 1/4] As torch.all inside _prepare_4d_causal_attention_mask_for_sdpa causes d2h copy and impact falcon performance, we disable the falcon model to use _prepare_4d_causal_attention_mask_for_sdpa. Instead passing normal constructed 4D to sdpa. --- optimum/habana/transformers/modeling_attn_mask_utils.py | 6 +++--- .../habana/transformers/models/falcon/modeling_falcon.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/optimum/habana/transformers/modeling_attn_mask_utils.py b/optimum/habana/transformers/modeling_attn_mask_utils.py index 859292c0a4..016b557d5b 100755 --- a/optimum/habana/transformers/modeling_attn_mask_utils.py +++ b/optimum/habana/transformers/modeling_attn_mask_utils.py @@ -117,8 +117,8 @@ def _gaudi_prepare_4d_causal_attention_mask_for_sdpa( Adapted from: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/modeling_attn_mask_utils.py#L331 Differences: - - `torch.all(attention_mask == 1)` was removed here: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/modeling_attn_mask_utils.py#L371 - for performance reasons + - No difference with : https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/modeling_attn_mask_utils.py#L331 + keep for potential performance improvement """ attn_mask_converter = GaudiAttentionMaskConverter(is_causal=True, sliding_window=sliding_window) @@ -146,7 +146,7 @@ def _gaudi_prepare_4d_causal_attention_mask_for_sdpa( ) return attention_mask - elif not is_tracing: + elif not is_tracing and torch.all(attention_mask == 1): if query_length == 1: # For query_length == 1, causal attention and bi-directional attention are the same. attention_mask = None diff --git a/optimum/habana/transformers/models/falcon/modeling_falcon.py b/optimum/habana/transformers/models/falcon/modeling_falcon.py index dbbe3c364c..2d015d9f29 100644 --- a/optimum/habana/transformers/models/falcon/modeling_falcon.py +++ b/optimum/habana/transformers/models/falcon/modeling_falcon.py @@ -449,7 +449,8 @@ def forward( if self._use_sdpa and not output_attentions: # output_attentions=True can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. - if alibi is None: + use_sdpa_attn_mask = False + if alibi is None and use_sdpa_attn_mask: attention_mask = _gaudi_prepare_4d_causal_attention_mask_for_sdpa( attention_mask, (batch_size, seq_length), From 059afed28c07db178099e2fc3a90c9bf61a92b80 Mon Sep 17 00:00:00 2001 From: Jimin Ha Date: Fri, 9 Feb 2024 10:07:23 -0800 Subject: [PATCH 2/4] Remove gaudi attention_mask_for_sdpa --- .../transformers/modeling_attn_mask_utils.py | 88 ------------------- .../transformers/models/bart/modeling_bart.py | 4 +- .../models/falcon/modeling_falcon.py | 10 ++- .../models/llama/modeling_llama.py | 4 +- .../models/mistral/modeling_mistral.py | 4 +- 5 files changed, 12 insertions(+), 98 deletions(-) diff --git a/optimum/habana/transformers/modeling_attn_mask_utils.py b/optimum/habana/transformers/modeling_attn_mask_utils.py index 016b557d5b..4fe6217099 100755 --- a/optimum/habana/transformers/modeling_attn_mask_utils.py +++ b/optimum/habana/transformers/modeling_attn_mask_utils.py @@ -104,91 +104,3 @@ def _gaudi_prepare_4d_causal_attention_mask( ) return attention_mask - - -def _gaudi_prepare_4d_causal_attention_mask_for_sdpa( - attention_mask: Optional[torch.Tensor], - input_shape: Union[torch.Size, Tuple, List], - inputs_embeds: torch.Tensor, - past_key_values_length: int, - sliding_window: Optional[int] = None, -): - """ - Adapted from: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/modeling_attn_mask_utils.py#L331 - - Differences: - - No difference with : https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/modeling_attn_mask_utils.py#L331 - keep for potential performance improvement - """ - attn_mask_converter = GaudiAttentionMaskConverter(is_causal=True, sliding_window=sliding_window) - - key_value_length = input_shape[-1] + past_key_values_length - batch_size, query_length = input_shape - - # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1` - # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing. - # TODO: Fix this as well when using torchdynamo with fullgraph=True. - is_tracing = torch.jit.is_tracing() or isinstance(inputs_embeds, torch.fx.Proxy) - - if attention_mask is not None: - # 4d mask is passed through - if len(attention_mask.shape) == 4: - expected_shape = (input_shape[0], 1, input_shape[1], key_value_length) - if tuple(attention_mask.shape) != expected_shape: - raise ValueError( - f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}." - ) - else: - # if the 4D mask has correct shape - invert it and fill with negative infinity - inverted_mask = 1.0 - attention_mask.to(inputs_embeds.dtype) - attention_mask = inverted_mask.masked_fill( - inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min - ) - return attention_mask - - elif not is_tracing and torch.all(attention_mask == 1): - if query_length == 1: - # For query_length == 1, causal attention and bi-directional attention are the same. - attention_mask = None - elif key_value_length == query_length: - attention_mask = None - else: - # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation - # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here. - # Reference: https://github.com/pytorch/pytorch/issues/108108 - pass - elif query_length > 1 and key_value_length != query_length: - # See the comment above (https://github.com/pytorch/pytorch/issues/108108). - # Ugly: we set it to True here to dispatch in the following controlflow to `to_causal_4d`. - attention_mask = True - elif is_tracing: - raise ValueError( - 'Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument `attn_implementation="eager"` or pass an attention_mask input when tracing the model.' - ) - - if attention_mask is None: - expanded_4d_mask = None - elif attention_mask is True: - expanded_4d_mask = attn_mask_converter.to_causal_4d( - input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device - ) - else: - expanded_4d_mask = attn_mask_converter.to_4d( - attention_mask, - input_shape[-1], - dtype=inputs_embeds.dtype, - key_value_length=key_value_length, - ) - - # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend - # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 - # - # This fix is not applied in case we are tracing with torch.jit.trace or symbolic_trace, as _unmask_unattended has a data-dependent - # controlflow that can not be captured properly. - # TODO: _unmask_unattended does not work either with torch.compile when using fullgraph=True. We should find a way to detect this case. - if query_length > 1 and not is_tracing: - expanded_4d_mask = GaudiAttentionMaskConverter._unmask_unattended( - expanded_4d_mask, attention_mask, unmasked_value=0.0 - ) - - return expanded_4d_mask diff --git a/optimum/habana/transformers/models/bart/modeling_bart.py b/optimum/habana/transformers/models/bart/modeling_bart.py index 14c80f98b9..0b4f8de563 100644 --- a/optimum/habana/transformers/models/bart/modeling_bart.py +++ b/optimum/habana/transformers/models/bart/modeling_bart.py @@ -35,7 +35,7 @@ from ...modeling_attn_mask_utils import ( _gaudi_prepare_4d_causal_attention_mask, - _gaudi_prepare_4d_causal_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask_for_sdpa, ) @@ -465,7 +465,7 @@ def gaudi_BartDecoder_forward( if self._use_sdpa and not output_attentions and cross_attn_head_mask is None: # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _gaudi_prepare_4d_causal_attention_mask_for_sdpa( + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, input_shape, inputs_embeds, diff --git a/optimum/habana/transformers/models/falcon/modeling_falcon.py b/optimum/habana/transformers/models/falcon/modeling_falcon.py index 2d015d9f29..31d9b588b5 100644 --- a/optimum/habana/transformers/models/falcon/modeling_falcon.py +++ b/optimum/habana/transformers/models/falcon/modeling_falcon.py @@ -45,7 +45,7 @@ from ...modeling_attn_mask_utils import ( GaudiAttentionMaskConverter, _gaudi_prepare_4d_causal_attention_mask, - _gaudi_prepare_4d_causal_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask_for_sdpa, ) @@ -446,12 +446,14 @@ def forward( ) position_ids = position_ids.unsqueeze(0) - if self._use_sdpa and not output_attentions: + # TODO: Due to perf degradation, disable spda_attn_mask + use_sdpa_attn_mask = False + + if self._use_sdpa and not output_attentions and use_sdpa_attn_mask: # output_attentions=True can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. - use_sdpa_attn_mask = False if alibi is None and use_sdpa_attn_mask: - attention_mask = _gaudi_prepare_4d_causal_attention_mask_for_sdpa( + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, (batch_size, seq_length), inputs_embeds, diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 0645bfb9dc..e981323a1d 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -20,7 +20,7 @@ from ...modeling_attn_mask_utils import ( _gaudi_prepare_4d_causal_attention_mask, - _gaudi_prepare_4d_causal_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask_for_sdpa, ) @@ -629,7 +629,7 @@ def forward( if self._use_sdpa and not output_attentions: # output_attentions=True can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _gaudi_prepare_4d_causal_attention_mask_for_sdpa( + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, (batch_size, seq_length), inputs_embeds, diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index 4e110e8f98..a5db5cfefb 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -33,7 +33,7 @@ from ...modeling_attn_mask_utils import ( _gaudi_prepare_4d_causal_attention_mask, - _gaudi_prepare_4d_causal_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask_for_sdpa, ) @@ -266,7 +266,7 @@ def gaudi_mistral_model_forward( if self._attn_implementation == "sdpa" and not output_attentions: # output_attentions=True can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _gaudi_prepare_4d_causal_attention_mask_for_sdpa( + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, (batch_size, seq_length), inputs_embeds, From 7e07376897910c6d0b2db97d913a1a3c36b35ec6 Mon Sep 17 00:00:00 2001 From: Jimin Ha Date: Fri, 9 Feb 2024 10:19:38 -0800 Subject: [PATCH 3/4] Fix import error --- optimum/habana/transformers/models/bart/modeling_bart.py | 2 +- optimum/habana/transformers/models/falcon/modeling_falcon.py | 4 ++-- optimum/habana/transformers/models/llama/modeling_llama.py | 2 +- .../habana/transformers/models/mistral/modeling_mistral.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/optimum/habana/transformers/models/bart/modeling_bart.py b/optimum/habana/transformers/models/bart/modeling_bart.py index 0b4f8de563..f551fe0641 100644 --- a/optimum/habana/transformers/models/bart/modeling_bart.py +++ b/optimum/habana/transformers/models/bart/modeling_bart.py @@ -23,6 +23,7 @@ from transformers.modeling_attn_mask_utils import ( _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask_for_sdpa, ) from transformers.modeling_outputs import ( BaseModelOutput, @@ -35,7 +36,6 @@ from ...modeling_attn_mask_utils import ( _gaudi_prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, ) diff --git a/optimum/habana/transformers/models/falcon/modeling_falcon.py b/optimum/habana/transformers/models/falcon/modeling_falcon.py index 31d9b588b5..9c853dfb2a 100644 --- a/optimum/habana/transformers/models/falcon/modeling_falcon.py +++ b/optimum/habana/transformers/models/falcon/modeling_falcon.py @@ -29,6 +29,7 @@ import habana_frameworks.torch.core as htcore from torch.nn import CrossEntropyLoss from torch.nn import functional as F +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -45,7 +46,6 @@ from ...modeling_attn_mask_utils import ( GaudiAttentionMaskConverter, _gaudi_prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, ) @@ -452,7 +452,7 @@ def forward( if self._use_sdpa and not output_attentions and use_sdpa_attn_mask: # output_attentions=True can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. - if alibi is None and use_sdpa_attn_mask: + if alibi is None: attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, (batch_size, seq_length), diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index e981323a1d..2dfae57b6f 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -5,6 +5,7 @@ import torch import torch.nn.functional as F from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import ( @@ -20,7 +21,6 @@ from ...modeling_attn_mask_utils import ( _gaudi_prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, ) diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index a5db5cfefb..c1802c7b71 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -27,13 +27,13 @@ from torch import nn from torch.nn import CrossEntropyLoss from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.mistral.modeling_mistral import MistralForCausalLM, apply_rotary_pos_emb, repeat_kv from transformers.utils import logging from ...modeling_attn_mask_utils import ( _gaudi_prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, ) From 4ec6d86acaa3a215edd65c70808732ed88831bd2 Mon Sep 17 00:00:00 2001 From: Jimin Ha Date: Fri, 9 Feb 2024 17:35:18 -0800 Subject: [PATCH 4/4] Add mark_step before index op for UT failure --- .../tests/models/wav2vec2/test_modeling_wav2vec2.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/transformers/tests/models/wav2vec2/test_modeling_wav2vec2.py b/tests/transformers/tests/models/wav2vec2/test_modeling_wav2vec2.py index adf566979c..6bb188156a 100644 --- a/tests/transformers/tests/models/wav2vec2/test_modeling_wav2vec2.py +++ b/tests/transformers/tests/models/wav2vec2/test_modeling_wav2vec2.py @@ -340,6 +340,10 @@ def check_ctc_loss(self, config, input_values, *args): input_values = input_values[:3] attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.long) + # TODO: due to limitation of index op, add mark_step + if torch_device == "hpu": + import habana_frameworks.torch.core as htcore + htcore.mark_step() input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]] max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))