Skip to content

Conversation

@tianleiwu
Copy link
Contributor

@tianleiwu tianleiwu commented Jul 10, 2025

Description

Update Flash Attention to support head_sink for smooth softmax in GQA.

Changes:

  • Update flash attention to support head_sink
  • Add test_gqa.py to test it
  • Remove test_gqa_cuda.py

Note: Memory efficient attention change will be in separated PR.

Motivation and Context

#25269

@tianleiwu tianleiwu marked this pull request as draft July 10, 2025 18:26
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.


# Triton-based implementation for CUDA
def rotary_embedding_cuda(*args, **kwargs):
from rotary_flash import apply_rotary_emb

Check notice

Code scanning / lintrunner

RUFF/PLC0415 Note test

import should be at the top-level of a file.
See https://docs.astral.sh/ruff/rules/import-outside-top-level
return rotary_embedding_cuda(x, cos, sin, seqlen_offsets=pos, interleaved=interleaved)
except ImportError:
print("WARNING: Triton-based rotary embedding not found. Falling back to PyTorch version.")
use_cuda_triton = False

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable use_cuda_triton is not used.

Copilot Autofix

AI 7 months ago

To fix the issue, the unused assignment to use_cuda_triton on line 194 should be removed. This will eliminate the redundant code while preserving the function's behavior. Since the variable is already used earlier in the function (line 188), its presence is still meaningful for the initial conditional logic. No additional changes are required.


Suggested changeset 1
onnxruntime/test/python/transformers/test_gqa.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/onnxruntime/test/python/transformers/test_gqa.py b/onnxruntime/test/python/transformers/test_gqa.py
--- a/onnxruntime/test/python/transformers/test_gqa.py
+++ b/onnxruntime/test/python/transformers/test_gqa.py
@@ -193,3 +193,2 @@
             print("WARNING: Triton-based rotary embedding not found. Falling back to PyTorch version.")
-            use_cuda_triton = False
 
EOF
@@ -193,3 +193,2 @@
print("WARNING: Triton-based rotary embedding not found. Falling back to PyTorch version.")
use_cuda_triton = False

Copilot is powered by AI and may make mistakes. Always verify output.

# Fill NaNs with 0
if window_size[0] >= 0 or window_size[1] >= 0:
attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0)

Check failure

Code scanning / CodeQL

Potentially uninitialized local variable Error test

Local variable 'local_mask' may be used before it is initialized.

Copilot Autofix

AI 7 months ago

To fix the issue, we need to ensure that local_mask is always initialized before it is used. The best approach is to initialize local_mask to None at the start of the function. Then, before using local_mask on line 748, we can check if it is None and handle that case appropriately. This ensures that the variable is always defined, regardless of the conditions.


Suggested changeset 1
onnxruntime/test/python/transformers/test_gqa.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/onnxruntime/test/python/transformers/test_gqa.py b/onnxruntime/test/python/transformers/test_gqa.py
--- a/onnxruntime/test/python/transformers/test_gqa.py
+++ b/onnxruntime/test/python/transformers/test_gqa.py
@@ -730,2 +730,3 @@
 
+    local_mask = None
     if window_size[0] >= 0 or window_size[1] >= 0:
@@ -746,3 +747,3 @@
     # Fill NaNs with 0
-    if window_size[0] >= 0 or window_size[1] >= 0:
+    if window_size[0] >= 0 or window_size[1] >= 0 and local_mask is not None:
         attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0)
EOF
@@ -730,2 +730,3 @@

local_mask = None
if window_size[0] >= 0 or window_size[1] >= 0:
@@ -746,3 +747,3 @@
# Fill NaNs with 0
if window_size[0] >= 0 or window_size[1] >= 0:
if window_size[0] >= 0 or window_size[1] >= 0 and local_mask is not None:
attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0)
Copilot is powered by AI and may make mistakes. Always verify output.
@tianleiwu tianleiwu closed this Jul 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants