diff --git a/src/transformers/integrations/sdpa_attention.py b/src/transformers/integrations/sdpa_attention.py index 3667832061df..195cb447c44b 100644 --- a/src/transformers/integrations/sdpa_attention.py +++ b/src/transformers/integrations/sdpa_attention.py @@ -3,11 +3,15 @@ import torch from ..utils import logging +from ..utils.import_utils import is_torch_greater_or_equal logger = logging.get_logger(__name__) +_is_torch_greater_or_equal_than_2_5 = is_torch_greater_or_equal("2.5", accept_dev=True) + + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -20,6 +24,14 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def use_gqa_in_sdpa(attention_mask: Optional[torch.Tensor], key: torch.Tensor) -> bool: + # GQA can only be used under the following conditions + # 1. torch version >= 2.5 + # 2. attention_mask is None (otherwise it will fall back to the math kernel) + # 3. key is not a torch.fx.Proxy (otherwise it will fail with a tracing error) + return _is_torch_greater_or_equal_than_2_5 and attention_mask is None and not isinstance(key, torch.fx.Proxy) + + def sdpa_attention_forward( module: torch.nn.Module, query: torch.Tensor, @@ -36,10 +48,13 @@ def sdpa_attention_forward( "`sdpa` attention does not support `output_attentions=True` or `head_mask`." " Please set your attention to `eager` if you want any of these features." ) - + sdpa_kwargs = {} if hasattr(module, "num_key_value_groups"): - key = repeat_kv(key, module.num_key_value_groups) - value = repeat_kv(value, module.num_key_value_groups) + if not use_gqa_in_sdpa(attention_mask, key): + key = repeat_kv(key, module.num_key_value_groups) + value = repeat_kv(value, module.num_key_value_groups) + else: + sdpa_kwargs = {"enable_gqa": True} if attention_mask is not None and attention_mask.ndim == 4: attention_mask = attention_mask[:, :, :, : key.shape[-2]] @@ -71,6 +86,7 @@ def sdpa_attention_forward( dropout_p=dropout, scale=scaling, is_causal=is_causal, + **sdpa_kwargs, ) attn_output = attn_output.transpose(1, 2).contiguous() diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 26f9f56996a6..bb73052b2519 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -588,12 +588,12 @@ def test_dynamic_cache_exportability(self): past_key_values=past_key_values_eager, use_cache=True, ) - self.assertTrue(torch.allclose(res.logits, res_eager.logits)) + self.assertTrue(torch.allclose(res.logits, res_eager.logits, atol=1e-5)) for k1, k2 in zip(res.past_key_values.key_cache, res_eager.past_key_values.key_cache): - self.assertTrue(torch.allclose(k1, k2)) + self.assertTrue(torch.allclose(k1, k2, atol=1e-5)) for v1, v2 in zip(res.past_key_values.value_cache, res_eager.past_key_values.value_cache): - self.assertTrue(torch.allclose(v1, v2)) + self.assertTrue(torch.allclose(v1, v2, atol=1e-5)) def test_dynamic_cache_exportability_multiple_run(self): # When exporting with DynamicCache, you should export two graphs: @@ -686,10 +686,10 @@ def test_dynamic_cache_exportability_multiple_run(self): ) for k1, k2 in zip(res_export_2.past_key_values.key_cache, res_eager_2.past_key_values.key_cache): - self.assertTrue(torch.allclose(k1, k2)) + self.assertTrue(torch.allclose(k1, k2, atol=1e-5)) for v1, v2 in zip(res_export_2.past_key_values.value_cache, res_eager_2.past_key_values.value_cache): - self.assertTrue(torch.allclose(v1, v2)) + self.assertTrue(torch.allclose(v1, v2, atol=1e-5)) @unittest.skip("Runs on my machine locally, passed, no idea why it does not online") def test_static_cache_exportability(self):