Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
gaudi_bloom_convert_to_bloom_cache,
gaudi_bloom_convert_to_standard_cache,
gaudi_bloom_model_forward,
gaudi_check_and_enable_sdpa,
gaudi_codegen_block_forward,
gaudi_codegen_model_forward,
gaudi_conv1d_forward,
Expand Down Expand Up @@ -215,6 +216,10 @@ def adapt_transformers_to_gaudi():
# so that Torch Autocast is disabled for specific parts of the code
transformers.modeling_utils.ModuleUtilsMixin.invert_attention_mask = gaudi_invert_attention_mask
transformers.modeling_utils.ModuleUtilsMixin.get_extended_attention_mask = gaudi_get_extended_attention_mask

# Override sdpa check on Gaudi
transformers.modeling_utils.PreTrainedModel._check_and_enable_sdpa = gaudi_check_and_enable_sdpa

# AlbertModel.forward does not rely on get_extended_attention_mask so it also needs to be replaced
transformers.models.albert.modeling_albert.AlbertModel.forward = gaudi_albert_forward

Expand Down
7 changes: 6 additions & 1 deletion optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,12 @@
gaudi_mixtral_model_forward,
gaudi_mixtral_rmsnorm_forward,
)
from .modeling_all_models import gaudi_conv1d_forward, gaudi_get_extended_attention_mask, gaudi_invert_attention_mask
from .modeling_all_models import (
gaudi_check_and_enable_sdpa,
gaudi_conv1d_forward,
gaudi_get_extended_attention_mask,
gaudi_invert_attention_mask,
)
from .mpt import (
GaudiMptForCausalLM,
GaudiMptModel,
Expand Down
38 changes: 37 additions & 1 deletion optimum/habana/transformers/models/modeling_all_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from typing import Tuple

import torch
from transformers.modeling_utils import ModuleUtilsMixin
from transformers.modeling_utils import ModuleUtilsMixin, PretrainedConfig
from transformers.utils.import_utils import is_torch_sdpa_available


def gaudi_invert_attention_mask(self, encoder_attention_mask: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -113,6 +114,41 @@ def gaudi_conv1d_forward(self, x):
return x


# Adapted from transformers.modeling_utils.PreTrainedModel._check_and_enable_sdpa
@classmethod
def gaudi_check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig:
# This model doesn't support SDPA in Gaudi yet, fallback to original code.
MODELS_ATTN_IMPLEMENTATION_EAGER = ["bart", "gpt_bigcode", "mistral", "mixtral"]

if config.model_type in MODELS_ATTN_IMPLEMENTATION_EAGER:
config._attn_implementation = "eager"
return config

# Otherwise, fallback to original implementation
# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/modeling_utils.py#L1542
if hard_check_only:
if not cls._supports_sdpa:
raise ValueError(
f"{cls.__name__} does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet."
" Please request the support for this architecture: https://github.com/huggingface/transformers/issues/28005. If you believe"
' this error is a bug, please open an issue in Transformers GitHub repository and load your model with the argument `attn_implementation="eager"` meanwhile. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`'
)
if not is_torch_sdpa_available():
raise ImportError("PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.1.1.")

if not is_torch_sdpa_available() or not cls._supports_sdpa:
return config

_is_bettertransformer = getattr(cls, "use_bettertransformer", False)
if _is_bettertransformer:
return config

if not hard_check_only:
config._attn_implementation = "sdpa"

return config


# Splitting DeepSpeed LinearAllReduce to three parts to avoid redundant memory consumption
class ScopedLinearAllReduce(torch.nn.Module):
def __init__(self, mod, *args, **kwargs):
Expand Down