🤖FFPA: Yet antother Faster Flash Prefill Attention with O(1)⚡️GPU SRAM complexity for large headdim🐑
![](https://private-user-images.githubusercontent.com/31974251/403900940-4abfae2d-5a26-4f73-aaa2-d1e452a4215d.jpg?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzk1NzE4OTYsIm5iZiI6MTczOTU3MTU5NiwicGF0aCI6Ii8zMTk3NDI1MS80MDM5MDA5NDAtNGFiZmFlMmQtNWEyNi00ZjczLWFhYTItZDFlNDUyYTQyMTVkLmpwZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMTQlMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjE0VDIyMTk1NlomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPTU4ZGI3ODVkMzE4ZDg3NjM4ZjAxMTVjNDRiMDNhODM0NGEzMTQ4ODQ2MWJjNzJlYzc3NWUyNTE0MjZlOTU3NzYmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.-vNgQtLVHQs2njHVIfzXXyXfNIGLKm3ZET1VZuGRsdM)
🤖[WIP] FFPA: Yet antother Faster Flash Prefill Attention with O(1) SRAM complexity & O(d/4) or O(1) register complexity for large headdim (D > 256), almost 1.8x~3x 🎉 faster than SDPA EA with or without MMA Acc F32 on many devices: 📈L20 ~1.9x↑🎉, 📈A30 ~1.8x↑🎉, 📈3080 ~2.9x↑🎉, 📈4090 ~2.1x↑🎉. FFPA Attention Algo: Fine-grained tiling for large headim, FA-2 Attention Algo: Coarse-grained tiling for small headidm.
💡NOTE: This project is still in its early dev stages and now provides some kernels and benchmarks for reference. More features will be added in the future. (Welcome to 🌟👆🏻star this repo to support me ~)
@misc{ffpa-attn-mma@2025,
title={FFPA: Yet another Faster Flash Prefill Attention for large headdim.},
url={https://github.com/DefTruth/ffpa-attn-mma.git},
note={Open-source software available at https://github.com/DefTruth/ffpa-attn-mma.git},
author={DefTruth etc},
year={2025}
}
- 📖 Installation⚙️
- 📖 Python Testing👇
- 📖 FFPA L1~L3 Design💡
- 📈 FFPA L1: L20 ~1.9x↑🎉
- 📈 FFPA L1: A30 ~1.8x↑🎉
- 📈 FFPA L1: 3080 ~2.9x↑🎉
- 📈 FFPA L1: 4090 ~2.1x↑🎉
- 📖 Fully Fused MLA w/ FFPA🎉
We have extended FlashAttention for large headdim (D > 256) by implementing Fine-grained Tiling at the MMA level (GEMM style) for the Q@K^T and P@V matmul. This approach results in a constant SRAM usage of Br * 16 or Bc * 16 (Br = Bc) for Q, K, and V, leading to an overall SRAM complexity of O(2 * Br * 16) ≈ O(1) and a register complexity of O(d/4) or O(1). Consequently, this method allows us to extend headdim beyond 256 and achieve faster performance compared to SDPA with or without MMA Accumulation F32 (1.8x~3x 🎉 faster than SDPA EA).
We have named this new attention tiling technique FFPA: Faster Flash Prefill Attention. We have designed three (L1~L3)
levels of FFPA based on SRAM and register complexity considerations. All levels will not introduce any additional VRAM requirements, ensuring that the HBM memory complexity remains same as FlashAttention. 👇
- 📚L1: level 1, O(2xBrx16)≈O(1) SRAM complexity, ≈O(d/4) register complexity.
- 📚L2: level 2, O(2xBrx16)≈O(1) SRAM complexity, ≈O(1) register complexity + Q@K^T recomputation.
- 📚L3: level 3, O(2xBrx16)≈O(1) SRAM complexity, ≈O(1) register complexity + scaling O via HBM offloading.
By leveraging this approach, we can achieve better performance than SDPA EA for very large headdim (D > 256, FA-2 not supported
). Approximate SRAM and register complexity analysis for FFPA L1~L3 level is as follows: (d
=headdim, C,Br,Bc
=Constant, Br=Bc
, let O(C)≈O(1)) 👇
📚Complexity | 📚FFPA L1 | 📚FFPA L2 | 📚FFPA L3 | 📚FA-2 |
---|---|---|---|---|
SRAM | O(2xBrx16)≈O(1) | O(2xBrx16)≈O(1) | O(2xBrx16)≈O(1) | ≈O(3xBrxd), d↑ |
Register | ≈O(d/4), d↑ | O((Bc/16)x4+2C)≈O(1) | O((Bc/16)x4+2C)≈O(1) | ≈O(d/2), d↑ |
HBM | ≈FA2≈O(Nd), O | ≈FA2≈O(Nd), O | ≈FA2≈O(Nd), O | ≈O(Nd), O |
Extra HBM | ≈FA2≈O(N), m,l | ≈FA2≈O(N), m,l | ≈FA2≈O(N), m,l | ≈O(N), m,l |
📚👇Core Features🎉🎉: I have implemented FFPA L1~L3 using pure MMA PTX instructions, which supports many features such as Split-Q, SMEM Swizzle/Padding, QKV Multi-Stages(1~4), Tile MMAs/Warps, Mixed MMA F32/F16 Acc (Q@K^T MMA Acc F32 + P@V MMA Acc F16), Fully Shared QKV SMEM, Prefetch QKV g2s, Persist Q s2r/g2s, Fully QKV Fine-grained Tiling(GEMM style), Collective Store, etc.
📚Feature | 📚Feature | 📚Feature | 📚Feature |
---|---|---|---|
✔️Tensor Cores | ✔️MMA(m16n8k16) | ✔️Tile Block(Br, Bc) | ✔️Tile MMA/Warp |
✔️Split Q(FA-2) | ✔️Pack LDST(128 bits) | ✔️SMEM Swizzle/Pad | ✔️Copy Async |
✔️Reg Double Buffers | ✔️QKV Multi-Stages(1~4) | ✔️Collective Store(Shfl) | ✔️Prefetch QKV g2s |
✔️QKV Fine-grained Tiling | ✔️Shared QKV SMEM | ✔️Mixed MMA Acc | ✔️Persist Q s2r/g2s |
- 📚 case: FFPA
L1
kernel template signature: ffpa_attn_templates_L1.cuh
template<
const int kHeadDim, // Headdim, 32~1024
const int kMmaAtomM, // MMA Atom M, 16
const int kMmaAtomN, // MMA Atom N, 8
const int kMmaAtomK, // MMA Atom K, 16
const int kMmaTileSeqLenQ, // 4, more MMA(warp), M=16*4=64, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)]
const int kMmaTileSeqLenK, // 1, more MMA(warp), N=8*1 =8, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)]
const int kMmaTileSeqLenP, // 4, more MMA(warp), M=16*4=64, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ]
const int kMmaTileHeadDimV, // 1, more MMA(warp), N=8*1 =8, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ]
const int kWarpTileSeqLenQ, // 1, more values, M, Br=64*1=64, matmul M
const int kWarpTileSeqLenK, // 8, more values, N, Bc=8*8 =64, matmul N
const int kWarpTileSeqLenP, // 1, more values, M, Br=64*1=64, matmul M
const int kWarpTileHeadDimV, // 8, more values, N, d=8*(1|2|3|4|...)=8|...|32|64|96|128|...
const int kMmaAccFloat32QK, // 0/1, Q@K^T, 0 MMA Acc with fp16, 1 MMA Acc with fp32.
const int kMmaAccFloat32PV, // 0/1, P@V, 0 MMA Acc with fp16, 1 MMA Acc with fp32.
const int kOStorageAccFloat32, // 0/1, MMA Acc always be f32/f16, but O storage can be fp32 or half.
const int kPrefetchQK, // Prefetch QK at the Appropriate Time Point.
const int kPrefetchPV, // Prefetch V at the Appropriate Time Point.
const int kShareSmemQKV, // QKV share the same shared memory, reuse QK smem for V.
const int kPersistQs2r, // Persist load Q s2r for headdim < 512, more registers, but still keep O(1) SRAM.
const int kPersistQg2s, // Persist load Q g2s for headdim <= 320, more SRAM, but still keep register usage.
const int kRegPipeKV, // Registers Ping pong double buffers for ldmatrix s2r & mma computation overlapping.
const int kStageQK, // <= 4, may apply different multi stages policy for QK and V (<=4)
const int kStagePV, // <= 4, may apply different multi stages policy for QK and V (<=4)
const int kPadQ, // Pad Q/K/V 0,8; 0 -> smem swizzle, > 0 -> padding
const int kPadK, // Pad Q/K/V 0,8; 0 -> smem swizzle, > 0 -> padding
const int kPadV // Pad Q/K/V 0,8; 0 -> smem swizzle, > 0 -> padding
> __global__ void // Q, K, V, O -> [B, H, N, D]
// FFPA Attention Algo: Fine-grained tiling at MMA level for large headdim (d>=256),
// which can achieve 1.8x~3x🎉 faster than SDPA EA with or without MMA Acc F32.
ffpa_mma_stages_split_q_L1_large_d_template(half* Q, half* K, half* V, half* O, ...);
// FA-2 Attention Algo: Coarse-grained tiling at Attention level for small headdim (d<256),
// which can achieve 95%-150%🎉 performance as SDPA FA-2 BE with MMA Acc F32 for N<=4096,
// and achieve almost 1.2x~1.4x🎉 faster than SDPA FA-2 via Mixed MMA Acc(Q@K^T F32 +
// P@V F16) for all range N.
ffpa_mma_stages_split_q_L1_small_d_template(half* Q, half* K, half* V, half* O, ...);
- Python >= 3.10
- PyTorch >= 2.4.0, CUDA >= 12.4
- flash-attention >= 2.6.3 (for test)
- Recommended: PyTorch 2.5.1, CUDA 12.5
- Docker: nvcr.io/nvidia/pytorch:24.10-py3
The FFPA implemented in this repo can be install as a python library, namely, ffpa-attn
library (optional).
git clone https://github.com/DefTruth/ffpa-attn-mma.git
# clone, then, run bash .dev/install.sh directly or run commands:
python3 setup.py bdist_wheel && cd dist && python3 -m pip install *.whl # pip uninstall ffpa-attn -y
L1: level 1, O(2xBrx16)≈O(1) SRAM complexity, O(d/4) register complexity, the same GPU HBM memory complexity as FlashAttention. B=1, H=48, N=8192, D=320-1024(FA2 not supported 👀). (Notes, *
=MMA Acc F32, ^
=MMA Acc F16, Softmax Acc dtype is always be F32, T=TFLOPS, 👇Benchmark)
- 📚 NVIDIA L20 (
*
=MMA Acc F32,^
=MMA Acc F16,T
=TFLOPS, ~1.8x↑🎉)
Algorithm | 320 | 384 | 448 | 512 | 576 | 640 | 704 | 768 | 832 | 896 | 960 | 1024 |
---|---|---|---|---|---|---|---|---|---|---|---|---|
SDPA EA | 56T | 63T | 58T | 58T | 55T | 56T | 54T | 55T | 54T | 55T | 54T | 56T |
FFPA L1* | 102T | 102T | 103T | 104T | 103T | 95T | 95T | 95T | 95T | 96T | 95T | 94T |
Speedup | 1.82x | 1.62x | 1.78x | 1.79x | 1.87x | 1.7x | 1.76x | 1.73x | 1.76x | 1.75x | 1.76x | 1.68x |
FFPA L1^ | 104T | 103T | 103T | 102T | 104T | 103T | 102T | 94T | 94T | 94T | 100T | 100T |
Speedup | 1.86x | 1.63x | 1.78x | 1.76x | 1.89x | 1.84x | 1.89x | 1.71x | 1.74x | 1.71x | 1.85x | 1.79x |
- 📚 NVIDIA L20 (
*
=MMA Acc: QK F32 + PV F16,^
=MMA Acc F16,T
=TFLOPS, ~1.9x↑🎉)
Algorithm | 320 | 384 | 448 | 512 | 576 | 640 | 704 | 768 | 832 | 896 | 960 | 1024 |
---|---|---|---|---|---|---|---|---|---|---|---|---|
SDPA EA | 56T | 64T | 58T | 58T | 55T | 56T | 54T | 55T | 54T | 55T | 54T | 56T |
FFPA L1* | 105T | 102T | 104T | 103T | 105T | 95T | 95T | 94T | 94T | 94T | 102T | 101T |
Speedup | 1.88x | 1.59x | 1.79x | 1.78x | 1.91x | 1.7x | 1.76x | 1.71x | 1.74x | 1.71x | 1.89x | 1.8x |
FFPA L1^ | 104T | 103T | 103T | 102T | 103T | 103T | 102T | 94T | 94T | 94T | 100T | 100T |
Speedup | 1.86x | 1.61x | 1.78x | 1.76x | 1.87x | 1.84x | 1.89x | 1.71x | 1.74x | 1.71x | 1.85x | 1.79x |
- 📚 NVIDIA A30 (
*
=MMA Acc F32,^
=MMA Acc F16,T
=TFLOPS, ~1.8x↑🎉)
Algorithm | 320 | 384 | 448 | 512 | 576 | 640 | 704 | 768 | 832 | 896 | 960 | 1024 |
---|---|---|---|---|---|---|---|---|---|---|---|---|
SDPA EA | 25T | 25T | 24T | 24T | 24T | 24T | 23T | 22T | 22T | 22T | 22T | 18T |
FFPA L1* | 45T | 44T | 44T | 43T | 43T | 38T | 37T | 37T | 37T | 36T | 33T | 32T |
Speedup | 1.8x | 1.76x | 1.83x | 1.79x | 1.79x | 1.58x | 1.61x | 1.68x | 1.68x | 1.64x | 1.5x | 1.78x |
FFPA L1^ | 48T | 46T | 45T | 43T | 44T | 44T | 44T | 38T | 37T | 36T | 40T | 34T |
Speedup | 1.92x | 1.84x | 1.88x | 1.79x | 1.83x | 1.83x | 1.91x | 1.73x | 1.68x | 1.64x | 1.82x | 1.89x |
- 📚 NVIDIA A30 (
*
=MMA Acc: QK F32 + PV F16,^
=MMA Acc F16,T
=TFLOPS, ~1.9x↑🎉)
Algorithm | 320 | 384 | 448 | 512 | 576 | 640 | 704 | 768 | 832 | 896 | 960 | 1024 |
---|---|---|---|---|---|---|---|---|---|---|---|---|
SDPA EA | 25T | 25T | 24T | 24T | 24T | 24T | 23T | 22T | 22T | 22T | 22T | 18T |
FFPA L1* | 48T | 46T | 46T | 43T | 44T | 38T | 38T | 38T | 37T | 36T | 40T | 34T |
Speedup | 1.92x | 1.84x | 1.92x | 1.79x | 1.83x | 1.58x | 1.65x | 1.73x | 1.68x | 1.64x | 1.82x | 1.89x |
FFPA L1^ | 48T | 46T | 45T | 43T | 44T | 44T | 44T | 38T | 37T | 36T | 39T | 34T |
Speedup | 1.92x | 1.84x | 1.88x | 1.79x | 1.83x | 1.83x | 1.91x | 1.73x | 1.68x | 1.64x | 1.77x | 1.89x |
- 📚 NVIDIA RTX 3080 Laptop (
*
=MMA Acc F32,^
=MMA Acc F16,T
=TFLOPS, ~2.5x↑🎉)
Algorithm | 320 | 384 | 448 | 512 | 576 | 640 | 704 | 768 | 832 | 896 | 960 | 1024 |
---|---|---|---|---|---|---|---|---|---|---|---|---|
SDPA EA | 13T | 16T | 11T | 16T | 15T | 15T | 15T | 15T | 14T | 14T | 14T | 14T |
FFPA L1* | 33T | 31T | 30T | 30T | 30T | 27T | 27T | 26T | 26T | 26T | 26T | 25T |
Speedup | 2.54x | 1.94x | 2.73x | 1.88x | 2.0x | 1.8x | 1.8x | 1.73x | 1.86x | 1.86x | 1.86x | 1.79x |
FFPA L1^ | 43T | 41T | 39T | 39T | 39T | 39T | 39T | 36T | 34T | 33T | 31T | 33T |
Speedup | 3.31x | 2.56x | 3.55x | 2.44x | 2.6x | 2.6x | 2.6x | 2.4x | 2.43x | 2.36x | 2.21x | 2.36x |
- 📚 NVIDIA RTX 3080 Laptop (
*
=MMA Acc: QK F32 + PV F16,^
=MMA Acc F16,T
=TFLOPS, ~2.9x↑🎉)
Algorithm | 320 | 384 | 448 | 512 | 576 | 640 | 704 | 768 | 832 | 896 | 960 | 1024 |
---|---|---|---|---|---|---|---|---|---|---|---|---|
SDPA EA | 13T | 15T | 12T | 15T | 14T | 15T | 14T | 14T | 14T | 14T | 14T | 14T |
FFPA L1* | 38T | 36T | 34T | 35T | 34T | 31T | 32T | 31T | 30T | 28T | 27T | 27T |
Speedup | 2.92x | 2.4x | 2.83x | 2.33x | 2.43x | 2.07x | 2.29x | 2.21x | 2.14x | 2.0x | 1.93x | 1.93x |
FFPA L1^ | 44T | 41T | 39T | 39T | 38T | 39T | 39T | 36T | 34T | 32T | 31T | 33T |
Speedup | 3.38x | 2.73x | 3.25x | 2.6x | 2.71x | 2.6x | 2.79x | 2.57x | 2.43x | 2.29x | 2.21x | 2.36x |
- 📚 NVIDIA RTX 4090 (
*
=MMA Acc F32,^
=MMA Acc F16,T
=TFLOPS, ~1.8x↑🎉)
Algorithm | 320 | 384 | 448 | 512 | 576 | 640 | 704 | 768 | 832 | 896 | 960 | 1024 |
---|---|---|---|---|---|---|---|---|---|---|---|---|
SDPA EA | 81T | 94T | 85T | 85T | 79T | 81T | 79T | 80T | 79T | 80T | 78T | 78T |
FFPA L1* | 149T | 150T | 150T | 150T | 150T | 140T | 140T | 140T | 139T | 139T | 137T | 134T |
Speedup | 1.84x | 1.6x | 1.76x | 1.76x | 1.9x | 1.73x | 1.77x | 1.75x | 1.76x | 1.74x | 1.76x | 1.72x |
FFPA L1^ | 194T | 194T | 189T | 191T | 197T | 188T | 184T | 180T | 177T | 172T | 171T | 171T |
Speedup | 2.4x | 2.06x | 2.22x | 2.25x | 2.49x | 2.32x | 2.33x | 2.25x | 2.24x | 2.15x | 2.19x | 2.19x |
- 📚 NVIDIA RTX 4090 (
*
=MMA Acc: QK F32 + PV F16,^
=MMA Acc F16,T
=TFLOPS, ~2.1x↑🎉)
Algorithm | 320 | 384 | 448 | 512 | 576 | 640 | 704 | 768 | 832 | 896 | 960 | 1024 |
---|---|---|---|---|---|---|---|---|---|---|---|---|
SDPA EA | 82T | 92T | 85T | 84T | 78T | 81T | 79T | 80T | 78T | 79T | 77T | 78T |
FFPA L1* | 176T | 170T | 171T | 171T | 171T | 161T | 160T | 161T | 160T | 158T | 165T | 164T |
Speedup | 2.15x | 1.85x | 2.01x | 2.04x | 2.19x | 1.99x | 2.03x | 2.01x | 2.05x | 2.0x | 2.14x | 2.1x |
FFPA L1^ | 200T | 191T | 189T | 191T | 188T | 188T | 186T | 179T | 175T | 173T | 172T | 170T |
Speedup | 2.44x | 2.08x | 2.22x | 2.27x | 2.41x | 2.32x | 2.35x | 2.24x | 2.24x | 2.19x | 2.23x | 2.18x |
👇You can test many custom FFPA kernels via Python and figure out the difference in their performance. The --gen-bench
and --plot
options help you generate a benchmark table in Markdown style and speedup bar plots on your device. Contributions of your benchmark tables and plots are welcome via a PR 🎉🎉.
- 📚 case: B=1, H=48, N=8192, D=320(
FA2 not supported
)
# You can test on many devices, such as Volta, Ampere, Ada, Hopper, ...
cd tests && python3 test_ffpa_attn.py --B 1 --H 48 --N 8192 --show-all --D 320
---------------------------------------B=1, H=48, N=8192, D=320, Warmup: 1, Iters: 5--------------------
(sdpa): ['-0.02380371'], time:73.66518ms, TFLOPS:56.19 (+0.00 %)(~1.00x)
(ffpa+acc+f32+L1+stage1): ['-0.02378845'], time:52.87361ms, TFLOPS:78.28 (+39.32%)(~1.39x)
(ffpa+acc+f32+L1+stage2): ['-0.02378845'], time:40.84062ms, TFLOPS:101.35(+29.46%)(~1.80x)
(ffpa+acc+f32+L1+stage3): ['-0.02378845'], time:40.49534ms, TFLOPS:102.21(+0.85 %)(~1.82x)
(ffpa+acc+f32+L1+stage4): ['-0.02378845'], time:40.88177ms, TFLOPS:101.25(+0.00 %)(~1.80x)
(ffpa+acc+f16+L1+stage1): ['-0.02378845'], time:53.43298ms, TFLOPS:77.46 (+0.00 %)(~1.38x)
(ffpa+acc+f16+L1+stage2): ['-0.02378845'], time:39.76068ms, TFLOPS:104.10(+1.85 %)(~1.85x)
(ffpa+acc+f16+L1+stage3): ['-0.02378845'], time:39.54901ms, TFLOPS:104.66(+0.54 %)(~1.86x)
(ffpa+acc+f16+L1+stage4): ['-0.02378845'], time:41.06554ms, TFLOPS:100.79(+0.00 %)(~1.79x)
--------------------------------------------------------------------------------------------------------
- 📚 case: Generate benchmark table and speedup bar plots on Your device.
cd tests && pip install matplotlib && python3 test_ffpa_attn.py --gen-bench --show-all --plot
- 📚 case: Compare small headdim (d<256, e.g 64), FFPA-L1 vs SDPA FA-2 BE.
# Enable ffpa-attn small d kernel which using coarse-grained tiling method.
export ENABLE_FFPA_PERSIST_Q_G2S=1 && export ENABLE_FFPA_PERSIST_KV_G2S=1
cd tests && python3 test_ffpa_attn.py --B 1 --H 32 --N 1024 --check --show-all --D 64 # NVIDIA L20
---------------------------------------B=1, H=32, N=1024, D=64, Warmup: 1, Iters: 5--------------------
(sdpa): ['0.00802612'], time:0.148057ms, TFLOPS:59.14 (+0.00 %)(~1.00x)
(ffpa+acc+f32+L1+stage1): ['0.00803375'], time:0.103807ms, TFLOPS:84.34 (+42.63%)(~1.43x)
(ffpa+acc+f32+L1+stage2): ['0.00803375'], time:0.102233ms, TFLOPS:85.64 (+1.54 %)(~1.45x)
(ffpa+acc+f32+L1+stage3): ['0.00803375'], time:0.102519ms, TFLOPS:85.40 (+0.00 %)(~1.44x)
(ffpa+acc+f32+L1+stage4): ['0.00803375'], time:0.102043ms, TFLOPS:85.80 (+0.19 %)(~1.45x)
(ffpa+acc+f16+L1+stage1): ['0.00795746'], time:0.104713ms, TFLOPS:83.61 (+0.00 %)(~1.41x)
(ffpa+acc+f16+L1+stage2): ['0.00795746'], time:0.102949ms, TFLOPS:85.05 (+0.00 %)(~1.44x)
(ffpa+acc+f16+L1+stage3): ['0.00795746'], time:0.108957ms, TFLOPS:80.36 (+0.00 %)(~1.36x)
(ffpa+acc+f16+L1+stage4): ['0.00795746'], time:0.103282ms, TFLOPS:84.77 (+0.00 %)(~1.43x)
--------------------------------------------------------------------------------------------------------
cd tests && python3 test_ffpa_attn.py --B 1 --H 32 --N 4096 --check --show-all --D 64 # NVIDIA L20
-------------------------B=1, H=32, N=4096, D=64, Warmup: 1, Iters: 5-----------------------------------
(sdpa): ['0.01959229'], time:1.397752ms, TFLOPS:100.24(+0.00 %)(~1.00x)
(ffpa+acc+f32+L1+stage1): ['0.01959229'], time:1.368856ms, TFLOPS:102.36(+2.11 %)(~1.02x)
(ffpa+acc+f32+L1+stage2): ['0.01959229'], time:1.367807ms, TFLOPS:102.44(+0.08 %)(~1.02x)
(ffpa+acc+f32+L1+stage3): ['0.01959229'], time:1.367855ms, TFLOPS:102.43(+0.00 %)(~1.02x)
(ffpa+acc+f32+L1+stage4): ['0.01959229'], time:1.368045ms, TFLOPS:102.42(+0.00 %)(~1.02x)
(ffpa+acc+f16+L1+stage1): ['0.01957703'], time:1.389312ms, TFLOPS:100.85(+0.00 %)(~1.01x)
(ffpa+acc+f16+L1+stage2): ['0.01957703'], time:1.388311ms, TFLOPS:100.92(+0.00 %)(~1.01x)
(ffpa+acc+f16+L1+stage3): ['0.01957703'], time:1.386976ms, TFLOPS:101.02(+0.00 %)(~1.01x)
(ffpa+acc+f16+L1+stage4): ['0.01957703'], time:1.387834ms, TFLOPS:100.96(+0.00 %)(~1.01x)
--------------------------------------------------------------------------------------------------------
💡NOTE: Please check all configurable environment variables in env.py.
Extending the support of FA for large headdim is meaningful in the context of DeepSeek MLA. For example, when FA supports headdim values greater than 512, we can achieve fully Fused MLA into a single CUDA kernel, after W_UK/W_UV are absorbed into W_Q/W_O (resulting in C_kv/C_q with dc/dc' >= 512
). TODO list👇:
- 📚Fully Fused MLA into a single CUDA kernel using FFPA Algo and Tensor Cores.
GNU General Public License v3.0
How to contribute? Wecome to star⭐️ this repo to support me👆🏻 ~