Skip to content

Conversation

@sywangyi
Copy link
Contributor

…tention
the GQA could be accelerated in torch.nn.functional.scaled_dot_product_attention. this pytorch api offer a param to enable gqa. see https://docs.pytorch.org/docs/2.7/generated/torch.nn.functional.scaled_dot_product_attention.html#torch-nn-functional-scaled-dot-product-attention

Signed-off-by: Wang, Yi A <[email protected]>
@liangan1
Copy link

@LuFinch pls help to review this pr.

@sywangyi
Copy link
Contributor Author

FAILED tests/models/nougat/test_image_processing_nougat.py::NougatImageProcessingTest::test_slow_fast_equivalence_batched - AssertionError: 0.005013074725866318 not less than or equal to 0.005 this failure case has nothing to do with the PR. the case does not call sdpa attention

@vasqu
Copy link
Contributor

vasqu commented Jul 15, 2025

Please see #35235 (comment)

The enable_gqa kwarg is pretty restrictive and would need proper checks around it (version, mask) to ensure we do not fall back to the math kernel / use unsupported features of older torch.

Signed-off-by: Wang, Yi A <[email protected]>
@sywangyi
Copy link
Contributor Author

sywangyi commented Jul 15, 2025

Please see #35235 (comment)

The enable_gqa kwarg is pretty restrictive and would need proper checks around it (version, mask) to ensure we do not fall back to the math kernel / use unsupported features of older torch.

thanks for the review. add check, only enable the GQA in sdpa for non-cuda device

Signed-off-by: Wang, Yi A <[email protected]>
Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some comments on the checks, got to be a bit strict here especially to avoid performance issues.

It would be very helpful to have a test that checks that we don't fall back to the math backend when calling SDPA (with grouped query).

@vasqu
Copy link
Contributor

vasqu commented Jul 16, 2025

And don't worry about the CI, the tests that failed were/are either flaky ones or not related to this PR.

@vasqu vasqu mentioned this pull request Jul 17, 2025
30 tasks
Signed-off-by: Wang, Yi A <[email protected]>
Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, more or less nits left

cc @ArthurZucker @Cyrilvallez for core maintainer + attention

Signed-off-by: Wang, Yi A <[email protected]>
Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Happy to use the native sdpa option whenever possible! Just a final comment

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect! Thanks a lot for enabling this! 🤗

@Cyrilvallez Cyrilvallez merged commit 9323d08 into huggingface:main Jul 21, 2025
23 checks passed
@LuFinch
Copy link

LuFinch commented Jul 22, 2025

It is true that the FlashAttention_Backend of Pytorh SDPA doesn't support attention mask and it will fall back to Math backend when user input an attn_mask. Ref to https://github.com/pytorch/pytorch/blob/d984143a74e5e726e2be35f6531582aab45bcf4c/aten/src/ATen/native/transformers/sdp_utils_cpp.h#L259

However, this check should not be a condition for GQA. The attn_mask along with enable_gqa = True can run into optimized kernels in other SDPA backends, like cudnn_attention, efficient_attention and overrideable backends.

Could you remove the attention_mask is None from use_gqa_in_sdpa or do more code refine?

zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Jul 22, 2025
huggingface#39412)

* use the enable_gqa param in torch.nn.functional.scaled_dot_product_attention

Signed-off-by: Wang, Yi A <[email protected]>

* ci failure fix

Signed-off-by: Wang, Yi A <[email protected]>

* add check

Signed-off-by: Wang, Yi A <[email protected]>

* fix ci failure

Signed-off-by: Wang, Yi A <[email protected]>

* refine code, extend to cuda

Signed-off-by: Wang, Yi A <[email protected]>

* refine code

Signed-off-by: Wang, Yi A <[email protected]>

* fix review comments

Signed-off-by: Wang, Yi A <[email protected]>

* refine the PR

Signed-off-by: Wang, Yi A <[email protected]>

---------

Signed-off-by: Wang, Yi A <[email protected]>
Co-authored-by: Cyril Vallez <[email protected]>
@vasqu
Copy link
Contributor

