-
Notifications
You must be signed in to change notification settings - Fork 31.9k
use the enable_gqa param in torch.nn.functional.scaled_dot_product_at… #39412
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…tention Signed-off-by: Wang, Yi A <[email protected]>
Signed-off-by: Wang, Yi A <[email protected]>
|
@LuFinch pls help to review this pr. |
|
FAILED tests/models/nougat/test_image_processing_nougat.py::NougatImageProcessingTest::test_slow_fast_equivalence_batched - AssertionError: 0.005013074725866318 not less than or equal to 0.005 this failure case has nothing to do with the PR. the case does not call sdpa attention |
|
Please see #35235 (comment) The |
Signed-off-by: Wang, Yi A <[email protected]>
thanks for the review. add check, only enable the GQA in sdpa for non-cuda device |
Signed-off-by: Wang, Yi A <[email protected]>
vasqu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added some comments on the checks, got to be a bit strict here especially to avoid performance issues.
It would be very helpful to have a test that checks that we don't fall back to the math backend when calling SDPA (with grouped query).
|
And don't worry about the CI, the tests that failed were/are either flaky ones or not related to this PR. |
Signed-off-by: Wang, Yi A <[email protected]>
Signed-off-by: Wang, Yi A <[email protected]>
vasqu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, more or less nits left
cc @ArthurZucker @Cyrilvallez for core maintainer + attention
Signed-off-by: Wang, Yi A <[email protected]>
Cyrilvallez
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Happy to use the native sdpa option whenever possible! Just a final comment
Signed-off-by: Wang, Yi A <[email protected]>
Cyrilvallez
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect! Thanks a lot for enabling this! 🤗
|
It is true that the FlashAttention_Backend of Pytorh SDPA doesn't support attention mask and it will fall back to Math backend when user input an attn_mask. Ref to https://github.com/pytorch/pytorch/blob/d984143a74e5e726e2be35f6531582aab45bcf4c/aten/src/ATen/native/transformers/sdp_utils_cpp.h#L259 However, this check should not be a condition for GQA. The Could you remove the |
huggingface#39412) * use the enable_gqa param in torch.nn.functional.scaled_dot_product_attention Signed-off-by: Wang, Yi A <[email protected]> * ci failure fix Signed-off-by: Wang, Yi A <[email protected]> * add check Signed-off-by: Wang, Yi A <[email protected]> * fix ci failure Signed-off-by: Wang, Yi A <[email protected]> * refine code, extend to cuda Signed-off-by: Wang, Yi A <[email protected]> * refine code Signed-off-by: Wang, Yi A <[email protected]> * fix review comments Signed-off-by: Wang, Yi A <[email protected]> * refine the PR Signed-off-by: Wang, Yi A <[email protected]> --------- Signed-off-by: Wang, Yi A <[email protected]> Co-authored-by: Cyril Vallez <[email protected]>
|
According to the docs (https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) "Grouped Query Attention (GQA) is an experimental feature. It currently works only for Flash_attention and math kernel on CUDA tensor, and does not support Nested tensor." That means either the docs are wrong or we only have the fa and math backend available when using the gqa kwargs. The memory efficient variant is not mentioned so it would need some code sample imo to show that it can work alongside. For cudnn, that's also a rather new and experimental backend so I'm not sure what the conditions are there; iirc it was also not turned on by default as backend option. |
Yes. you are right. The doc for CUDA is also right. The GQA optimization for XPU is enabled since torch-2.8, may be we can add different path in "use_gqa_in_sdpa" function for different device?the recommended condition for xpu should be _is_torch_greater_or_equal_than_2_8 and not isinstance(key, torch.fx.Proxy). |
|
Interestingly enough, I just checked the cudnn backend and it seems to work with the mask + gqa on first glance - would need to look deeper into this
That sounds good to me, can you open a PR? It's hard to keep up what's possible where :D |
|
Thanks. I suppose @sywangyi who is owner of this part will submit a pr to fix it. |
huggingface#39412) * use the enable_gqa param in torch.nn.functional.scaled_dot_product_attention Signed-off-by: Wang, Yi A <[email protected]> * ci failure fix Signed-off-by: Wang, Yi A <[email protected]> * add check Signed-off-by: Wang, Yi A <[email protected]> * fix ci failure Signed-off-by: Wang, Yi A <[email protected]> * refine code, extend to cuda Signed-off-by: Wang, Yi A <[email protected]> * refine code Signed-off-by: Wang, Yi A <[email protected]> * fix review comments Signed-off-by: Wang, Yi A <[email protected]> * refine the PR Signed-off-by: Wang, Yi A <[email protected]> --------- Signed-off-by: Wang, Yi A <[email protected]> Co-authored-by: Cyril Vallez <[email protected]>
huggingface#39412) * use the enable_gqa param in torch.nn.functional.scaled_dot_product_attention Signed-off-by: Wang, Yi A <[email protected]> * ci failure fix Signed-off-by: Wang, Yi A <[email protected]> * add check Signed-off-by: Wang, Yi A <[email protected]> * fix ci failure Signed-off-by: Wang, Yi A <[email protected]> * refine code, extend to cuda Signed-off-by: Wang, Yi A <[email protected]> * refine code Signed-off-by: Wang, Yi A <[email protected]> * fix review comments Signed-off-by: Wang, Yi A <[email protected]> * refine the PR Signed-off-by: Wang, Yi A <[email protected]> --------- Signed-off-by: Wang, Yi A <[email protected]> Co-authored-by: Cyril Vallez <[email protected]>
huggingface#39412) * use the enable_gqa param in torch.nn.functional.scaled_dot_product_attention Signed-off-by: Wang, Yi A <[email protected]> * ci failure fix Signed-off-by: Wang, Yi A <[email protected]> * add check Signed-off-by: Wang, Yi A <[email protected]> * fix ci failure Signed-off-by: Wang, Yi A <[email protected]> * refine code, extend to cuda Signed-off-by: Wang, Yi A <[email protected]> * refine code Signed-off-by: Wang, Yi A <[email protected]> * fix review comments Signed-off-by: Wang, Yi A <[email protected]> * refine the PR Signed-off-by: Wang, Yi A <[email protected]> --------- Signed-off-by: Wang, Yi A <[email protected]> Co-authored-by: Cyril Vallez <[email protected]>
huggingface#39412) * use the enable_gqa param in torch.nn.functional.scaled_dot_product_attention Signed-off-by: Wang, Yi A <[email protected]> * ci failure fix Signed-off-by: Wang, Yi A <[email protected]> * add check Signed-off-by: Wang, Yi A <[email protected]> * fix ci failure Signed-off-by: Wang, Yi A <[email protected]> * refine code, extend to cuda Signed-off-by: Wang, Yi A <[email protected]> * refine code Signed-off-by: Wang, Yi A <[email protected]> * fix review comments Signed-off-by: Wang, Yi A <[email protected]> * refine the PR Signed-off-by: Wang, Yi A <[email protected]> --------- Signed-off-by: Wang, Yi A <[email protected]> Co-authored-by: Cyril Vallez <[email protected]>
huggingface#39412) * use the enable_gqa param in torch.nn.functional.scaled_dot_product_attention Signed-off-by: Wang, Yi A <[email protected]> * ci failure fix Signed-off-by: Wang, Yi A <[email protected]> * add check Signed-off-by: Wang, Yi A <[email protected]> * fix ci failure Signed-off-by: Wang, Yi A <[email protected]> * refine code, extend to cuda Signed-off-by: Wang, Yi A <[email protected]> * refine code Signed-off-by: Wang, Yi A <[email protected]> * fix review comments Signed-off-by: Wang, Yi A <[email protected]> * refine the PR Signed-off-by: Wang, Yi A <[email protected]> --------- Signed-off-by: Wang, Yi A <[email protected]> Co-authored-by: Cyril Vallez <[email protected]>
huggingface#39412) * use the enable_gqa param in torch.nn.functional.scaled_dot_product_attention Signed-off-by: Wang, Yi A <[email protected]> * ci failure fix Signed-off-by: Wang, Yi A <[email protected]> * add check Signed-off-by: Wang, Yi A <[email protected]> * fix ci failure Signed-off-by: Wang, Yi A <[email protected]> * refine code, extend to cuda Signed-off-by: Wang, Yi A <[email protected]> * refine code Signed-off-by: Wang, Yi A <[email protected]> * fix review comments Signed-off-by: Wang, Yi A <[email protected]> * refine the PR Signed-off-by: Wang, Yi A <[email protected]> --------- Signed-off-by: Wang, Yi A <[email protected]> Co-authored-by: Cyril Vallez <[email protected]>
huggingface#39412) * use the enable_gqa param in torch.nn.functional.scaled_dot_product_attention Signed-off-by: Wang, Yi A <[email protected]> * ci failure fix Signed-off-by: Wang, Yi A <[email protected]> * add check Signed-off-by: Wang, Yi A <[email protected]> * fix ci failure Signed-off-by: Wang, Yi A <[email protected]> * refine code, extend to cuda Signed-off-by: Wang, Yi A <[email protected]> * refine code Signed-off-by: Wang, Yi A <[email protected]> * fix review comments Signed-off-by: Wang, Yi A <[email protected]> * refine the PR Signed-off-by: Wang, Yi A <[email protected]> --------- Signed-off-by: Wang, Yi A <[email protected]> Co-authored-by: Cyril Vallez <[email protected]>
…tention
the GQA could be accelerated in torch.nn.functional.scaled_dot_product_attention. this pytorch api offer a param to enable gqa. see https://docs.pytorch.org/docs/2.7/generated/torch.nn.functional.scaled_dot_product_attention.html#torch-nn-functional-scaled-dot-product-attention