Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 0 additions & 88 deletions optimum/habana/transformers/modeling_attn_mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,91 +104,3 @@ def _gaudi_prepare_4d_causal_attention_mask(
)

return attention_mask


def _gaudi_prepare_4d_causal_attention_mask_for_sdpa(
attention_mask: Optional[torch.Tensor],
input_shape: Union[torch.Size, Tuple, List],
inputs_embeds: torch.Tensor,
past_key_values_length: int,
sliding_window: Optional[int] = None,
):
"""
Adapted from: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/modeling_attn_mask_utils.py#L331

Differences:
- `torch.all(attention_mask == 1)` was removed here: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/modeling_attn_mask_utils.py#L371
for performance reasons
"""
attn_mask_converter = GaudiAttentionMaskConverter(is_causal=True, sliding_window=sliding_window)

key_value_length = input_shape[-1] + past_key_values_length
batch_size, query_length = input_shape

# torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
# used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
# TODO: Fix this as well when using torchdynamo with fullgraph=True.
is_tracing = torch.jit.is_tracing() or isinstance(inputs_embeds, torch.fx.Proxy)

if attention_mask is not None:
# 4d mask is passed through
if len(attention_mask.shape) == 4:
expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
if tuple(attention_mask.shape) != expected_shape:
raise ValueError(
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
)
else:
# if the 4D mask has correct shape - invert it and fill with negative infinity
inverted_mask = 1.0 - attention_mask.to(inputs_embeds.dtype)
attention_mask = inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
)
return attention_mask

elif not is_tracing:
if query_length == 1:
# For query_length == 1, causal attention and bi-directional attention are the same.
attention_mask = None
elif key_value_length == query_length:
attention_mask = None
else:
# Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
# may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
# Reference: https://github.com/pytorch/pytorch/issues/108108
pass
elif query_length > 1 and key_value_length != query_length:
# See the comment above (https://github.com/pytorch/pytorch/issues/108108).
# Ugly: we set it to True here to dispatch in the following controlflow to `to_causal_4d`.
attention_mask = True
elif is_tracing:
raise ValueError(
'Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument `attn_implementation="eager"` or pass an attention_mask input when tracing the model.'
)

if attention_mask is None:
expanded_4d_mask = None
elif attention_mask is True:
expanded_4d_mask = attn_mask_converter.to_causal_4d(
input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
)
else:
expanded_4d_mask = attn_mask_converter.to_4d(
attention_mask,
input_shape[-1],
dtype=inputs_embeds.dtype,
key_value_length=key_value_length,
)

# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
#
# This fix is not applied in case we are tracing with torch.jit.trace or symbolic_trace, as _unmask_unattended has a data-dependent
# controlflow that can not be captured properly.
# TODO: _unmask_unattended does not work either with torch.compile when using fullgraph=True. We should find a way to detect this case.
if query_length > 1 and not is_tracing:
expanded_4d_mask = GaudiAttentionMaskConverter._unmask_unattended(
expanded_4d_mask, attention_mask, unmasked_value=0.0
)

return expanded_4d_mask
4 changes: 2 additions & 2 deletions optimum/habana/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from transformers.modeling_attn_mask_utils import (
_prepare_4d_attention_mask,
_prepare_4d_attention_mask_for_sdpa,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.modeling_outputs import (
BaseModelOutput,
Expand All @@ -35,7 +36,6 @@

from ...modeling_attn_mask_utils import (
_gaudi_prepare_4d_causal_attention_mask,
_gaudi_prepare_4d_causal_attention_mask_for_sdpa,
)


Expand Down Expand Up @@ -465,7 +465,7 @@ def gaudi_BartDecoder_forward(
if self._use_sdpa and not output_attentions and cross_attn_head_mask is None:
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _gaudi_prepare_4d_causal_attention_mask_for_sdpa(
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
input_shape,
inputs_embeds,
Expand Down
9 changes: 6 additions & 3 deletions optimum/habana/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import habana_frameworks.torch.core as htcore
from torch.nn import CrossEntropyLoss
from torch.nn import functional as F
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
Expand All @@ -45,7 +46,6 @@
from ...modeling_attn_mask_utils import (
GaudiAttentionMaskConverter,
_gaudi_prepare_4d_causal_attention_mask,
_gaudi_prepare_4d_causal_attention_mask_for_sdpa,
)


Expand Down Expand Up @@ -446,11 +446,14 @@ def forward(
)
position_ids = position_ids.unsqueeze(0)

if self._use_sdpa and not output_attentions:
# TODO: Due to perf degradation, disable spda_attn_mask
use_sdpa_attn_mask = False

if self._use_sdpa and not output_attentions and use_sdpa_attn_mask:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
if alibi is None:
attention_mask = _gaudi_prepare_4d_causal_attention_mask_for_sdpa(
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
Expand Down
4 changes: 2 additions & 2 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
import torch.nn.functional as F
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import (
Expand All @@ -20,7 +21,6 @@

from ...modeling_attn_mask_utils import (
_gaudi_prepare_4d_causal_attention_mask,
_gaudi_prepare_4d_causal_attention_mask_for_sdpa,
)


Expand Down Expand Up @@ -629,7 +629,7 @@ def forward(
if self._use_sdpa and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _gaudi_prepare_4d_causal_attention_mask_for_sdpa(
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.mistral.modeling_mistral import MistralForCausalLM, apply_rotary_pos_emb, repeat_kv
from transformers.utils import logging

from ...modeling_attn_mask_utils import (
_gaudi_prepare_4d_causal_attention_mask,
_gaudi_prepare_4d_causal_attention_mask_for_sdpa,
)


Expand Down Expand Up @@ -266,7 +266,7 @@ def gaudi_mistral_model_forward(
if self._attn_implementation == "sdpa" and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _gaudi_prepare_4d_causal_attention_mask_for_sdpa(
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,10 @@ def check_ctc_loss(self, config, input_values, *args):

input_values = input_values[:3]
attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.long)
# TODO: due to limitation of index op, add mark_step
if torch_device == "hpu":
import habana_frameworks.torch.core as htcore
htcore.mark_step()

input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
Expand Down