vasqu commented Jul 22, 2025

According to the docs (https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)

"Grouped Query Attention (GQA) is an experimental feature. It currently works only for Flash_attention and math kernel on CUDA tensor, and does not support Nested tensor."

That means either the docs are wrong or we only have the fa and math backend available when using the gqa kwargs. The memory efficient variant is not mentioned so it would need some code sample imo to show that it can work alongside. For cudnn, that's also a rather new and experimental backend so I'm not sure what the conditions are there; iirc it was also not turned on by default as backend option.

@liangan1
Copy link

According to the docs (https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)

"Grouped Query Attention (GQA) is an experimental feature. It currently works only for Flash_attention and math kernel on CUDA tensor, and does not support Nested tensor."

That means either the docs are wrong or we only have the fa and math backend available when using the gqa kwargs. The memory efficient variant is not mentioned so it would need some code sample imo to show that it can work alongside. For cudnn, that's also a rather new and experimental backend so I'm not sure what the conditions are there; iirc it was also not turned on by default as backend option.

Yes. you are right. The doc for CUDA is also right. The GQA optimization for XPU is enabled since torch-2.8, may be we can add different path in "use_gqa_in_sdpa" function for different device?the recommended condition for xpu should be _is_torch_greater_or_equal_than_2_8 and not isinstance(key, torch.fx.Proxy).

@vasqu
Copy link
Contributor

vasqu commented Jul 22, 2025

Interestingly enough, I just checked the cudnn backend and it seems to work with the mask + gqa on first glance - would need to look deeper into this

Yes. you are right. The doc for CUDA is also right. The GQA optimization for XPU is enabled since torch-2.8, may be we can add different path in "use_gqa_in_sdpa" function for different device?the recommended condition for xpu should be _is_torch_greater_or_equal_than_2_8 and not isinstance(key, torch.fx.Proxy).

That sounds good to me, can you open a PR? It's hard to keep up what's possible where :D

@liangan1
Copy link

Thanks. I suppose @sywangyi who is owner of this part will submit a pr to fix it.

zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
huggingface#39412)

* use the enable_gqa param in torch.nn.functional.scaled_dot_product_attention

Signed-off-by: Wang, Yi A <[email protected]>

* ci failure fix

Signed-off-by: Wang, Yi A <[email protected]>

* add check

Signed-off-by: Wang, Yi A <[email protected]>

* fix ci failure

Signed-off-by: Wang, Yi A <[email protected]>

* refine code, extend to cuda

Signed-off-by: Wang, Yi A <[email protected]>

* refine code

Signed-off-by: Wang, Yi A <[email protected]>

* fix review comments

Signed-off-by: Wang, Yi A <[email protected]>

* refine the PR

Signed-off-by: Wang, Yi A <[email protected]>

---------

Signed-off-by: Wang, Yi A <[email protected]>
Co-authored-by: Cyril Vallez <[email protected]>
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
huggingface#39412)

* use the enable_gqa param in torch.nn.functional.scaled_dot_product_attention

Signed-off-by: Wang, Yi A <[email protected]>

* ci failure fix

Signed-off-by: Wang, Yi A <[email protected]>

* add check

Signed-off-by: Wang, Yi A <[email protected]>

* fix ci failure

Signed-off-by: Wang, Yi A <[email protected]>

* refine code, extend to cuda

Signed-off-by: Wang, Yi A <[email protected]>

* refine code

Signed-off-by: Wang, Yi A <[email protected]>

* fix review comments

Signed-off-by: Wang, Yi A <[email protected]>

* refine the PR

Signed-off-by: Wang, Yi A <[email protected]>

---------

Signed-off-by: Wang, Yi A <[email protected]>
Co-authored-by: Cyril Vallez <[email protected]>
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
huggingface#39412)

* use the enable_gqa param in torch.nn.functional.scaled_dot_product_attention

Signed-off-by: Wang, Yi A <[email protected]>

* ci failure fix

Signed-off-by: Wang, Yi A <[email protected]>

* add check

Signed-off-by: Wang, Yi A <[email protected]>

