Skip to content

Conversation

@tianleiwu
Copy link
Contributor

@tianleiwu tianleiwu commented Jul 17, 2025

Description

Update Flash Attention to support softmax sink in GQA.

Changes:

  • Update flash attention to support head_sink
  • Add test_gqa.py to test cuda, and remove test_gqa_cuda.py.

Note that the sink is treated as scaled, while the elements in QK GEMMs is not scaled. The sink value does not need scaling or softcap, and it joins softmax with those scaled or soft-capped values. There are two ways to add sink to softmax:

  • One way is to patch normalize_softmax_lse to use sink to update max and sum. Pros is major change in one function; Cons is the logic is a little complex since row_max is unscaled, while row_sum is scaled.
  • Another way is to change softmax_rescale_o to handle the sink directly in the first block of online softmax by using an unscaled sink value. It is a robust way to keep core algorithm consistent. Cons is need change in multiple places, and it is little hard to work with softcap.

This PR use the the first approach for easy integration.

Note: Memory efficient attention change will be in separated PR.

Motivation and Context

#25269

@tianleiwu tianleiwu merged commit e6c84b8 into main Jul 17, 2025
88 of 90 checks passed
@tianleiwu tianleiwu deleted the tlwu/gqa_head_sink_cuda branch July 17, 2025 22:52
qti-yuduo pushed a commit to CodeLinaro/onnxruntime that referenced this pull request Aug 8, 2025
### Description

Update Flash Attention to support softmax sink in GQA.

Changes:
- [x] Update flash attention to support head_sink
- [x] Add test_gqa.py to test cuda, and remove test_gqa_cuda.py.

Note that the sink is treated as scaled, while the elements in QK GEMMs
is not scaled. The sink value does not need scaling or softcap, and it
joins softmax with those scaled or soft-capped values. There are two
ways to add sink to softmax:
* One way is to [patch
normalize_softmax_lse](https://github.com/microsoft/onnxruntime/blob/1cf1aa786f6e7f7e6abd6fba8b8aea2e7a43092c/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h#L143-L178)
to use sink to update max and sum. Pros is major change in one function;
Cons is the logic is a little complex since row_max is unscaled, while
row_sum is scaled.
* Another way is to change softmax_rescale_o to handle the sink directly
in the first block of online softmax by using an unscaled sink value. It
is a robust way to keep core algorithm consistent. Cons is need change
in multiple places, and it is little hard to work with softcap.

This PR use the the first approach for easy integration.

Note: Memory efficient attention change will be in separated PR.

### Motivation and Context
microsoft#25269
sanketkaleoss pushed a commit to sanketkaleoss/onnxruntime that referenced this pull request Aug 11, 2025
### Description

Update Flash Attention to support softmax sink in GQA.

Changes:
- [x] Update flash attention to support head_sink
- [x] Add test_gqa.py to test cuda, and remove test_gqa_cuda.py.

Note that the sink is treated as scaled, while the elements in QK GEMMs
is not scaled. The sink value does not need scaling or softcap, and it
joins softmax with those scaled or soft-capped values. There are two
ways to add sink to softmax:
* One way is to [patch
normalize_softmax_lse](https://github.com/microsoft/onnxruntime/blob/1cf1aa786f6e7f7e6abd6fba8b8aea2e7a43092c/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h#L143-L178)
to use sink to update max and sum. Pros is major change in one function;
Cons is the logic is a little complex since row_max is unscaled, while
row_sum is scaled.
* Another way is to change softmax_rescale_o to handle the sink directly
in the first block of online softmax by using an unscaled sink value. It
is a robust way to keep core algorithm consistent. Cons is need change
in multiple places, and it is little hard to work with softcap.

This PR use the the first approach for easy integration.

Note: Memory efficient attention change will be in separated PR.

### Motivation and Context
microsoft#25269
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants