[Sm75] Add README link for initial Turing support#2379
Merged
tridao merged 1 commit intoDao-AILab:mainfrom Mar 25, 2026
Merged
Conversation
Member
|
Thanks! |
Contributor
Author
|
Thanks for the comment! Sure, I’ll keep the Turing implementation in a separate repo and share it here once it’s cleaned up. |
Contributor
Author
|
Hi @tridao , I just cleaned up the Turing repo. I think it's good to go now. Thanks again! |
tridao
approved these changes
Mar 25, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
FlashAttention Turing
This PR adds a link to the flash-attention-turing repo that provides support for Turing (SM75) architecture in FlashAttention, following #1533.
Features
Supports:
Does not support:
Performance
Benchmarks are reported on Nvidia T4 GPUs.
Forward pass
Up to 2.19x and 1.95x faster than PyTorch's Attention for non-causal and causal workloads.
On Turing GPUs, PyTorch's Attention uses Memory-Efficient Attention from xformers, since FlashAttention does not provide optimized kernels for SM75.
For long sequences, the forward kernel reaches up to 66% compute throughput.
Backward pass
The backward pass is split into two kernels: one for
dQand one fordKanddV.Up to 1.35x and 1.51x faster than PyTorch's Attention for non-causal and causal workloads.
For long sequences, the backward kernels reach up to 49% compute throughput for
dKanddV, and 45% fordQ.Correctness and numerical differences
From our tests in
test_flash_attn.py, we consistently observe maximum and mean absolute differences of ~1e-3 and ~1e-5 respectively relative to PyTorch's attention kernels.Thanks!