-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[CUDA] Update Flash Attention to support head_sink for smooth softmax in GQA #25358
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
|
|
||
| # Triton-based implementation for CUDA | ||
| def rotary_embedding_cuda(*args, **kwargs): | ||
| from rotary_flash import apply_rotary_emb |
Check notice
Code scanning / lintrunner
RUFF/PLC0415 Note test
See https://docs.astral.sh/ruff/rules/import-outside-top-level
| return rotary_embedding_cuda(x, cos, sin, seqlen_offsets=pos, interleaved=interleaved) | ||
| except ImportError: | ||
| print("WARNING: Triton-based rotary embedding not found. Falling back to PyTorch version.") | ||
| use_cuda_triton = False |
Check notice
Code scanning / CodeQL
Unused local variable Note test
Show autofix suggestion
Hide autofix suggestion
Copilot Autofix
AI 7 months ago
To fix the issue, the unused assignment to use_cuda_triton on line 194 should be removed. This will eliminate the redundant code while preserving the function's behavior. Since the variable is already used earlier in the function (line 188), its presence is still meaningful for the initial conditional logic. No additional changes are required.
| @@ -193,3 +193,2 @@ | ||
| print("WARNING: Triton-based rotary embedding not found. Falling back to PyTorch version.") | ||
| use_cuda_triton = False | ||
|
|
|
|
||
| # Fill NaNs with 0 | ||
| if window_size[0] >= 0 or window_size[1] >= 0: | ||
| attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) |
Check failure
Code scanning / CodeQL
Potentially uninitialized local variable Error test
Show autofix suggestion
Hide autofix suggestion
Copilot Autofix
AI 7 months ago
To fix the issue, we need to ensure that local_mask is always initialized before it is used. The best approach is to initialize local_mask to None at the start of the function. Then, before using local_mask on line 748, we can check if it is None and handle that case appropriately. This ensures that the variable is always defined, regardless of the conditions.
-
Copy modified line R731 -
Copy modified line R748
| @@ -730,2 +730,3 @@ | ||
|
|
||
| local_mask = None | ||
| if window_size[0] >= 0 or window_size[1] >= 0: | ||
| @@ -746,3 +747,3 @@ | ||
| # Fill NaNs with 0 | ||
| if window_size[0] >= 0 or window_size[1] >= 0: | ||
| if window_size[0] >= 0 or window_size[1] >= 0 and local_mask is not None: | ||
| attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) |
Description
Update Flash Attention to support head_sink for smooth softmax in GQA.
Changes:
Note: Memory efficient attention change will be in separated PR.
Motivation and Context
#25269