Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions src/transformers/integrations/sdpa_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
import torch

from ..utils import logging
from ..utils.import_utils import is_torch_greater_or_equal


logger = logging.get_logger(__name__)


_is_torch_greater_or_equal_than_2_5 = is_torch_greater_or_equal("2.5", accept_dev=True)


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
Expand All @@ -20,6 +24,14 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def use_gqa_in_sdpa(attention_mask: Optional[torch.Tensor], key: torch.Tensor) -> bool:
# GQA can only be used under the following conditions
# 1. torch version >= 2.5
# 2. attention_mask is None (otherwise it will fall back to the math kernel)
# 3. key is not a torch.fx.Proxy (otherwise it will fail with a tracing error)
return _is_torch_greater_or_equal_than_2_5 and attention_mask is None and not isinstance(key, torch.fx.Proxy)


def sdpa_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
Expand All @@ -36,10 +48,13 @@ def sdpa_attention_forward(
"`sdpa` attention does not support `output_attentions=True` or `head_mask`."
" Please set your attention to `eager` if you want any of these features."
)

sdpa_kwargs = {}
if hasattr(module, "num_key_value_groups"):
key = repeat_kv(key, module.num_key_value_groups)
value = repeat_kv(value, module.num_key_value_groups)
if not use_gqa_in_sdpa(attention_mask, key):
key = repeat_kv(key, module.num_key_value_groups)
value = repeat_kv(value, module.num_key_value_groups)
else:
sdpa_kwargs = {"enable_gqa": True}

if attention_mask is not None and attention_mask.ndim == 4:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
Expand Down Expand Up @@ -71,6 +86,7 @@ def sdpa_attention_forward(
dropout_p=dropout,
scale=scaling,
is_causal=is_causal,
**sdpa_kwargs,
)
attn_output = attn_output.transpose(1, 2).contiguous()

Expand Down
10 changes: 5 additions & 5 deletions tests/utils/test_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,12 +588,12 @@ def test_dynamic_cache_exportability(self):
past_key_values=past_key_values_eager,
use_cache=True,
)
self.assertTrue(torch.allclose(res.logits, res_eager.logits))
self.assertTrue(torch.allclose(res.logits, res_eager.logits, atol=1e-5))
for k1, k2 in zip(res.past_key_values.key_cache, res_eager.past_key_values.key_cache):
self.assertTrue(torch.allclose(k1, k2))
self.assertTrue(torch.allclose(k1, k2, atol=1e-5))

for v1, v2 in zip(res.past_key_values.value_cache, res_eager.past_key_values.value_cache):
self.assertTrue(torch.allclose(v1, v2))
self.assertTrue(torch.allclose(v1, v2, atol=1e-5))

def test_dynamic_cache_exportability_multiple_run(self):
# When exporting with DynamicCache, you should export two graphs:
Expand Down Expand Up @@ -686,10 +686,10 @@ def test_dynamic_cache_exportability_multiple_run(self):
)

for k1, k2 in zip(res_export_2.past_key_values.key_cache, res_eager_2.past_key_values.key_cache):
self.assertTrue(torch.allclose(k1, k2))
self.assertTrue(torch.allclose(k1, k2, atol=1e-5))

for v1, v2 in zip(res_export_2.past_key_values.value_cache, res_eager_2.past_key_values.value_cache):
self.assertTrue(torch.allclose(v1, v2))
self.assertTrue(torch.allclose(v1, v2, atol=1e-5))

@unittest.skip("Runs on my machine locally, passed, no idea why it does not online")
def test_static_cache_exportability(self):
Expand Down