diff --git a/examples/modular-transformers/modeling_dummy_bert.py b/examples/modular-transformers/modeling_dummy_bert.py index 0c8cdc16a7da..bd4322607c92 100644 --- a/examples/modular-transformers/modeling_dummy_bert.py +++ b/examples/modular-transformers/modeling_dummy_bert.py @@ -169,9 +169,9 @@ def forward( {"cache_position": cache_position}, ) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -246,9 +246,9 @@ def forward( # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls past_key_values.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/examples/modular-transformers/modeling_from_uppercase_model.py b/examples/modular-transformers/modeling_from_uppercase_model.py index dc4e41f14951..af4ea303306f 100644 --- a/examples/modular-transformers/modeling_from_uppercase_model.py +++ b/examples/modular-transformers/modeling_from_uppercase_model.py @@ -80,9 +80,9 @@ def forward( keys = keys.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2) values = values.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/examples/modular-transformers/modeling_global_indexing.py b/examples/modular-transformers/modeling_global_indexing.py index 83b214485c02..3a4bdcbb6add 100644 --- a/examples/modular-transformers/modeling_global_indexing.py +++ b/examples/modular-transformers/modeling_global_indexing.py @@ -150,9 +150,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/examples/modular-transformers/modeling_multimodal2.py b/examples/modular-transformers/modeling_multimodal2.py index 71a9cd217ec0..772aa9697ab5 100644 --- a/examples/modular-transformers/modeling_multimodal2.py +++ b/examples/modular-transformers/modeling_multimodal2.py @@ -84,9 +84,9 @@ def forward( keys = keys.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2) values = values.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index 2c9fb1485b83..0fc26b94fdb7 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -178,9 +178,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/examples/modular-transformers/modeling_roberta.py b/examples/modular-transformers/modeling_roberta.py index f5754ffaa62d..1b871f2f3586 100644 --- a/examples/modular-transformers/modeling_roberta.py +++ b/examples/modular-transformers/modeling_roberta.py @@ -172,9 +172,9 @@ def forward( {"cache_position": cache_position}, ) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -249,9 +249,9 @@ def forward( # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls past_key_values.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/examples/modular-transformers/modeling_super.py b/examples/modular-transformers/modeling_super.py index 350422b8197e..ac5b01549114 100644 --- a/examples/modular-transformers/modeling_super.py +++ b/examples/modular-transformers/modeling_super.py @@ -248,9 +248,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/examples/modular-transformers/modeling_switch_function.py b/examples/modular-transformers/modeling_switch_function.py index 89ebc74eabc9..5a978fb1b059 100644 --- a/examples/modular-transformers/modeling_switch_function.py +++ b/examples/modular-transformers/modeling_switch_function.py @@ -141,9 +141,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/examples/modular-transformers/modeling_test_suffix.py b/examples/modular-transformers/modeling_test_suffix.py index aba241dd97c7..029e9d3a74b9 100644 --- a/examples/modular-transformers/modeling_test_suffix.py +++ b/examples/modular-transformers/modeling_test_suffix.py @@ -182,9 +182,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/integrations/moe.py b/src/transformers/integrations/moe.py index a6f8f9ebbef7..af9e06281d26 100644 --- a/src/transformers/integrations/moe.py +++ b/src/transformers/integrations/moe.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Callable from functools import wraps +from ..utils import logging from ..utils.generic import GeneralInterface from ..utils.import_utils import is_torch_available @@ -21,6 +23,8 @@ if is_torch_available(): import torch +logger = logging.get_logger(__name__) + # Examples of experts class with its eager mm implementation # class Experts(nn.Module): # """Collection of expert weights stored as 3D tensors.""" @@ -265,6 +269,20 @@ class ExpertsInterface(GeneralInterface): "grouped_mm": grouped_mm_experts_forward, } + def get_interface(self, experts_implementation: str, default: Callable) -> Callable: + """Return the requested `experts_implementation`. Also strictly check its validity, and raise if invalid.""" + if experts_implementation is None: + logger.warning_once( + "You tried to access the `ExpertsInterface` with a `config._experts_implementation` set to `None`. This " + "is expected if you use an Expert Module as a standalone Module. If this is not the case, something went " + "wrong with the dispatch of `config._experts_implementation`" + ) + elif experts_implementation != "eager" and experts_implementation not in self: + raise KeyError( + f"`{experts_implementation}` is not a valid experts implementation registered in the `ExpertsInterface`" + ) + return super().get(experts_implementation, default) + ALL_EXPERTS_FUNCTIONS = ExpertsInterface() @@ -313,11 +331,9 @@ def __init__(self, config, *args, **kwargs): @wraps(original_forward) def forward(self, *args, **kwargs): - experts_forward = original_forward - - if self.config._experts_implementation != "eager": - experts_forward = ALL_EXPERTS_FUNCTIONS[self.config._experts_implementation] - + experts_forward = ALL_EXPERTS_FUNCTIONS.get_interface( + self.config._experts_implementation, original_forward + ) return experts_forward(self, *args, **kwargs) if not hasattr(experts_class, "_apply_gate"): diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index d07ff8618c10..898d44ea7639 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1957,10 +1957,7 @@ def _can_set_attn_implementation(cls) -> bool: code = f.read() # heuristic -> if we find those patterns, the model uses the correct interface if re.search(r"class \w+Attention\(nn.Module\)", code): - return ( - "eager_attention_forward" in code - and "ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]" in code - ) + return "eager_attention_forward" in code and "ALL_ATTENTION_FUNCTIONS.get_interface(" in code else: # If no attention layer, assume `True`. Most probably a multimodal model or inherits from existing models return True @@ -4782,6 +4779,20 @@ class AttentionInterface(GeneralInterface): "paged|eager": eager_paged_attention_forward, } + def get_interface(self, attn_implementation: str, default: Callable) -> Callable: + """Return the requested `attn_implementation`. Also strictly check its validity, and raise if invalid.""" + if attn_implementation is None: + logger.warning_once( + "You tried to access the `AttentionInterface` with a `config._attn_implementation` set to `None`. This " + "is expected if you use an Attention Module as a standalone Module. If this is not the case, something went " + "wrong with the dispatch of `config._attn_implementation`" + ) + elif attn_implementation != "eager" and attn_implementation not in self: + raise KeyError( + f"`{attn_implementation}` is not a valid attention implementation registered in the `AttentionInterface`" + ) + return super().get(attn_implementation, default) + # Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones ALL_ATTENTION_FUNCTIONS: AttentionInterface = AttentionInterface() diff --git a/src/transformers/models/afmoe/modeling_afmoe.py b/src/transformers/models/afmoe/modeling_afmoe.py index 15e88fc1f00b..f99e439f88e6 100644 --- a/src/transformers/models/afmoe/modeling_afmoe.py +++ b/src/transformers/models/afmoe/modeling_afmoe.py @@ -405,9 +405,9 @@ def forward( cache_kwargs = {"cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/afmoe/modular_afmoe.py b/src/transformers/models/afmoe/modular_afmoe.py index d81a659e905b..3bd29f904f2e 100644 --- a/src/transformers/models/afmoe/modular_afmoe.py +++ b/src/transformers/models/afmoe/modular_afmoe.py @@ -226,9 +226,9 @@ def forward( cache_kwargs = {"cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/aimv2/modeling_aimv2.py b/src/transformers/models/aimv2/modeling_aimv2.py index 8f0a86ed3da3..5d4839d4f4f3 100644 --- a/src/transformers/models/aimv2/modeling_aimv2.py +++ b/src/transformers/models/aimv2/modeling_aimv2.py @@ -270,9 +270,9 @@ def forward( keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/albert/modeling_albert.py b/src/transformers/models/albert/modeling_albert.py index 70c179cfd6a3..910ac97a146f 100755 --- a/src/transformers/models/albert/modeling_albert.py +++ b/src/transformers/models/albert/modeling_albert.py @@ -181,9 +181,9 @@ def forward( key_layer = self.key(hidden_states).view(*hidden_shape).transpose(1, 2) value_layer = self.value(hidden_states).view(*hidden_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/align/modeling_align.py b/src/transformers/models/align/modeling_align.py index de30af314e40..91f96d288c70 100644 --- a/src/transformers/models/align/modeling_align.py +++ b/src/transformers/models/align/modeling_align.py @@ -628,9 +628,9 @@ def forward( key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index 247b81023758..d6bc5965779b 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -505,9 +505,9 @@ def forward( else: self.is_causal = causal_attention_mask is not None - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/apertus/modeling_apertus.py b/src/transformers/models/apertus/modeling_apertus.py index 26520c3a1653..3a2fb0482bc0 100644 --- a/src/transformers/models/apertus/modeling_apertus.py +++ b/src/transformers/models/apertus/modeling_apertus.py @@ -266,9 +266,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/apertus/modular_apertus.py b/src/transformers/models/apertus/modular_apertus.py index f24b9ebd79b3..0aded80ff4a9 100644 --- a/src/transformers/models/apertus/modular_apertus.py +++ b/src/transformers/models/apertus/modular_apertus.py @@ -235,9 +235,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/arcee/modeling_arcee.py b/src/transformers/models/arcee/modeling_arcee.py index 8a692d516b19..f2fd072aff20 100644 --- a/src/transformers/models/arcee/modeling_arcee.py +++ b/src/transformers/models/arcee/modeling_arcee.py @@ -268,9 +268,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 0ecb99c898f0..02b0f14affcd 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -497,9 +497,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py index fe12caa0e6df..76033f7cdd23 100644 --- a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +++ b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py @@ -155,9 +155,9 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2) query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) context_layer, attention_probs = attention_interface( self, diff --git a/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py b/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py index f88a19796f34..0cbcabe88a3e 100644 --- a/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py +++ b/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py @@ -171,9 +171,9 @@ def forward( key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index be8f046940ce..6dd6d9e735b7 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -380,9 +380,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index fac9a775fb56..8774853ec6ba 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -235,9 +235,9 @@ def forward( if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): past_key_values.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 3cc03c9c0f2c..d33f6356e4ce 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -195,9 +195,9 @@ def forward( {"cache_position": cache_position}, ) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -272,9 +272,9 @@ def forward( # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls past_key_values.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index bccb17124aa9..0a84c415939b 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -141,9 +141,9 @@ def forward( {"cache_position": cache_position}, ) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -219,9 +219,9 @@ def forward( # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls past_key_values.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index fef98fb6cdf2..8ec1a5ca230b 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -1286,9 +1286,9 @@ def forward( if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): past_key_values.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index 547040ed1873..370de417dbfa 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -219,9 +219,9 @@ def forward( if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): past_key_values.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/bitnet/modeling_bitnet.py b/src/transformers/models/bitnet/modeling_bitnet.py index f230fd27e355..48d8162b50fa 100644 --- a/src/transformers/models/bitnet/modeling_bitnet.py +++ b/src/transformers/models/bitnet/modeling_bitnet.py @@ -200,10 +200,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/bitnet/modular_bitnet.py b/src/transformers/models/bitnet/modular_bitnet.py index d3aae65a8ad4..e00d14fbc897 100644 --- a/src/transformers/models/bitnet/modular_bitnet.py +++ b/src/transformers/models/bitnet/modular_bitnet.py @@ -81,10 +81,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 08acd429d790..e6e3e90a372d 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -233,9 +233,9 @@ def forward( if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): past_key_values.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index e2e8adb0b752..58d3012487f2 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -220,9 +220,9 @@ def forward( if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): past_key_values.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 6ac6c4b48b7a..1c0a36d36be2 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -331,10 +331,9 @@ def forward( ) query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2] - attention_interface: Callable = eager_attention_forward - - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index 0dea1ef44a67..8cc3e28bbdf0 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -341,10 +341,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -405,9 +404,9 @@ def forward( key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/blt/modular_blt.py b/src/transformers/models/blt/modular_blt.py index e27a24ee59d7..22fef55c96d2 100644 --- a/src/transformers/models/blt/modular_blt.py +++ b/src/transformers/models/blt/modular_blt.py @@ -327,9 +327,9 @@ def forward( key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index 8e9741dc76e3..ba6aaa0b471a 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -487,9 +487,9 @@ def forward( {"cache_position": cache_position}, ) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -565,9 +565,9 @@ def forward( # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls past_key_values.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index 7bdcb24be61e..b67960f32a43 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -242,9 +242,9 @@ def forward( {"cache_position": cache_position}, ) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -319,9 +319,9 @@ def forward( # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls past_key_values.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index f80a47df5ded..cfe557d6afc5 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -335,9 +335,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py index 9ec883d64952..10425aa5dca8 100644 --- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -296,9 +296,9 @@ def forward( key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -390,9 +390,9 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/clap/modeling_clap.py b/src/transformers/models/clap/modeling_clap.py index 64223c23e8c1..b462a1793421 100644 --- a/src/transformers/models/clap/modeling_clap.py +++ b/src/transformers/models/clap/modeling_clap.py @@ -1113,9 +1113,9 @@ def forward( key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index 268661968b51..d47c3978d21b 100644 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -323,9 +323,9 @@ def forward( keys = keys.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2) values = values.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/clipseg/modeling_clipseg.py b/src/transformers/models/clipseg/modeling_clipseg.py index 9c93a5764b99..ed2b5cdb5b68 100644 --- a/src/transformers/models/clipseg/modeling_clipseg.py +++ b/src/transformers/models/clipseg/modeling_clipseg.py @@ -332,9 +332,9 @@ def forward( else: self.is_causal = causal_attention_mask is not None - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index cf29d8596a9e..e19242a7359c 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -288,9 +288,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/cohere/modular_cohere.py b/src/transformers/models/cohere/modular_cohere.py index 89e7256035ec..3c215a77780f 100644 --- a/src/transformers/models/cohere/modular_cohere.py +++ b/src/transformers/models/cohere/modular_cohere.py @@ -174,9 +174,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index 559e2b0e8dd7..5aa00891dcfe 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -248,9 +248,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/cohere2/modular_cohere2.py b/src/transformers/models/cohere2/modular_cohere2.py index ca3a8e14af89..231203c8ac4d 100644 --- a/src/transformers/models/cohere2/modular_cohere2.py +++ b/src/transformers/models/cohere2/modular_cohere2.py @@ -285,9 +285,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/csm/modeling_csm.py b/src/transformers/models/csm/modeling_csm.py index 77399420a9e9..ad125f0237b7 100644 --- a/src/transformers/models/csm/modeling_csm.py +++ b/src/transformers/models/csm/modeling_csm.py @@ -321,9 +321,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/cwm/modeling_cwm.py b/src/transformers/models/cwm/modeling_cwm.py index 2b36d702132f..a5c601610698 100644 --- a/src/transformers/models/cwm/modeling_cwm.py +++ b/src/transformers/models/cwm/modeling_cwm.py @@ -221,9 +221,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index 5c1225e93db8..028fd106cc23 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -265,9 +265,9 @@ def forward( key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2) value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index 93177572caea..d43cc5da4d62 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -241,9 +241,9 @@ def forward( {"cache_position": cache_position}, ) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -318,9 +318,9 @@ def forward( # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls past_key_values.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 835c975cac4e..6033d75a0f5e 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -244,9 +244,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/dbrx/modular_dbrx.py b/src/transformers/models/dbrx/modular_dbrx.py index 3f3201614652..bcd17481f784 100644 --- a/src/transformers/models/dbrx/modular_dbrx.py +++ b/src/transformers/models/dbrx/modular_dbrx.py @@ -113,9 +113,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index 6bc055b62d80..afe70ede6667 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -234,9 +234,9 @@ def forward( past_key_values.is_updated[self.layer_idx] = True using_eager = self.config._attn_implementation == "eager" - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) if using_eager and self.reorder_and_upcast_attn: attn_output, attn_weights = self._upcast_and_reordered_attn( diff --git a/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py b/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py index 6e183dc173b8..9b4dedcaefae 100644 --- a/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py +++ b/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py @@ -374,9 +374,9 @@ def forward( if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim: value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim]) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/deepseek_v2/modular_deepseek_v2.py b/src/transformers/models/deepseek_v2/modular_deepseek_v2.py index bbb6f0321c59..6509f5bb2c7a 100644 --- a/src/transformers/models/deepseek_v2/modular_deepseek_v2.py +++ b/src/transformers/models/deepseek_v2/modular_deepseek_v2.py @@ -399,9 +399,9 @@ def forward( if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim: value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim]) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index f8183b24d8b4..5503a9986847 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -459,9 +459,9 @@ def forward( if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim: value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim]) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py index 76e5a0060f1c..935ae2f8f59a 100644 --- a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py @@ -264,9 +264,9 @@ def forward( if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim: value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim]) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index b4254274ee34..ee455cb19d29 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -220,9 +220,9 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2) query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) context_layer, attention_probs = attention_interface( self, diff --git a/src/transformers/models/dia/modeling_dia.py b/src/transformers/models/dia/modeling_dia.py index b2b7b34dd923..35eeb795e7b9 100644 --- a/src/transformers/models/dia/modeling_dia.py +++ b/src/transformers/models/dia/modeling_dia.py @@ -317,9 +317,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -392,9 +392,9 @@ def forward( # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls past_key_values.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/dia/modular_dia.py b/src/transformers/models/dia/modular_dia.py index d265bea97c51..ca4119b85056 100644 --- a/src/transformers/models/dia/modular_dia.py +++ b/src/transformers/models/dia/modular_dia.py @@ -182,9 +182,9 @@ def forward( # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls past_key_values.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/dinov2/modeling_dinov2.py b/src/transformers/models/dinov2/modeling_dinov2.py index 1661211c1d2e..62c6561c7c6a 100644 --- a/src/transformers/models/dinov2/modeling_dinov2.py +++ b/src/transformers/models/dinov2/modeling_dinov2.py @@ -208,9 +208,9 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2) query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) context_layer, attention_probs = attention_interface( self, diff --git a/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py b/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py index 30c19a9adf19..1c6d199f221b 100644 --- a/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +++ b/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py @@ -228,9 +228,9 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2) query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) context_layer, attention_probs = attention_interface( self, diff --git a/src/transformers/models/dinov3_vit/modeling_dinov3_vit.py b/src/transformers/models/dinov3_vit/modeling_dinov3_vit.py index 045b12a3bf58..ddc6feefb657 100644 --- a/src/transformers/models/dinov3_vit/modeling_dinov3_vit.py +++ b/src/transformers/models/dinov3_vit/modeling_dinov3_vit.py @@ -295,9 +295,9 @@ def forward( cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/dinov3_vit/modular_dinov3_vit.py b/src/transformers/models/dinov3_vit/modular_dinov3_vit.py index 0f667822c29f..bc1378c14bc0 100644 --- a/src/transformers/models/dinov3_vit/modular_dinov3_vit.py +++ b/src/transformers/models/dinov3_vit/modular_dinov3_vit.py @@ -252,9 +252,9 @@ def forward( cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index bd8c75b2a855..f0a6fd929fa4 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -186,9 +186,9 @@ def forward( key_layer = self.k_lin(hidden_states).view(*hidden_shape).transpose(1, 2) value_layer = self.v_lin(hidden_states).view(*hidden_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/doge/modeling_doge.py b/src/transformers/models/doge/modeling_doge.py index 9899c5688dc3..c7801b144362 100644 --- a/src/transformers/models/doge/modeling_doge.py +++ b/src/transformers/models/doge/modeling_doge.py @@ -324,9 +324,9 @@ def forward( ) attn_mask = repeat_kv(attn_mask, self.num_key_value_groups) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/doge/modular_doge.py b/src/transformers/models/doge/modular_doge.py index 10bf076b4a7d..1a29061f8063 100644 --- a/src/transformers/models/doge/modular_doge.py +++ b/src/transformers/models/doge/modular_doge.py @@ -360,9 +360,9 @@ def forward( ) attn_mask = repeat_kv(attn_mask, self.num_key_value_groups) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/dots1/modeling_dots1.py b/src/transformers/models/dots1/modeling_dots1.py index efa1af5ef85b..0026fa4d3c24 100644 --- a/src/transformers/models/dots1/modeling_dots1.py +++ b/src/transformers/models/dots1/modeling_dots1.py @@ -258,9 +258,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index c112d6099c63..46b5523de14a 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -327,9 +327,9 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2) query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) context_layer, attention_probs = attention_interface( self, diff --git a/src/transformers/models/edgetam/modeling_edgetam.py b/src/transformers/models/edgetam/modeling_edgetam.py index 261d14289d6a..d01e855aaab7 100644 --- a/src/transformers/models/edgetam/modeling_edgetam.py +++ b/src/transformers/models/edgetam/modeling_edgetam.py @@ -163,9 +163,9 @@ def forward( key = self.k_proj(key).view(*new_shape).transpose(1, 2) value = self.v_proj(value).view(*new_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) if is_flash_attention_requested(self.config) and attention_similarity is not None: # Target guided masks are represented as float masks and are incompatible with Flash Attention diff --git a/src/transformers/models/edgetam_video/modeling_edgetam_video.py b/src/transformers/models/edgetam_video/modeling_edgetam_video.py index dda7d1a4ab4c..8f2789c8c2e4 100644 --- a/src/transformers/models/edgetam_video/modeling_edgetam_video.py +++ b/src/transformers/models/edgetam_video/modeling_edgetam_video.py @@ -245,9 +245,9 @@ def forward( key = self.k_proj(key).view(*new_shape).transpose(1, 2) value = self.v_proj(value).view(*new_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) if is_flash_attention_requested(self.config) and attention_similarity is not None: # Target guided masks are represented as float masks and are incompatible with Flash Attention @@ -364,9 +364,9 @@ def forward( # Apply rotary position encoding for self-attention query, key = apply_rotary_pos_emb_2d_self_attn(query, key, cos=cos, sin=sin) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -515,9 +515,9 @@ def forward( num_k_exclude_rope=num_k_exclude_rope, ) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -1331,9 +1331,9 @@ def forward( value = value + pos_encoding # Apply attention - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, _ = attention_interface( self, diff --git a/src/transformers/models/edgetam_video/modular_edgetam_video.py b/src/transformers/models/edgetam_video/modular_edgetam_video.py index e35e5f7fdc36..c810215f9dbe 100644 --- a/src/transformers/models/edgetam_video/modular_edgetam_video.py +++ b/src/transformers/models/edgetam_video/modular_edgetam_video.py @@ -536,9 +536,9 @@ def forward( # Apply rotary position encoding for self-attention query, key = apply_rotary_pos_emb_2d_self_attn(query, key, cos=cos, sin=sin) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -612,9 +612,9 @@ def forward( num_k_exclude_rope=num_k_exclude_rope, ) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -854,9 +854,9 @@ def forward( value = value + pos_encoding # Apply attention - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, _ = attention_interface( self, diff --git a/src/transformers/models/efficientloftr/modeling_efficientloftr.py b/src/transformers/models/efficientloftr/modeling_efficientloftr.py index b90eb7f233d2..f0b79e001c52 100644 --- a/src/transformers/models/efficientloftr/modeling_efficientloftr.py +++ b/src/transformers/models/efficientloftr/modeling_efficientloftr.py @@ -418,9 +418,9 @@ def forward( query_states = query_states.view(batch_size, seq_len, -1, self.head_dim).transpose(1, 2) key_states = key_states.view(batch_size, seq_len, -1, self.head_dim).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index 2fe4c2bea9b9..e0141e8862b9 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -202,9 +202,9 @@ def forward( {"cache_position": cache_position}, ) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -280,9 +280,9 @@ def forward( # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls past_key_values.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index fc73095839fb..aa1588256d36 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -178,9 +178,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -600,9 +600,9 @@ def forward( keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/eomt/modeling_eomt.py b/src/transformers/models/eomt/modeling_eomt.py index ca5b60669348..d8dddbb3d497 100644 --- a/src/transformers/models/eomt/modeling_eomt.py +++ b/src/transformers/models/eomt/modeling_eomt.py @@ -774,9 +774,9 @@ def forward( keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index 672af9d80a2d..7d9f38ef1f36 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -214,9 +214,9 @@ def forward( {"cache_position": cache_position}, ) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -291,9 +291,9 @@ def forward( # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls past_key_values.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/ernie4_5/modeling_ernie4_5.py b/src/transformers/models/ernie4_5/modeling_ernie4_5.py index 985e07e7c0e8..3f3b0901239e 100644 --- a/src/transformers/models/ernie4_5/modeling_ernie4_5.py +++ b/src/transformers/models/ernie4_5/modeling_ernie4_5.py @@ -245,9 +245,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py b/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py index 0878e028bb3d..4215fc15dccb 100644 --- a/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +++ b/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py @@ -268,9 +268,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py index 3b02f84c8d84..3b07af7cca2c 100644 --- a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +++ b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py @@ -259,9 +259,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -606,9 +606,9 @@ def forward( key_states = key_states.transpose(0, 1).unsqueeze(0) value_states = value_states.transpose(0, 1).unsqueeze(0) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) if is_flash_attention_requested(self.config): # Flash Attention: Use cu_seqlens for variable length attention diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index 6a95eefa494d..d56274402694 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -343,9 +343,9 @@ def forward( if self.position_embedding_type == "rotary": query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/evolla/modeling_evolla.py b/src/transformers/models/evolla/modeling_evolla.py index 7961ea75cd3f..534b089c13b6 100644 --- a/src/transformers/models/evolla/modeling_evolla.py +++ b/src/transformers/models/evolla/modeling_evolla.py @@ -308,9 +308,9 @@ def forward( if self.position_embedding_type == "rotary": query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -1147,9 +1147,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/exaone4/modeling_exaone4.py b/src/transformers/models/exaone4/modeling_exaone4.py index 7e87fbf5a337..d2c1f7c9a385 100644 --- a/src/transformers/models/exaone4/modeling_exaone4.py +++ b/src/transformers/models/exaone4/modeling_exaone4.py @@ -261,9 +261,9 @@ def forward( } key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/exaone4/modular_exaone4.py b/src/transformers/models/exaone4/modular_exaone4.py index 95f6e2128c94..616c7faa795d 100644 --- a/src/transformers/models/exaone4/modular_exaone4.py +++ b/src/transformers/models/exaone4/modular_exaone4.py @@ -287,9 +287,9 @@ def forward( } key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/falcon_h1/modeling_falcon_h1.py b/src/transformers/models/falcon_h1/modeling_falcon_h1.py index 8c5017621e3b..41f3a51fa00d 100644 --- a/src/transformers/models/falcon_h1/modeling_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modeling_falcon_h1.py @@ -411,9 +411,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/falcon_h1/modular_falcon_h1.py b/src/transformers/models/falcon_h1/modular_falcon_h1.py index 2a44f8220876..27a835f5362f 100644 --- a/src/transformers/models/falcon_h1/modular_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modular_falcon_h1.py @@ -230,9 +230,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/flex_olmo/modeling_flex_olmo.py b/src/transformers/models/flex_olmo/modeling_flex_olmo.py index cba565099af8..6e2b0c773fd3 100644 --- a/src/transformers/models/flex_olmo/modeling_flex_olmo.py +++ b/src/transformers/models/flex_olmo/modeling_flex_olmo.py @@ -271,9 +271,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/florence2/modeling_florence2.py b/src/transformers/models/florence2/modeling_florence2.py index 2d317c14e5c8..aa5633a98fdb 100644 --- a/src/transformers/models/florence2/modeling_florence2.py +++ b/src/transformers/models/florence2/modeling_florence2.py @@ -245,9 +245,9 @@ def forward(self, hidden_states: torch.Tensor): scale = num_tokens**-0.5 # Channel-to-channel attention within groups: - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) hidden_states, _ = attention_interface( self, query, @@ -370,9 +370,9 @@ def forward(self, hidden_states: torch.Tensor): qkv = qkv.permute(2, 0, 3, 1, 4) query, key, value = qkv.unbind(0) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) windowed_hidden_states, _ = attention_interface( self, diff --git a/src/transformers/models/florence2/modular_florence2.py b/src/transformers/models/florence2/modular_florence2.py index 055f7685803c..e5277ff058fc 100644 --- a/src/transformers/models/florence2/modular_florence2.py +++ b/src/transformers/models/florence2/modular_florence2.py @@ -1095,9 +1095,9 @@ def forward(self, hidden_states: torch.Tensor): scale = num_tokens**-0.5 # Channel-to-channel attention within groups: - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) hidden_states, _ = attention_interface( self, query, @@ -1220,9 +1220,9 @@ def forward(self, hidden_states: torch.Tensor): qkv = qkv.permute(2, 0, 3, 1, 4) query, key, value = qkv.unbind(0) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) windowed_hidden_states, _ = attention_interface( self, diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index ecd12e1a45a5..fd9cdde7af50 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -268,9 +268,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index f1a5989ee208..7257febc275a 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -280,9 +280,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 8f2b4c5f635e..9408dce128ec 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -329,9 +329,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index c5bbf253521e..ca71a59844eb 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -369,9 +369,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index 4be5ef654521..09e96d7d53aa 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -491,9 +491,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py index a146f357ca59..e3bef554b714 100644 --- a/src/transformers/models/gemma3n/modeling_gemma3n.py +++ b/src/transformers/models/gemma3n/modeling_gemma3n.py @@ -1343,9 +1343,9 @@ def forward( past_key_values.shared_layers = {} past_key_values.shared_layers[self.layer_idx] = key_states, value_states - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/gemma3n/modular_gemma3n.py b/src/transformers/models/gemma3n/modular_gemma3n.py index 0fca0518b668..ec3cc4bef06e 100644 --- a/src/transformers/models/gemma3n/modular_gemma3n.py +++ b/src/transformers/models/gemma3n/modular_gemma3n.py @@ -1812,9 +1812,9 @@ def forward( past_key_values.shared_layers = {} past_key_values.shared_layers[self.layer_idx] = key_states, value_states - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index 5cb977b4512d..1baea7a449a8 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -679,9 +679,9 @@ def forward( else: self.is_causal = causal_attention_mask is not None - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 020ecd502fdd..f48fda56ebae 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -262,9 +262,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/glm4/modeling_glm4.py b/src/transformers/models/glm4/modeling_glm4.py index dd11e0d4dfe0..1770bb34e2b2 100644 --- a/src/transformers/models/glm4/modeling_glm4.py +++ b/src/transformers/models/glm4/modeling_glm4.py @@ -245,9 +245,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/glm4_moe/modeling_glm4_moe.py b/src/transformers/models/glm4_moe/modeling_glm4_moe.py index 81931551cec4..c667127884cd 100644 --- a/src/transformers/models/glm4_moe/modeling_glm4_moe.py +++ b/src/transformers/models/glm4_moe/modeling_glm4_moe.py @@ -252,9 +252,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py b/src/transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py index b0b108080e20..ae3bc8c4b0b9 100644 --- a/src/transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +++ b/src/transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py @@ -324,9 +324,9 @@ def forward( if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim: value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim]) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index 57940388dfb1..e57642a0a01e 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -305,9 +305,9 @@ def forward( key_states = key_states.transpose(0, 1).unsqueeze(0) value_states = value_states.transpose(0, 1).unsqueeze(0) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) if is_flash_attention_requested(self.config): # Flash Attention: Use cu_seqlens for variable length attention @@ -558,9 +558,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/glm4v/modular_glm4v.py b/src/transformers/models/glm4v/modular_glm4v.py index 28e30c32ddc6..d4001318518e 100644 --- a/src/transformers/models/glm4v/modular_glm4v.py +++ b/src/transformers/models/glm4v/modular_glm4v.py @@ -627,9 +627,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py index 620069fec656..375c174fd773 100644 --- a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py +++ b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py @@ -186,9 +186,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -685,9 +685,9 @@ def forward( key_states = key_states.transpose(0, 1).unsqueeze(0) value_states = value_states.transpose(0, 1).unsqueeze(0) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) if is_flash_attention_requested(self.config): # Flash Attention: Use cu_seqlens for variable length attention diff --git a/src/transformers/models/glm4v_moe/modular_glm4v_moe.py b/src/transformers/models/glm4v_moe/modular_glm4v_moe.py index dd59ae412916..49c8adfea602 100644 --- a/src/transformers/models/glm4v_moe/modular_glm4v_moe.py +++ b/src/transformers/models/glm4v_moe/modular_glm4v_moe.py @@ -321,9 +321,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/glm_image/modeling_glm_image.py b/src/transformers/models/glm_image/modeling_glm_image.py index 8cfd6853e28f..33939387c124 100644 --- a/src/transformers/models/glm_image/modeling_glm_image.py +++ b/src/transformers/models/glm_image/modeling_glm_image.py @@ -126,9 +126,9 @@ def forward( key_states = key_states.transpose(0, 1).unsqueeze(0) value_states = value_states.transpose(0, 1).unsqueeze(0) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) if "flash" in self.config._attn_implementation: # Flash Attention: Use cu_seqlens for variable length attention @@ -402,9 +402,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/glm_image/modular_glm_image.py b/src/transformers/models/glm_image/modular_glm_image.py index 9eaea4692fca..030987f7a81b 100644 --- a/src/transformers/models/glm_image/modular_glm_image.py +++ b/src/transformers/models/glm_image/modular_glm_image.py @@ -361,9 +361,9 @@ def forward( key_states = key_states.transpose(0, 1).unsqueeze(0) value_states = value_states.transpose(0, 1).unsqueeze(0) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) if "flash" in self.config._attn_implementation: # Flash Attention: Use cu_seqlens for variable length attention diff --git a/src/transformers/models/glm_ocr/modeling_glm_ocr.py b/src/transformers/models/glm_ocr/modeling_glm_ocr.py index 17653f18988b..e5ded6ca88e5 100644 --- a/src/transformers/models/glm_ocr/modeling_glm_ocr.py +++ b/src/transformers/models/glm_ocr/modeling_glm_ocr.py @@ -222,9 +222,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -437,9 +437,9 @@ def forward( key_states = key_states.transpose(0, 1).unsqueeze(0) value_states = value_states.transpose(0, 1).unsqueeze(0) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) if is_flash_attention_requested(self.config): # Flash Attention: Use cu_seqlens for variable length attention diff --git a/src/transformers/models/glm_ocr/modular_glm_ocr.py b/src/transformers/models/glm_ocr/modular_glm_ocr.py index 92ca8ccd1177..7659b35deaa1 100644 --- a/src/transformers/models/glm_ocr/modular_glm_ocr.py +++ b/src/transformers/models/glm_ocr/modular_glm_ocr.py @@ -290,9 +290,9 @@ def forward( key_states = key_states.transpose(0, 1).unsqueeze(0) value_states = value_states.transpose(0, 1).unsqueeze(0) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) if is_flash_attention_requested(self.config): # Flash Attention: Use cu_seqlens for variable length attention diff --git a/src/transformers/models/glmasr/modeling_glmasr.py b/src/transformers/models/glmasr/modeling_glmasr.py index 9a08c0ec1adb..b2446238da12 100644 --- a/src/transformers/models/glmasr/modeling_glmasr.py +++ b/src/transformers/models/glmasr/modeling_glmasr.py @@ -205,9 +205,9 @@ def forward( cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/glmasr/modular_glmasr.py b/src/transformers/models/glmasr/modular_glmasr.py index a81a6ed7748d..0bbffa63f042 100644 --- a/src/transformers/models/glmasr/modular_glmasr.py +++ b/src/transformers/models/glmasr/modular_glmasr.py @@ -227,9 +227,9 @@ def forward( cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index db529e8aadf7..a55437c0ae5a 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -243,9 +243,9 @@ def forward( past_key_values.is_updated[self.layer_idx] = True using_eager = self.config._attn_implementation == "eager" - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) if using_eager and self.reorder_and_upcast_attn: attn_output, attn_weights = self._upcast_and_reordered_attn( diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 4b89bcc9e140..f72b2e08eb8d 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -220,9 +220,9 @@ def forward( if self.is_cross_attention: layer_past.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 2d9e16419e81..2c717c67119a 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -229,9 +229,9 @@ def forward( } key_states, value_states = layer_past.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) # Compute attention attn_output, attn_weights = attention_interface( diff --git a/src/transformers/models/gpt_neox/modular_gpt_neox.py b/src/transformers/models/gpt_neox/modular_gpt_neox.py index b157af6a1bd2..22d140c1c781 100644 --- a/src/transformers/models/gpt_neox/modular_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modular_gpt_neox.py @@ -181,9 +181,9 @@ def forward( } key_states, value_states = layer_past.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) # Compute attention attn_output, attn_weights = attention_interface( diff --git a/src/transformers/models/gpt_oss/modeling_gpt_oss.py b/src/transformers/models/gpt_oss/modeling_gpt_oss.py index 56e894119b33..be0a69391f5b 100644 --- a/src/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -325,9 +325,9 @@ def forward( cache_kwargs = {"cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/gpt_oss/modular_gpt_oss.py b/src/transformers/models/gpt_oss/modular_gpt_oss.py index 267d42bd9e20..3f80a6504dd5 100644 --- a/src/transformers/models/gpt_oss/modular_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modular_gpt_oss.py @@ -250,9 +250,9 @@ def forward( cache_kwargs = {"cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 5cbeb3b41ee7..46fa0224ec25 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -165,9 +165,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 527b5251d3be..fe7b0ca1f616 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -387,9 +387,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index 85bcbb89f28a..9dcc81541067 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -172,9 +172,9 @@ def forward( cache_kwargs = {"cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py index 8c890be042f3..60216adda9bb 100644 --- a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py @@ -74,9 +74,9 @@ def forward( # FIME: @ARTHUR this forward is also classic: attention nope cache_kwargs = {"cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py index 3f177aa2475c..e4c78cebfc1f 100644 --- a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py +++ b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py @@ -375,9 +375,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index dc5a1a34bd94..576131d6fc8d 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -267,9 +267,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index adee070f3a20..7ea2717dd585 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -324,9 +324,9 @@ def forward( key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2) value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py b/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py index 579a1aedbaef..4b41acdc4d67 100644 --- a/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +++ b/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py @@ -205,9 +205,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py b/src/transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py index b96f430b6064..4bbe01f51e2d 100644 --- a/src/transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +++ b/src/transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py @@ -91,9 +91,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py b/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py index 20390d8c88f0..e9507729b4b4 100644 --- a/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +++ b/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py @@ -209,9 +209,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py b/src/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py index 645d257dbcdd..7d752660ac74 100644 --- a/src/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +++ b/src/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py @@ -88,9 +88,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index fa23e47bf4d6..4d4fabaccb2d 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -607,10 +607,9 @@ def forward( query_states = self.q_layer_norm(query_states) key_states = self.k_layer_norm(key_states) - attention_interface: Callable = eager_attention_forward - - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/idefics/vision.py b/src/transformers/models/idefics/vision.py index 91bdb78c3bed..adcf7f2167da 100644 --- a/src/transformers/models/idefics/vision.py +++ b/src/transformers/models/idefics/vision.py @@ -240,9 +240,9 @@ def forward( else: self.is_causal = causal_attention_mask is not None - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 5e8917097ab2..1c8ed9cee817 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -250,9 +250,9 @@ def forward( keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -603,9 +603,9 @@ def forward( if past_key_values is not None: keys, values = past_key_values.update(keys, values, self.layer_idx) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index d9536b8d6860..4f9ed64a8874 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -245,9 +245,9 @@ def forward( keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/ijepa/modeling_ijepa.py b/src/transformers/models/ijepa/modeling_ijepa.py index 0918989a08ba..def8b44deaec 100644 --- a/src/transformers/models/ijepa/modeling_ijepa.py +++ b/src/transformers/models/ijepa/modeling_ijepa.py @@ -199,9 +199,9 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2) query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) context_layer, attention_probs = attention_interface( self, diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index a7aa50dd2016..34b7fe5ee668 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -388,9 +388,9 @@ def forward( if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): past_key_values.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index 648e52f2efd5..be96f1536eec 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -249,10 +249,9 @@ def forward( ) query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2] - attention_interface: Callable = eager_attention_forward - - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py index 121db617af5c..5905f78dbe5f 100644 --- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py @@ -291,10 +291,9 @@ def forward( ) query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2] - attention_interface: Callable = eager_attention_forward - - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/internvl/modeling_internvl.py b/src/transformers/models/internvl/modeling_internvl.py index 498caae1044e..846b12872be4 100644 --- a/src/transformers/models/internvl/modeling_internvl.py +++ b/src/transformers/models/internvl/modeling_internvl.py @@ -139,9 +139,9 @@ def forward( key_states = key_states.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/internvl/modular_internvl.py b/src/transformers/models/internvl/modular_internvl.py index d5ec73e498da..bbc5a6aecfe8 100644 --- a/src/transformers/models/internvl/modular_internvl.py +++ b/src/transformers/models/internvl/modular_internvl.py @@ -104,9 +104,9 @@ def forward( key_states = key_states.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/jais2/modeling_jais2.py b/src/transformers/models/jais2/modeling_jais2.py index b8f691c363c5..6a4fb8af5bde 100644 --- a/src/transformers/models/jais2/modeling_jais2.py +++ b/src/transformers/models/jais2/modeling_jais2.py @@ -176,9 +176,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 540e3e672f8f..27250918b7b7 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -276,9 +276,9 @@ def forward( key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/jamba/modular_jamba.py b/src/transformers/models/jamba/modular_jamba.py index 4d4e98a4a5b2..3f7ea1c29e01 100644 --- a/src/transformers/models/jamba/modular_jamba.py +++ b/src/transformers/models/jamba/modular_jamba.py @@ -168,9 +168,9 @@ def forward( key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/janus/modeling_janus.py b/src/transformers/models/janus/modeling_janus.py index 7ef79803672f..0c61a42ee790 100644 --- a/src/transformers/models/janus/modeling_janus.py +++ b/src/transformers/models/janus/modeling_janus.py @@ -320,9 +320,9 @@ def forward( key_states = key_states.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/janus/modular_janus.py b/src/transformers/models/janus/modular_janus.py index c636b69b47a4..67b2f6c7b8a8 100644 --- a/src/transformers/models/janus/modular_janus.py +++ b/src/transformers/models/janus/modular_janus.py @@ -492,9 +492,9 @@ def forward( key_states = key_states.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 599a18598d6e..421d4122aa37 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -494,9 +494,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) # This is different from other models where we repeat k/v heads # instead of repeat interleaving them diff --git a/src/transformers/models/jetmoe/modular_jetmoe.py b/src/transformers/models/jetmoe/modular_jetmoe.py index f620b9e8ee04..ee2357beaa15 100644 --- a/src/transformers/models/jetmoe/modular_jetmoe.py +++ b/src/transformers/models/jetmoe/modular_jetmoe.py @@ -349,9 +349,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) # This is different from other models where we repeat k/v heads # instead of repeat interleaving them diff --git a/src/transformers/models/kosmos2/modeling_kosmos2.py b/src/transformers/models/kosmos2/modeling_kosmos2.py index d42b15d9d15c..95a5618d935c 100644 --- a/src/transformers/models/kosmos2/modeling_kosmos2.py +++ b/src/transformers/models/kosmos2/modeling_kosmos2.py @@ -335,9 +335,9 @@ def forward( else: self.is_causal = causal_attention_mask is not None - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -771,10 +771,9 @@ def forward( if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): past_key_values.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attention_forward - - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/kosmos2_5/modeling_kosmos2_5.py b/src/transformers/models/kosmos2_5/modeling_kosmos2_5.py index 44c8ca6241ac..ffc43a557d3d 100644 --- a/src/transformers/models/kosmos2_5/modeling_kosmos2_5.py +++ b/src/transformers/models/kosmos2_5/modeling_kosmos2_5.py @@ -488,15 +488,9 @@ def forward( key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = getattr( + ALL_ATTENTION_FUNCTIONS, self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -812,15 +806,9 @@ def forward( cache_kwargs = {"cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/lasr/modeling_lasr.py b/src/transformers/models/lasr/modeling_lasr.py index 8b819f1e0356..30837cbb0aa3 100644 --- a/src/transformers/models/lasr/modeling_lasr.py +++ b/src/transformers/models/lasr/modeling_lasr.py @@ -246,9 +246,9 @@ def forward( cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/lasr/modular_lasr.py b/src/transformers/models/lasr/modular_lasr.py index eeac0fb29225..b3c3e69e71f3 100644 --- a/src/transformers/models/lasr/modular_lasr.py +++ b/src/transformers/models/lasr/modular_lasr.py @@ -348,9 +348,9 @@ def forward( cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index 677f6eb75859..d206be4301e7 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -179,9 +179,9 @@ def forward( key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/lfm2/modeling_lfm2.py b/src/transformers/models/lfm2/modeling_lfm2.py index 515553ef3e96..b6d7885d841e 100644 --- a/src/transformers/models/lfm2/modeling_lfm2.py +++ b/src/transformers/models/lfm2/modeling_lfm2.py @@ -398,9 +398,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/lfm2/modular_lfm2.py b/src/transformers/models/lfm2/modular_lfm2.py index c8b74a709619..2067e7e8bd5c 100644 --- a/src/transformers/models/lfm2/modular_lfm2.py +++ b/src/transformers/models/lfm2/modular_lfm2.py @@ -249,9 +249,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py b/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py index 23923ec4f87c..28b7735c8264 100644 --- a/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py +++ b/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py @@ -474,9 +474,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/lightglue/modeling_lightglue.py b/src/transformers/models/lightglue/modeling_lightglue.py index df038b4ac73a..e1c5793bb92a 100644 --- a/src/transformers/models/lightglue/modeling_lightglue.py +++ b/src/transformers/models/lightglue/modeling_lightglue.py @@ -225,9 +225,9 @@ def forward( cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/lightglue/modular_lightglue.py b/src/transformers/models/lightglue/modular_lightglue.py index ab2ebbc78812..2325d63c9959 100644 --- a/src/transformers/models/lightglue/modular_lightglue.py +++ b/src/transformers/models/lightglue/modular_lightglue.py @@ -286,9 +286,9 @@ def forward( cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index cbbdaf0ff10b..f05ba492d8db 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -272,9 +272,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/llama4/modeling_llama4.py b/src/transformers/models/llama4/modeling_llama4.py index b8f98b385b4b..aec0fab4344d 100644 --- a/src/transformers/models/llama4/modeling_llama4.py +++ b/src/transformers/models/llama4/modeling_llama4.py @@ -393,9 +393,9 @@ def forward( cache_kwargs = {"cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, query_states, @@ -828,10 +828,9 @@ def forward( key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - attention_interface: Callable = vision_eager_attention_forward - # flex disable because breaks on TP 8, embed is 88 not power of 2 - if self.config._attn_implementation not in ["eager", "flex_attention"]: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, vision_eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/longcat_flash/modeling_longcat_flash.py b/src/transformers/models/longcat_flash/modeling_longcat_flash.py index 4a8a941d7456..42d9546e5634 100644 --- a/src/transformers/models/longcat_flash/modeling_longcat_flash.py +++ b/src/transformers/models/longcat_flash/modeling_longcat_flash.py @@ -433,9 +433,9 @@ def forward( if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim: value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim]) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/longcat_flash/modular_longcat_flash.py b/src/transformers/models/longcat_flash/modular_longcat_flash.py index 793d0a5e3b37..a54296465a5f 100644 --- a/src/transformers/models/longcat_flash/modular_longcat_flash.py +++ b/src/transformers/models/longcat_flash/modular_longcat_flash.py @@ -217,9 +217,9 @@ def forward( if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim: value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim]) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/lw_detr/modeling_lw_detr.py b/src/transformers/models/lw_detr/modeling_lw_detr.py index 318940ff14a2..7656cfab349e 100644 --- a/src/transformers/models/lw_detr/modeling_lw_detr.py +++ b/src/transformers/models/lw_detr/modeling_lw_detr.py @@ -114,9 +114,9 @@ def forward( value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2) query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) context_layer, attention_probs = attention_interface( self, @@ -692,9 +692,9 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states_original).view(hidden_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/lw_detr/modular_lw_detr.py b/src/transformers/models/lw_detr/modular_lw_detr.py index 978d9e0dd853..d4c88ae5660a 100644 --- a/src/transformers/models/lw_detr/modular_lw_detr.py +++ b/src/transformers/models/lw_detr/modular_lw_detr.py @@ -420,9 +420,9 @@ def forward( value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2) query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) context_layer, attention_probs = attention_interface( self, @@ -851,9 +851,9 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states_original).view(hidden_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index ad4e341098ae..595cdef8f3d7 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -313,9 +313,9 @@ def forward( if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): past_key_values.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index cad56a9a09cb..b43409762dad 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -234,9 +234,9 @@ def forward( if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): past_key_values.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/markuplm/modeling_markuplm.py b/src/transformers/models/markuplm/modeling_markuplm.py index 0a01c4b556f4..978b186f9544 100755 --- a/src/transformers/models/markuplm/modeling_markuplm.py +++ b/src/transformers/models/markuplm/modeling_markuplm.py @@ -373,9 +373,9 @@ def forward( key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 65a0aec397f3..ec635b3dbd09 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -254,9 +254,9 @@ def forward( if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): past_key_values.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/metaclip_2/modeling_metaclip_2.py b/src/transformers/models/metaclip_2/modeling_metaclip_2.py index fb7a9f87e0e5..cab24f20ac4d 100644 --- a/src/transformers/models/metaclip_2/modeling_metaclip_2.py +++ b/src/transformers/models/metaclip_2/modeling_metaclip_2.py @@ -208,9 +208,9 @@ def forward( keys = keys.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2) values = values.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/minimax/modeling_minimax.py b/src/transformers/models/minimax/modeling_minimax.py index a2dff7e9401b..72f3b3abb2e0 100644 --- a/src/transformers/models/minimax/modeling_minimax.py +++ b/src/transformers/models/minimax/modeling_minimax.py @@ -436,9 +436,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/minimax_m2/modeling_minimax_m2.py b/src/transformers/models/minimax_m2/modeling_minimax_m2.py index d5137fbb9523..c374f50efcc1 100644 --- a/src/transformers/models/minimax_m2/modeling_minimax_m2.py +++ b/src/transformers/models/minimax_m2/modeling_minimax_m2.py @@ -340,9 +340,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/ministral/modeling_ministral.py b/src/transformers/models/ministral/modeling_ministral.py index 379ae1b33532..91fba171c713 100644 --- a/src/transformers/models/ministral/modeling_ministral.py +++ b/src/transformers/models/ministral/modeling_ministral.py @@ -163,9 +163,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/ministral3/modeling_ministral3.py b/src/transformers/models/ministral3/modeling_ministral3.py index 1f19ddf1a03a..40cb87cf70fe 100644 --- a/src/transformers/models/ministral3/modeling_ministral3.py +++ b/src/transformers/models/ministral3/modeling_ministral3.py @@ -156,9 +156,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/ministral3/modular_ministral3.py b/src/transformers/models/ministral3/modular_ministral3.py index 19f976f19299..a12e131d0226 100644 --- a/src/transformers/models/ministral3/modular_ministral3.py +++ b/src/transformers/models/ministral3/modular_ministral3.py @@ -61,9 +61,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index e3fab2930415..483b6f813ed6 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -162,9 +162,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/mistral/modular_mistral.py b/src/transformers/models/mistral/modular_mistral.py index c8c4d84adf70..3af330cca0a7 100644 --- a/src/transformers/models/mistral/modular_mistral.py +++ b/src/transformers/models/mistral/modular_mistral.py @@ -74,9 +74,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 2a22dbdd8d1d..26cac8697520 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -335,9 +335,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/mlcd/modeling_mlcd.py b/src/transformers/models/mlcd/modeling_mlcd.py index 4d5d9ea534b4..53dd3ef503d2 100644 --- a/src/transformers/models/mlcd/modeling_mlcd.py +++ b/src/transformers/models/mlcd/modeling_mlcd.py @@ -280,9 +280,9 @@ def forward( key_states = key_states.permute(0, 2, 1, 3).contiguous() value_states = value_states.permute(0, 2, 1, 3).contiguous() - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/mlcd/modular_mlcd.py b/src/transformers/models/mlcd/modular_mlcd.py index 1cffd604ac59..f57546ffd6d7 100644 --- a/src/transformers/models/mlcd/modular_mlcd.py +++ b/src/transformers/models/mlcd/modular_mlcd.py @@ -224,9 +224,9 @@ def forward( key_states = key_states.permute(0, 2, 1, 3).contiguous() value_states = value_states.permute(0, 2, 1, 3).contiguous() - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index ac4a6b8a80fd..075c1b65ab29 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -252,10 +252,9 @@ def forward( key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim).transpose(1, 2) value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -451,10 +450,9 @@ def forward( "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" ) - attention_interface: Callable = eager_attention_forward - - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -554,10 +552,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/mobilebert/modeling_mobilebert.py b/src/transformers/models/mobilebert/modeling_mobilebert.py index fe25e0b139c8..638a1e99d56b 100644 --- a/src/transformers/models/mobilebert/modeling_mobilebert.py +++ b/src/transformers/models/mobilebert/modeling_mobilebert.py @@ -204,9 +204,9 @@ def forward( key_layer = self.key(key_tensor).view(*hidden_shape).transpose(1, 2) value_layer = self.value(value_tensor).view(*hidden_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py index 558db015bd9a..19f4fef65af0 100644 --- a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +++ b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py @@ -289,9 +289,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py b/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py index a90935e28617..93c05dbd4bb6 100644 --- a/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py +++ b/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py @@ -357,9 +357,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index 7d8731e0ddc0..1e61d3eda8c7 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -331,9 +331,9 @@ def forward( key_states, value_states, self.layer_idx, cache_kwargs ) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) is_causal = self.is_causal and attention_mask is None and q_len > 1 diff --git a/src/transformers/models/moonshine/modular_moonshine.py b/src/transformers/models/moonshine/modular_moonshine.py index b8a209036e7e..4c3e1e64c260 100644 --- a/src/transformers/models/moonshine/modular_moonshine.py +++ b/src/transformers/models/moonshine/modular_moonshine.py @@ -328,9 +328,9 @@ def forward( key_states, value_states, self.layer_idx, cache_kwargs ) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) is_causal = self.is_causal and attention_mask is None and q_len > 1 diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index ef4b4aaae935..9bfe09749fa8 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -280,9 +280,9 @@ def forward( if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): past_key_values.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index 59838c08c689..ba68b800f517 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -286,9 +286,9 @@ def forward( if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): past_key_values.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/nanochat/modeling_nanochat.py b/src/transformers/models/nanochat/modeling_nanochat.py index aabd31742cec..b8f144ef5059 100644 --- a/src/transformers/models/nanochat/modeling_nanochat.py +++ b/src/transformers/models/nanochat/modeling_nanochat.py @@ -250,9 +250,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/nanochat/modular_nanochat.py b/src/transformers/models/nanochat/modular_nanochat.py index 379434bae879..468abf1ab017 100644 --- a/src/transformers/models/nanochat/modular_nanochat.py +++ b/src/transformers/models/nanochat/modular_nanochat.py @@ -92,9 +92,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 2cc7aba473c6..14b3ca23b38e 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -495,9 +495,9 @@ def forward( if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): past_key_values.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 9101c271b23f..5f959c765558 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -270,9 +270,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/olmo/modular_olmo.py b/src/transformers/models/olmo/modular_olmo.py index df56ca86d9ad..8a1294b32a73 100644 --- a/src/transformers/models/olmo/modular_olmo.py +++ b/src/transformers/models/olmo/modular_olmo.py @@ -144,9 +144,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 2ea3248af0e6..e8d909e09bee 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -259,9 +259,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/olmo2/modular_olmo2.py b/src/transformers/models/olmo2/modular_olmo2.py index 00598ac623b4..4142f12d9477 100644 --- a/src/transformers/models/olmo2/modular_olmo2.py +++ b/src/transformers/models/olmo2/modular_olmo2.py @@ -238,9 +238,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/olmo3/modeling_olmo3.py b/src/transformers/models/olmo3/modeling_olmo3.py index a20676e2c91f..15a85ec59a36 100644 --- a/src/transformers/models/olmo3/modeling_olmo3.py +++ b/src/transformers/models/olmo3/modeling_olmo3.py @@ -193,9 +193,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/olmo3/modular_olmo3.py b/src/transformers/models/olmo3/modular_olmo3.py index 76b6bd352c0a..46efa752aa39 100644 --- a/src/transformers/models/olmo3/modular_olmo3.py +++ b/src/transformers/models/olmo3/modular_olmo3.py @@ -234,9 +234,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index 1acc5be9b4a4..da03616f97bc 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -280,9 +280,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/olmoe/modular_olmoe.py b/src/transformers/models/olmoe/modular_olmoe.py index 8fc711739ebf..69e1b13546ea 100644 --- a/src/transformers/models/olmoe/modular_olmoe.py +++ b/src/transformers/models/olmoe/modular_olmoe.py @@ -95,9 +95,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 771ea6b0efe8..c8636b012e33 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -169,10 +169,9 @@ def forward( key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) - attention_interface: Callable = eager_attention_forward - - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/ovis2/modeling_ovis2.py b/src/transformers/models/ovis2/modeling_ovis2.py index eb93e5dbab82..2b0722f5c9af 100644 --- a/src/transformers/models/ovis2/modeling_ovis2.py +++ b/src/transformers/models/ovis2/modeling_ovis2.py @@ -235,9 +235,9 @@ def forward( keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py index f813ac0c10ea..c9e8c8618f5e 100644 --- a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py @@ -340,9 +340,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -695,9 +695,9 @@ def forward( key_states = key_states.transpose(0, 1).unsqueeze(0) value_states = value_states.transpose(0, 1).unsqueeze(0) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) if is_flash_attention_requested(self.config): # Flash Attention 2: Use cu_seqlens for variable length attention diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py index 2e27be594216..78c198cddc2b 100644 --- a/src/transformers/models/parakeet/modeling_parakeet.py +++ b/src/transformers/models/parakeet/modeling_parakeet.py @@ -303,9 +303,9 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) query_states_with_bias_u = query_states + self.bias_u.view( 1, self.config.num_attention_heads, 1, self.head_dim diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index cf5ad1be8dc8..1b3bb2a5b2c7 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -141,9 +141,9 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) query_states_with_bias_u = query_states + self.bias_u.view( 1, self.config.num_attention_heads, 1, self.head_dim diff --git a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py index b5dcd379b3b7..a6c891c21e63 100644 --- a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py +++ b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py @@ -332,9 +332,9 @@ def forward( key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2) value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/patchtst/modeling_patchtst.py b/src/transformers/models/patchtst/modeling_patchtst.py index 0efcc573a1a5..1876522795a1 100755 --- a/src/transformers/models/patchtst/modeling_patchtst.py +++ b/src/transformers/models/patchtst/modeling_patchtst.py @@ -130,9 +130,9 @@ def forward( key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2) value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/pe_audio/modeling_pe_audio.py b/src/transformers/models/pe_audio/modeling_pe_audio.py index 948cd6e1fd16..f7f165a0d4d0 100644 --- a/src/transformers/models/pe_audio/modeling_pe_audio.py +++ b/src/transformers/models/pe_audio/modeling_pe_audio.py @@ -412,9 +412,9 @@ def forward( cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/pe_audio_video/modeling_pe_audio_video.py b/src/transformers/models/pe_audio_video/modeling_pe_audio_video.py index 0fb693d67941..012919f5fb60 100644 --- a/src/transformers/models/pe_audio_video/modeling_pe_audio_video.py +++ b/src/transformers/models/pe_audio_video/modeling_pe_audio_video.py @@ -346,9 +346,9 @@ def forward( cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/pe_audio_video/modular_pe_audio_video.py b/src/transformers/models/pe_audio_video/modular_pe_audio_video.py index 78bd0a044259..ccdedf8fce8a 100644 --- a/src/transformers/models/pe_audio_video/modular_pe_audio_video.py +++ b/src/transformers/models/pe_audio_video/modular_pe_audio_video.py @@ -254,9 +254,9 @@ def forward( cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/pe_video/modeling_pe_video.py b/src/transformers/models/pe_video/modeling_pe_video.py index a94e53b77dc4..4c32deecb300 100644 --- a/src/transformers/models/pe_video/modeling_pe_video.py +++ b/src/transformers/models/pe_video/modeling_pe_video.py @@ -316,9 +316,9 @@ def forward( cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 4ea56ce2d928..d1de8f3670e2 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -236,9 +236,9 @@ def forward( if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): past_key_values.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 286510a7917e..10b5a1e71ef7 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -253,9 +253,9 @@ def forward( if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): past_key_values.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 5456c7371f42..b2d2718bb25a 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -310,9 +310,9 @@ def forward( } key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index fbdd8ad96f9c..d8c6f8a50daa 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -239,9 +239,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/phi/modular_phi.py b/src/transformers/models/phi/modular_phi.py index e32894f71786..45db5b04d841 100644 --- a/src/transformers/models/phi/modular_phi.py +++ b/src/transformers/models/phi/modular_phi.py @@ -127,9 +127,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 23e06d120e7a..c007d1a8c804 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -255,9 +255,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/phi3/modular_phi3.py b/src/transformers/models/phi3/modular_phi3.py index 614f2900213c..7cb5c1730034 100644 --- a/src/transformers/models/phi3/modular_phi3.py +++ b/src/transformers/models/phi3/modular_phi3.py @@ -145,9 +145,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py index 38995c5f167d..5b6b4148260a 100644 --- a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py @@ -119,9 +119,9 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - attention_interface: Callable = simple_eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, simple_eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -679,9 +679,9 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - attention_interface: Callable = simple_eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, simple_eager_attention_forward + ) attn_output, _ = attention_interface( self, @@ -1307,9 +1307,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py index 23054869aeb6..d947e2d1b9f5 100644 --- a/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py @@ -491,9 +491,9 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - attention_interface: Callable = simple_eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, simple_eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -920,9 +920,9 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - attention_interface: Callable = simple_eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, simple_eager_attention_forward + ) attn_output, _ = attention_interface( self, diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index dc6ec1b1a586..446b17072ba4 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -242,9 +242,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/pixio/modeling_pixio.py b/src/transformers/models/pixio/modeling_pixio.py index 327a77482d79..55122fd32c2f 100644 --- a/src/transformers/models/pixio/modeling_pixio.py +++ b/src/transformers/models/pixio/modeling_pixio.py @@ -204,9 +204,9 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2) query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) context_layer, attention_probs = attention_interface( self, diff --git a/src/transformers/models/pixtral/modeling_pixtral.py b/src/transformers/models/pixtral/modeling_pixtral.py index e6006d6dae5f..423cb6acc186 100644 --- a/src/transformers/models/pixtral/modeling_pixtral.py +++ b/src/transformers/models/pixtral/modeling_pixtral.py @@ -236,9 +236,9 @@ def forward( cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=0) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 4d5b31a520d4..2b464cc381ec 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -237,9 +237,9 @@ def forward( if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): past_key_values.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 34f6f85f5979..8e890ba7921b 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -227,9 +227,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/qwen2/modular_qwen2.py b/src/transformers/models/qwen2/modular_qwen2.py index c72a8d2eeaa6..38d86176488e 100644 --- a/src/transformers/models/qwen2/modular_qwen2.py +++ b/src/transformers/models/qwen2/modular_qwen2.py @@ -82,9 +82,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index 8b0b28522594..67ea2b7e0d74 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -673,9 +673,9 @@ def forward( value_states = value_states.transpose(0, 1).unsqueeze(0) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, _ = attention_interface( self, @@ -997,9 +997,9 @@ def forward( key_states = key_states.transpose(0, 1).unsqueeze(0) value_states = value_states.transpose(0, 1).unsqueeze(0) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) if is_flash_attention_requested(self.config): # Flash Attention 2: Use cu_seqlens for variable length attention @@ -1494,9 +1494,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 340e4f9fdb47..691df1a8a77e 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -1576,9 +1576,9 @@ def forward( value_states = value_states.transpose(0, 1).unsqueeze(0) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, _ = attention_interface( self, @@ -1876,9 +1876,9 @@ def forward( key_states = key_states.transpose(0, 1).unsqueeze(0) value_states = value_states.transpose(0, 1).unsqueeze(0) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) if is_flash_attention_requested(self.config): # Flash Attention 2: Use cu_seqlens for variable length attention diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 29f66d6cd204..c5cc23371cba 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -214,9 +214,9 @@ def forward( key_states = key_states.transpose(0, 1).unsqueeze(0) value_states = value_states.transpose(0, 1).unsqueeze(0) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) if is_flash_attention_requested(self.config): # Flash Attention: Use cu_seqlens for variable length attention @@ -707,9 +707,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index c0ea10482235..2e88b49a2de0 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -162,9 +162,9 @@ def forward( key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index be9722105274..9ea02c1e8327 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -274,9 +274,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 7eb1829d17c4..d445835f5ba8 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -381,9 +381,9 @@ def forward( key_states = key_states.transpose(0, 1).unsqueeze(0) value_states = value_states.transpose(0, 1).unsqueeze(0) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) if is_flash_attention_requested(self.config): # Flash Attention: Use cu_seqlens for variable length attention @@ -547,9 +547,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index 5ab3eccd63eb..c5fe0b5b10e0 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -273,9 +273,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/qwen3/modular_qwen3.py b/src/transformers/models/qwen3/modular_qwen3.py index 247a29a2b13e..cd5727b5c322 100644 --- a/src/transformers/models/qwen3/modular_qwen3.py +++ b/src/transformers/models/qwen3/modular_qwen3.py @@ -89,9 +89,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 2f098969a6ad..e76da1d43a70 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -177,9 +177,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/qwen3_next/modeling_qwen3_next.py b/src/transformers/models/qwen3_next/modeling_qwen3_next.py index 624a580a5d88..a21877bc9ee7 100644 --- a/src/transformers/models/qwen3_next/modeling_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modeling_qwen3_next.py @@ -401,9 +401,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/qwen3_next/modular_qwen3_next.py b/src/transformers/models/qwen3_next/modular_qwen3_next.py index 09d736944e57..8457e5a3bcd1 100644 --- a/src/transformers/models/qwen3_next/modular_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modular_qwen3_next.py @@ -256,9 +256,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index 8305f8ae9381..8f7a6a963081 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -595,9 +595,9 @@ def forward( value_states = value_states.transpose(0, 1).unsqueeze(0) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, _ = attention_interface( self, @@ -925,9 +925,9 @@ def forward( key_states = key_states.transpose(0, 1).unsqueeze(0) value_states = value_states.transpose(0, 1).unsqueeze(0) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) if is_flash_attention_requested(self.config): # Flash Attention: Use cu_seqlens for variable length attention @@ -1573,9 +1573,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -2468,9 +2468,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -3509,9 +3509,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py index a06ab3e2ca7c..2ed44f0c9678 100644 --- a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -211,9 +211,9 @@ def forward( key_states = key_states.transpose(0, 1).unsqueeze(0) value_states = value_states.transpose(0, 1).unsqueeze(0) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) if is_flash_attention_requested(self.config): # Flash Attention: Use cu_seqlens for variable length attention @@ -480,9 +480,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py index 05eeec541845..64719b300929 100644 --- a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py @@ -441,9 +441,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index ca2c30e8ea35..cd91241e1167 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -277,9 +277,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -462,9 +462,9 @@ def forward( key_states = key_states.transpose(0, 1).unsqueeze(0) value_states = value_states.transpose(0, 1).unsqueeze(0) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) if is_flash_attention_requested(self.config): # Flash Attention: Use cu_seqlens for variable length attention diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index f05ef121efc8..5848f1a81a28 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -242,9 +242,9 @@ def forward( {"cache_position": cache_position}, ) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -319,9 +319,9 @@ def forward( # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls past_key_values.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index eaa1a4561f6c..ffeb47f6b9cd 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -240,9 +240,9 @@ def forward( {"cache_position": cache_position}, ) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -318,9 +318,9 @@ def forward( # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls past_key_values.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index d7cbe566c977..f7eb73785e7c 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -259,9 +259,9 @@ def forward( {"cache_position": cache_position}, ) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -337,9 +337,9 @@ def forward( # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls past_key_values.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index 52e1489f9a99..bc9b1611acbe 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -248,9 +248,9 @@ def forward( value = self._separate_heads(value, self.num_attention_heads) # SamAttention - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/sam2/modeling_sam2.py b/src/transformers/models/sam2/modeling_sam2.py index 092389ebcb65..c5c3d9d21280 100644 --- a/src/transformers/models/sam2/modeling_sam2.py +++ b/src/transformers/models/sam2/modeling_sam2.py @@ -326,9 +326,9 @@ def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor: key = key.transpose(1, 2) value = value.transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, _ = attention_interface( self, query, @@ -881,9 +881,9 @@ def forward( key = self.k_proj(key).view(*new_shape).transpose(1, 2) value = self.v_proj(value).view(*new_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) if is_flash_attention_requested(self.config) and attention_similarity is not None: # Target guided masks are represented as float masks and are incompatible with Flash Attention diff --git a/src/transformers/models/sam2/modular_sam2.py b/src/transformers/models/sam2/modular_sam2.py index 83d9514edd9d..eb90c0ee5061 100644 --- a/src/transformers/models/sam2/modular_sam2.py +++ b/src/transformers/models/sam2/modular_sam2.py @@ -494,9 +494,9 @@ def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor: key = key.transpose(1, 2) value = value.transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, _ = attention_interface( self, query, @@ -921,9 +921,9 @@ def forward( key = self.k_proj(key).view(*new_shape).transpose(1, 2) value = self.v_proj(value).view(*new_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) if is_flash_attention_requested(self.config) and attention_similarity is not None: # Target guided masks are represented as float masks and are incompatible with Flash Attention diff --git a/src/transformers/models/sam2_video/modeling_sam2_video.py b/src/transformers/models/sam2_video/modeling_sam2_video.py index 3f6d411be7a1..fdc865abd61f 100644 --- a/src/transformers/models/sam2_video/modeling_sam2_video.py +++ b/src/transformers/models/sam2_video/modeling_sam2_video.py @@ -473,9 +473,9 @@ def forward( key = self.k_proj(key).view(*new_shape).transpose(1, 2) value = self.v_proj(value).view(*new_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) if is_flash_attention_requested(self.config) and attention_similarity is not None: # Target guided masks are represented as float masks and are incompatible with Flash Attention @@ -866,9 +866,9 @@ def forward( query, key, cos, sin, repeat_freqs_k=self.rope_k_repeat, num_k_exclude_rope=num_k_exclude_rope ) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/sam2_video/modular_sam2_video.py b/src/transformers/models/sam2_video/modular_sam2_video.py index 52f63e9d802a..a3abc1e7b202 100644 --- a/src/transformers/models/sam2_video/modular_sam2_video.py +++ b/src/transformers/models/sam2_video/modular_sam2_video.py @@ -1162,9 +1162,9 @@ def forward( query, key, cos, sin, repeat_freqs_k=self.rope_k_repeat, num_k_exclude_rope=num_k_exclude_rope ) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/sam3/modeling_sam3.py b/src/transformers/models/sam3/modeling_sam3.py index 5a0f9093d251..8f4f576cd022 100644 --- a/src/transformers/models/sam3/modeling_sam3.py +++ b/src/transformers/models/sam3/modeling_sam3.py @@ -361,9 +361,9 @@ def forward( key = self.k_proj(key).view(batch_size, key_len, self.num_attention_heads, self.head_dim).transpose(1, 2) value = self.v_proj(value).view(batch_size, key_len, self.num_attention_heads, self.head_dim).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) if ( is_flash_attention_requested(self.config) @@ -509,9 +509,9 @@ def forward( cos, sin = position_embeddings query, key = apply_rotary_pos_emb_2d(query, key, cos=cos, sin=sin) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/sam3_tracker/modeling_sam3_tracker.py b/src/transformers/models/sam3_tracker/modeling_sam3_tracker.py index a12042dd37de..f0d83a3016d0 100644 --- a/src/transformers/models/sam3_tracker/modeling_sam3_tracker.py +++ b/src/transformers/models/sam3_tracker/modeling_sam3_tracker.py @@ -340,9 +340,9 @@ def forward( key = self.k_proj(key).view(*new_shape).transpose(1, 2) value = self.v_proj(value).view(*new_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) if is_flash_attention_requested(self.config) and attention_similarity is not None: # Target guided masks are represented as float masks and are incompatible with Flash Attention diff --git a/src/transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py b/src/transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py index 0f8fb9ee9858..a09b9289dc05 100644 --- a/src/transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +++ b/src/transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py @@ -478,9 +478,9 @@ def forward( key = self.k_proj(key).view(*new_shape).transpose(1, 2) value = self.v_proj(value).view(*new_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) if is_flash_attention_requested(self.config) and attention_similarity is not None: # Target guided masks are represented as float masks and are incompatible with Flash Attention @@ -871,9 +871,9 @@ def forward( query, key, cos, sin, repeat_freqs_k=self.rope_k_repeat, num_k_exclude_rope=num_k_exclude_rope ) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/sam_hq/modeling_sam_hq.py b/src/transformers/models/sam_hq/modeling_sam_hq.py index 0d182bfe8552..339fa7dc0a08 100644 --- a/src/transformers/models/sam_hq/modeling_sam_hq.py +++ b/src/transformers/models/sam_hq/modeling_sam_hq.py @@ -685,9 +685,9 @@ def forward( value = self._separate_heads(value, self.num_attention_heads) # SamHQAttention - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/seed_oss/modeling_seed_oss.py b/src/transformers/models/seed_oss/modeling_seed_oss.py index 568b5fac2c24..f3485de00a16 100644 --- a/src/transformers/models/seed_oss/modeling_seed_oss.py +++ b/src/transformers/models/seed_oss/modeling_seed_oss.py @@ -206,9 +206,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/seed_oss/modular_seed_oss.py b/src/transformers/models/seed_oss/modular_seed_oss.py index 3250dee9eedf..305e36549ee0 100644 --- a/src/transformers/models/seed_oss/modular_seed_oss.py +++ b/src/transformers/models/seed_oss/modular_seed_oss.py @@ -118,9 +118,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 195098c3d835..8f2d9b7b168d 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -322,9 +322,9 @@ def forward( key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2) value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/siglip/modeling_siglip.py b/src/transformers/models/siglip/modeling_siglip.py index bf91bfe8b56b..ca37fb6ea894 100644 --- a/src/transformers/models/siglip/modeling_siglip.py +++ b/src/transformers/models/siglip/modeling_siglip.py @@ -324,9 +324,9 @@ def forward( keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/siglip2/modeling_siglip2.py b/src/transformers/models/siglip2/modeling_siglip2.py index f39a1de270f4..b5653d2e1474 100644 --- a/src/transformers/models/siglip2/modeling_siglip2.py +++ b/src/transformers/models/siglip2/modeling_siglip2.py @@ -319,9 +319,9 @@ def forward( keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/smollm3/modeling_smollm3.py b/src/transformers/models/smollm3/modeling_smollm3.py index 96f875bc734e..69cb646d9370 100644 --- a/src/transformers/models/smollm3/modeling_smollm3.py +++ b/src/transformers/models/smollm3/modeling_smollm3.py @@ -239,9 +239,9 @@ def forward( cache_kwargs = {"cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/smollm3/modular_smollm3.py b/src/transformers/models/smollm3/modular_smollm3.py index b7aa95a78219..3a4c6e0a01d5 100644 --- a/src/transformers/models/smollm3/modular_smollm3.py +++ b/src/transformers/models/smollm3/modular_smollm3.py @@ -259,9 +259,9 @@ def forward( cache_kwargs = {"cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/smolvlm/modeling_smolvlm.py b/src/transformers/models/smolvlm/modeling_smolvlm.py index ef7c4cfb3495..0ca10fd4e900 100644 --- a/src/transformers/models/smolvlm/modeling_smolvlm.py +++ b/src/transformers/models/smolvlm/modeling_smolvlm.py @@ -205,9 +205,9 @@ def forward( keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/solar_open/modeling_solar_open.py b/src/transformers/models/solar_open/modeling_solar_open.py index c5f2c7ace1bf..1d0edc7a12d2 100644 --- a/src/transformers/models/solar_open/modeling_solar_open.py +++ b/src/transformers/models/solar_open/modeling_solar_open.py @@ -342,9 +342,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 89ec461c02b6..5401f3d69a6f 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -293,9 +293,9 @@ def forward( if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): past_key_values.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index 8e930c0b604e..48f39c614dc7 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -149,9 +149,9 @@ def forward( key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index c5a6e85a8f05..dcdf29e1b2c0 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -327,9 +327,9 @@ def forward( } key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 8613fef45695..bc5195a99778 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -182,9 +182,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/starcoder2/modular_starcoder2.py b/src/transformers/models/starcoder2/modular_starcoder2.py index f6fd0841b217..9f8b963fb379 100644 --- a/src/transformers/models/starcoder2/modular_starcoder2.py +++ b/src/transformers/models/starcoder2/modular_starcoder2.py @@ -99,9 +99,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/t5gemma/modeling_t5gemma.py b/src/transformers/models/t5gemma/modeling_t5gemma.py index f6bc2ca10eae..7c504639ddeb 100644 --- a/src/transformers/models/t5gemma/modeling_t5gemma.py +++ b/src/transformers/models/t5gemma/modeling_t5gemma.py @@ -290,9 +290,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -376,9 +376,9 @@ def forward( key_states = curr_past_key_values.layers[self.layer_idx].keys value_states = curr_past_key_values.layers[self.layer_idx].values - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/t5gemma/modular_t5gemma.py b/src/transformers/models/t5gemma/modular_t5gemma.py index c55d20ba7b66..50de054e1f97 100644 --- a/src/transformers/models/t5gemma/modular_t5gemma.py +++ b/src/transformers/models/t5gemma/modular_t5gemma.py @@ -393,9 +393,9 @@ def forward( key_states = curr_past_key_values.layers[self.layer_idx].keys value_states = curr_past_key_values.layers[self.layer_idx].values - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/t5gemma2/modeling_t5gemma2.py b/src/transformers/models/t5gemma2/modeling_t5gemma2.py index da8bdc4905bc..ca975fa79245 100644 --- a/src/transformers/models/t5gemma2/modeling_t5gemma2.py +++ b/src/transformers/models/t5gemma2/modeling_t5gemma2.py @@ -312,9 +312,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -432,9 +432,9 @@ def forward( key_states = torch.cat([key_states, cross_key_states], dim=2) value_states = torch.cat([value_states, cross_value_states], dim=2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/t5gemma2/modular_t5gemma2.py b/src/transformers/models/t5gemma2/modular_t5gemma2.py index 3245a041d20b..03b9333282c3 100644 --- a/src/transformers/models/t5gemma2/modular_t5gemma2.py +++ b/src/transformers/models/t5gemma2/modular_t5gemma2.py @@ -602,9 +602,9 @@ def forward( key_states = torch.cat([key_states, cross_key_states], dim=2) value_states = torch.cat([value_states, cross_value_states], dim=2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index d6a61f301bc9..18354efdb601 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -402,9 +402,9 @@ def forward( if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): past_key_values.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index c0d547deeedb..632e748b1494 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -245,9 +245,9 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - attention_interface: Callable = simple_eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, simple_eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/timesfm/modular_timesfm.py b/src/transformers/models/timesfm/modular_timesfm.py index de714c21c6ac..7ac5123018a2 100644 --- a/src/transformers/models/timesfm/modular_timesfm.py +++ b/src/transformers/models/timesfm/modular_timesfm.py @@ -201,9 +201,9 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - attention_interface: Callable = simple_eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, simple_eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index 2dc68c7f1656..144af5f4ca21 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -356,9 +356,9 @@ def forward( key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2) value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 03bb2ce75e6a..d818a1332a78 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -362,9 +362,9 @@ def forward( key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2) value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/vaultgemma/modeling_vaultgemma.py b/src/transformers/models/vaultgemma/modeling_vaultgemma.py index 7c583658fbbf..e972eb346fe9 100644 --- a/src/transformers/models/vaultgemma/modeling_vaultgemma.py +++ b/src/transformers/models/vaultgemma/modeling_vaultgemma.py @@ -212,9 +212,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/video_llama_3/modeling_video_llama_3.py b/src/transformers/models/video_llama_3/modeling_video_llama_3.py index d19162af17a3..2e9f142ca38f 100644 --- a/src/transformers/models/video_llama_3/modeling_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modeling_video_llama_3.py @@ -232,9 +232,9 @@ def forward( key_states = key_states.transpose(0, 1).unsqueeze(0) value_states = value_states.transpose(0, 1).unsqueeze(0) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) if is_flash_attention_requested(self.config): # Flash Attention 2: Use cu_seqlens for variable length attention diff --git a/src/transformers/models/video_llama_3/modular_video_llama_3.py b/src/transformers/models/video_llama_3/modular_video_llama_3.py index caeb8483afdd..f9b8984e44ac 100644 --- a/src/transformers/models/video_llama_3/modular_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modular_video_llama_3.py @@ -312,9 +312,9 @@ def forward( key_states = key_states.transpose(0, 1).unsqueeze(0) value_states = value_states.transpose(0, 1).unsqueeze(0) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) if is_flash_attention_requested(self.config): # Flash Attention 2: Use cu_seqlens for variable length attention diff --git a/src/transformers/models/videomae/modeling_videomae.py b/src/transformers/models/videomae/modeling_videomae.py index 96ff0e16435e..02cd9fb65fa5 100755 --- a/src/transformers/models/videomae/modeling_videomae.py +++ b/src/transformers/models/videomae/modeling_videomae.py @@ -245,9 +245,9 @@ def forward(self, hidden_states: torch.Tensor | None = None) -> tuple[torch.Tens value_layer = values.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) query_layer = queries.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) context_layer, attention_probs = attention_interface( self, diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index 7dd6debd5e54..02acbc2c2857 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -225,9 +225,9 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2) query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) context_layer, attention_probs = attention_interface( self, diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index 568de3c991c4..86bbb88c39f6 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -387,9 +387,9 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2) query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) context_layer, attention_probs = attention_interface( self, diff --git a/src/transformers/models/vit_msn/modeling_vit_msn.py b/src/transformers/models/vit_msn/modeling_vit_msn.py index c7b666702a39..8a68d52b6c48 100644 --- a/src/transformers/models/vit_msn/modeling_vit_msn.py +++ b/src/transformers/models/vit_msn/modeling_vit_msn.py @@ -222,9 +222,9 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2) query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) context_layer, attention_probs = attention_interface( self, diff --git a/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py b/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py index b5da48a61e4c..ca341e85c8b2 100644 --- a/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +++ b/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py @@ -155,9 +155,9 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2) query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) context_layer, attention_probs = attention_interface( self, diff --git a/src/transformers/models/vivit/modeling_vivit.py b/src/transformers/models/vivit/modeling_vivit.py index 3e0f18b40437..fa3524b5cd9c 100755 --- a/src/transformers/models/vivit/modeling_vivit.py +++ b/src/transformers/models/vivit/modeling_vivit.py @@ -215,9 +215,9 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2) query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) context_layer, attention_probs = attention_interface( self, diff --git a/src/transformers/models/vjepa2/modeling_vjepa2.py b/src/transformers/models/vjepa2/modeling_vjepa2.py index 332e02108056..71896012e183 100644 --- a/src/transformers/models/vjepa2/modeling_vjepa2.py +++ b/src/transformers/models/vjepa2/modeling_vjepa2.py @@ -317,9 +317,9 @@ def forward( key_layer = self.apply_rotary_embeddings(key_layer, pos_ids) query_layer = self.apply_rotary_embeddings(query_layer, pos_ids) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) context_layer, attention_probs = attention_interface( self, @@ -726,9 +726,9 @@ def forward( keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -795,9 +795,9 @@ def forward( keys = keys.view(batch_size, kv_seq_length, self.num_heads, self.head_dim).transpose(1, 2) values = values.view(batch_size, kv_seq_length, self.num_heads, self.head_dim).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/voxtral/modeling_voxtral.py b/src/transformers/models/voxtral/modeling_voxtral.py index 594ca09efe55..88b9d03aab97 100644 --- a/src/transformers/models/voxtral/modeling_voxtral.py +++ b/src/transformers/models/voxtral/modeling_voxtral.py @@ -133,9 +133,9 @@ def forward( key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 395817ded96c..63285fac216e 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -528,9 +528,9 @@ def forward( key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2) value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 4516a3da8bff..5648e019d3a9 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -338,9 +338,9 @@ def forward( key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/x_clip/modeling_x_clip.py b/src/transformers/models/x_clip/modeling_x_clip.py index f08c4a695e2b..72bc8700c10a 100644 --- a/src/transformers/models/x_clip/modeling_x_clip.py +++ b/src/transformers/models/x_clip/modeling_x_clip.py @@ -299,9 +299,9 @@ def forward( else: self.is_causal = causal_attention_mask is not None - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index e8006b07e22d..c8f385654bec 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -242,9 +242,9 @@ def forward( {"cache_position": cache_position}, ) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -319,9 +319,9 @@ def forward( # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls past_key_values.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index c765000ca274..47fab400f387 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -244,9 +244,9 @@ def forward( {"cache_position": cache_position}, ) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -321,9 +321,9 @@ def forward( # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls past_key_values.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index c59bbc4bb022..5176a611e6f7 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -239,9 +239,9 @@ def forward( {"cache_position": cache_position}, ) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -317,9 +317,9 @@ def forward( # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls past_key_values.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/yolos/modeling_yolos.py b/src/transformers/models/yolos/modeling_yolos.py index 7990ece334b1..adf96283141b 100755 --- a/src/transformers/models/yolos/modeling_yolos.py +++ b/src/transformers/models/yolos/modeling_yolos.py @@ -269,9 +269,9 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2) query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) context_layer, attention_probs = attention_interface( self, diff --git a/src/transformers/models/youtu/modeling_youtu.py b/src/transformers/models/youtu/modeling_youtu.py index bc31edbe01fe..5731e2086891 100644 --- a/src/transformers/models/youtu/modeling_youtu.py +++ b/src/transformers/models/youtu/modeling_youtu.py @@ -363,9 +363,9 @@ def forward( if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim: value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim]) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index 5c4676a4e567..b8451ae6bff3 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -265,9 +265,9 @@ def forward( if past_key_values is not None: key_states, value_states = past_key_values.update(key_states, value_states, layer_idx) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index fe73644620ad..78297b44188a 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -447,9 +447,9 @@ def forward( if past_key_values is not None: key_states, value_states = past_key_values.update(key_states, value_states, layer_idx) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 2c3a534c2fdf..af7152943adf 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -255,9 +255,9 @@ def forward( if past_key_values is not None: key_states, value_states = past_key_values.update(key_states, value_states, layer_idx) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index ff0e6c32a2fb..5a7a638dac6f 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -3084,6 +3084,51 @@ def test_not_available_kernels(self): self.assertTrue("`kernels` is either not installed or uses an incompatible version." in str(cm.exception)) + def test_attention_and_experts_modules_can_be_used_standalone(self): + """Test that both Attention and Expert modules can be used on their own, instantiated from a config without the + respective `_xxx_implementation` attr set. Also checks that it correctly raises a warning""" + from transformers.models.mixtral.configuration_mixtral import MixtralConfig + from transformers.models.mixtral.modeling_mixtral import ( + MixtralAttention, + MixtralExperts, + MixtralRotaryEmbedding, + ) + + hidden_size = 32 + seq_len = 10 + config = MixtralConfig(hidden_size=32, intermediate_size=16, num_hidden_layers=2) + experts_module = MixtralExperts(config) + attn_module = MixtralAttention(config, layer_idx=0) + + hidden_states = torch.randn(1, seq_len, hidden_size) + + # Try the Attention (check it works + raises the warning) + dummy_ids = torch.arange(seq_len).unsqueeze(0) + dummy_embeddings = MixtralRotaryEmbedding(config)(hidden_states, dummy_ids) + with CaptureLogger(logging.get_logger("transformers.modeling_utils")) as cl: + _ = attn_module(hidden_states, dummy_embeddings, None) + self.assertIn( + "You tried to access the `AttentionInterface` with a `config._attn_implementation` set to `None`.", cl.out + ) + # With a wrong _attn_implementation, it should raise a proper exception + attn_module.config._attn_implementation = "foobar" + with self.assertRaisesRegex(KeyError, "`foobar` is not a valid attention implementation registered"): + _ = attn_module(hidden_states, dummy_embeddings, None) + + # Try the Experts (check it works + raises the warning) + hidden_states = hidden_states.reshape(-1, hidden_size) + dummy_scores = torch.randn(seq_len, config.num_experts_per_tok) + dummy_indices = torch.randint(0, config.num_local_experts, (seq_len, config.num_experts_per_tok)) + with CaptureLogger(logging.get_logger("transformers.integrations.moe")) as cl: + _ = experts_module(hidden_states, dummy_indices, dummy_scores) + self.assertIn( + "You tried to access the `ExpertsInterface` with a `config._experts_implementation` set to `None`.", cl.out + ) + # With a wrong _experts_implementation, it should raise a proper exception + experts_module.config._experts_implementation = "foobar" + with self.assertRaisesRegex(KeyError, "`foobar` is not a valid experts implementation registered"): + _ = experts_module(hidden_states, dummy_indices, dummy_scores) + @require_torch class TestTensorSharing(TestCasePlus):