* fix ci failure

Signed-off-by: Wang, Yi A <[email protected]>

* refine code, extend to cuda

Signed-off-by: Wang, Yi A <[email protected]>

* refine code

Signed-off-by: Wang, Yi A <[email protected]>

* fix review comments

Signed-off-by: Wang, Yi A <[email protected]>

* refine the PR

Signed-off-by: Wang, Yi A <[email protected]>

---------

Signed-off-by: Wang, Yi A <[email protected]>
Co-authored-by: Cyril Vallez <[email protected]>
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
huggingface#39412)

* use the enable_gqa param in torch.nn.functional.scaled_dot_product_attention

Signed-off-by: Wang, Yi A <[email protected]>

* ci failure fix

Signed-off-by: Wang, Yi A <[email protected]>

* add check

Signed-off-by: Wang, Yi A <[email protected]>

* fix ci failure

Signed-off-by: Wang, Yi A <[email protected]>

* refine code, extend to cuda

Signed-off-by: Wang, Yi A <[email protected]>

* refine code

Signed-off-by: Wang, Yi A <[email protected]>

* fix review comments

Signed-off-by: Wang, Yi A <[email protected]>

* refine the PR

Signed-off-by: Wang, Yi A <[email protected]>

---------

Signed-off-by: Wang, Yi A <[email protected]>
Co-authored-by: Cyril Vallez <[email protected]>
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
huggingface#39412)

* use the enable_gqa param in torch.nn.functional.scaled_dot_product_attention

Signed-off-by: Wang, Yi A <[email protected]>

* ci failure fix

Signed-off-by: Wang, Yi A <[email protected]>

* add check

Signed-off-by: Wang, Yi A <[email protected]>

* fix ci failure

Signed-off-by: Wang, Yi A <[email protected]>

* refine code, extend to cuda

Signed-off-by: Wang, Yi A <[email protected]>

* refine code

Signed-off-by: Wang, Yi A <[email protected]>

* fix review comments

Signed-off-by: Wang, Yi A <[email protected]>

* refine the PR

Signed-off-by: Wang, Yi A <[email protected]>

---------

Signed-off-by: Wang, Yi A <[email protected]>
Co-authored-by: Cyril Vallez <[email protected]>
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
huggingface#39412)

* use the enable_gqa param in torch.nn.functional.scaled_dot_product_attention

Signed-off-by: Wang, Yi A <[email protected]>

* ci failure fix

Signed-off-by: Wang, Yi A <[email protected]>

* add check

Signed-off-by: Wang, Yi A <[email protected]>

* fix ci failure

Signed-off-by: Wang, Yi A <[email protected]>

* refine code, extend to cuda

Signed-off-by: Wang, Yi A <[email protected]>

* refine code

Signed-off-by: Wang, Yi A <[email protected]>

* fix review comments

Signed-off-by: Wang, Yi A <[email protected]>

* refine the PR

Signed-off-by: Wang, Yi A <[email protected]>

---------

Signed-off-by: Wang, Yi A <[email protected]>
Co-authored-by: Cyril Vallez <[email protected]>
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
huggingface#39412)

* use the enable_gqa param in torch.nn.functional.scaled_dot_product_attention

Signed-off-by: Wang, Yi A <[email protected]>

* ci failure fix

Signed-off-by: Wang, Yi A <[email protected]>

* add check

Signed-off-by: Wang, Yi A <[email protected]>

* fix ci failure

Signed-off-by: Wang, Yi A <[email protected]>

* refine code, extend to cuda

Signed-off-by: Wang, Yi A <[email protected]>

* refine code

Signed-off-by: Wang, Yi A <[email protected]>

* fix review comments

Signed-off-by: Wang, Yi A <[email protected]>

* refine the PR

Signed-off-by: Wang, Yi A <[email protected]>

---------

Signed-off-by: Wang, Yi A <[email protected]>
Co-authored-by: Cyril Vallez <[email protected]>
@sywangyi sywangyi deleted the yi_sdpa branch November 19, 2025 04:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants