Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions examples/modular-transformers/modeling_dummy_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,9 @@ def forward(
{"cache_position": cache_position},
)

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)

attn_output, attn_weights = attention_interface(
self,
Expand Down Expand Up @@ -246,9 +246,9 @@ def forward(
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
past_key_values.is_updated[self.layer_idx] = True

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)

attn_output, attn_weights = attention_interface(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ def forward(
keys = keys.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
values = values.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)

attn_output, attn_weights = attention_interface(
self,
Expand Down
6 changes: 3 additions & 3 deletions examples/modular-transformers/modeling_global_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,9 @@ def forward(
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)

attn_output, attn_weights = attention_interface(
self,
Expand Down
6 changes: 3 additions & 3 deletions examples/modular-transformers/modeling_multimodal2.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ def forward(
keys = keys.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
values = values.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)

attn_output, attn_weights = attention_interface(
self,
Expand Down
6 changes: 3 additions & 3 deletions examples/modular-transformers/modeling_my_new_model2.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,9 @@ def forward(
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)

attn_output, attn_weights = attention_interface(
self,
Expand Down
12 changes: 6 additions & 6 deletions examples/modular-transformers/modeling_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,9 @@ def forward(
{"cache_position": cache_position},
)

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)

attn_output, attn_weights = attention_interface(
self,
Expand Down Expand Up @@ -249,9 +249,9 @@ def forward(
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
past_key_values.is_updated[self.layer_idx] = True

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)

attn_output, attn_weights = attention_interface(
self,
Expand Down
6 changes: 3 additions & 3 deletions examples/modular-transformers/modeling_super.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,9 @@ def forward(
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)

attn_output, attn_weights = attention_interface(
self,
Expand Down
6 changes: 3 additions & 3 deletions examples/modular-transformers/modeling_switch_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,9 @@ def forward(
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)

attn_output, attn_weights = attention_interface(
self,
Expand Down
6 changes: 3 additions & 3 deletions examples/modular-transformers/modeling_test_suffix.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,9 @@ def forward(
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)

attn_output, attn_weights = attention_interface(
self,
Expand Down
26 changes: 21 additions & 5 deletions src/transformers/integrations/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections.abc import Callable
from functools import wraps

from ..utils import logging
from ..utils.generic import GeneralInterface
from ..utils.import_utils import is_torch_available


if is_torch_available():
import torch

logger = logging.get_logger(__name__)

# Examples of experts class with its eager mm implementation
# class Experts(nn.Module):
# """Collection of expert weights stored as 3D tensors."""
Expand Down Expand Up @@ -265,6 +269,20 @@ class ExpertsInterface(GeneralInterface):
"grouped_mm": grouped_mm_experts_forward,
}

def get_interface(self, experts_implementation: str, default: Callable) -> Callable:
"""Return the requested `experts_implementation`. Also strictly check its validity, and raise if invalid."""
if experts_implementation is None:
logger.warning_once(
"You tried to access the `ExpertsInterface` with a `config._experts_implementation` set to `None`. This "
"is expected if you use an Expert Module as a standalone Module. If this is not the case, something went "
"wrong with the dispatch of `config._experts_implementation`"
)
elif experts_implementation != "eager" and experts_implementation not in self:
raise KeyError(
f"`{experts_implementation}` is not a valid experts implementation registered in the `ExpertsInterface`"
)
return super().get(experts_implementation, default)


ALL_EXPERTS_FUNCTIONS = ExpertsInterface()

Expand Down Expand Up @@ -313,11 +331,9 @@ def __init__(self, config, *args, **kwargs):

@wraps(original_forward)
def forward(self, *args, **kwargs):
experts_forward = original_forward

if self.config._experts_implementation != "eager":
experts_forward = ALL_EXPERTS_FUNCTIONS[self.config._experts_implementation]

experts_forward = ALL_EXPERTS_FUNCTIONS.get_interface(
self.config._experts_implementation, original_forward
)
return experts_forward(self, *args, **kwargs)

if not hasattr(experts_class, "_apply_gate"):
Expand Down
19 changes: 15 additions & 4 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1957,10 +1957,7 @@ def _can_set_attn_implementation(cls) -> bool:
code = f.read()
# heuristic -> if we find those patterns, the model uses the correct interface
if re.search(r"class \w+Attention\(nn.Module\)", code):
return (
"eager_attention_forward" in code
and "ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]" in code
)
return "eager_attention_forward" in code and "ALL_ATTENTION_FUNCTIONS.get_interface(" in code
else:
# If no attention layer, assume `True`. Most probably a multimodal model or inherits from existing models
return True
Expand Down Expand Up @@ -4782,6 +4779,20 @@ class AttentionInterface(GeneralInterface):
"paged|eager": eager_paged_attention_forward,
}

def get_interface(self, attn_implementation: str, default: Callable) -> Callable:
"""Return the requested `attn_implementation`. Also strictly check its validity, and raise if invalid."""
if attn_implementation is None:
logger.warning_once(
"You tried to access the `AttentionInterface` with a `config._attn_implementation` set to `None`. This "
"is expected if you use an Attention Module as a standalone Module. If this is not the case, something went "
"wrong with the dispatch of `config._attn_implementation`"
)
elif attn_implementation != "eager" and attn_implementation not in self:
raise KeyError(
f"`{attn_implementation}` is not a valid attention implementation registered in the `AttentionInterface`"
)
return super().get(attn_implementation, default)


# Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones
ALL_ATTENTION_FUNCTIONS: AttentionInterface = AttentionInterface()
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/afmoe/modeling_afmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,9 +405,9 @@ def forward(
cache_kwargs = {"cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)

output, attn_weights = attention_interface(
self,
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/afmoe/modular_afmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,9 @@ def forward(
cache_kwargs = {"cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)

output, attn_weights = attention_interface(
self,
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/aimv2/modeling_aimv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,9 +270,9 @@ def forward(
keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)

attn_output, attn_weights = attention_interface(
self,
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/albert/modeling_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,9 @@ def forward(
key_layer = self.key(hidden_states).view(*hidden_shape).transpose(1, 2)
value_layer = self.value(hidden_states).view(*hidden_shape).transpose(1, 2)

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)

attn_output, attn_weights = attention_interface(
self,
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/align/modeling_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,9 +628,9 @@ def forward(
key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)

attn_output, attn_weights = attention_interface(
self,
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/altclip/modeling_altclip.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,9 +505,9 @@ def forward(
else:
self.is_causal = causal_attention_mask is not None

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)

attn_output, attn_weights = attention_interface(
self,
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/apertus/modeling_apertus.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,9 +266,9 @@ def forward(
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)

attn_output, attn_weights = attention_interface(
self,
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/apertus/modular_apertus.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,9 @@ def forward(
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)

attn_output, attn_weights = attention_interface(
self,
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/arcee/modeling_arcee.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,9 @@ def forward(
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)

attn_output, attn_weights = attention_interface(
self,
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/aria/modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,9 +497,9 @@ def forward(
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)

attn_output, attn_weights = attention_interface(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,9 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens
value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2)
query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2)

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)

context_layer, attention_probs = attention_interface(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,9 @@ def forward(
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
)

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)

attn_output, attn_weights = attention_interface(
self,
Expand Down
Loading