📚 Modern CUDA Learn Notes with PyTorch for Beginners: It includes Tensor/CUDA Cores, TF32/F16/BF16/F8, 📖150+ CUDA Kernels🔥🔥 with PyTorch bindings, 📖100+ LLM/VLM/CV/CUDA/CuTe🔥 blogs, 📖toy-hgemm⚡️⚡️ which can achieve 98%~100%
performance of cuBLAS, and 📖flash-attention-mma⚡️⚡️ using Tensor Cores with pure MMA PTX. Welcome to 🌟👆🏻star this repo to support me, many thanks ~ 🎉🎉
Currently, on NVIDIA L20, RTX 4090 and RTX 3080 Laptop, compared with cuBLAS's default Tensor Cores algorithm, the HGEMM (WMMA/MMA/CuTe)
in this repo (blue
🔵) can achieve 98%~100%
of its (orange
🟠) performance. Please check toy-hgemm library⚡️⚡️ or hgemm-tensorcores-mma⚡️⚡️ repo for more details.
CUDA Cores | Sliced K (Loop over K) | Tile Block (BMxBK) | Tile Thread (t 8x8) |
---|---|---|---|
✔️ | ✔️ | ✔️ | ✔️ |
WMMA (m16n16k16) | MMA (m16n8k16) | Pack LDST (128 bits) | SMEM Padding |
✔️ | ✔️ | ✔️ | ✔️ |
Copy Async | Tile MMA (More Threads) | Tile Warp (More Values) | Multi Stages (2/3/4) |
✔️ | ✔️ | ✔️ | ✔️ |
Reg Double Buffers | Block Swizzle | Warp Swizzle | SMEM Swizzle (CuTe) |
✔️ | ✔️ | ✔️ | ✔️ |
Collective Store (Warp Shfl) | Row Major (NN) | Col Major (TN) | SGEMM FP32/TF32 |
✔️ | ✔️ | ✔️ | ✔️ |
I have also implemented FlashAttention-2 using pure MMA PTX instructions, which supports features such as Multi-Stages, Tile MMA, Tile Warp, Shared KV SMEM, Fully Shared QKV SMEM, Prefetch Q s2r, Collective Store, etc. Please refer to flash-attention-mma⚡️⚡️ for more details.
Tensor Cores | Loop over Seqlen/Headdim | Tile Block (Br, Bc) | MMA (m16n8k16) |
---|---|---|---|
✔️ | ✔️ | ✔️ | ✔️ |
Pack LDST (128 bits) | SMEM Padding | Copy Async | Tile MMA (More Threads) |
✔️ | ✔️ | ✔️ | ✔️ |
Tile Warp (More Values) | Multi Stages (1/2) | Collective Store (Shfl) | Split KV/Q |
✔️ | ✔️ | ✔️ | ✔️ |
Shared KV SMEM | Fully Shared QKV SMEM | Prefetch Q s2r | SMEM/Block Swizzle |
✔️ | ✔️ | ✔️ | ? |
Currently, for small-scale attention (B<=4, H <=48, SeqLen <= 8192)
can run faster than offical FA2 on some Devices. However, for large-scale attention, there remains a performance gap. Performance is continuously being optimized. Stay tuned for updates ~ Example: B=1, H=8, N=8192, D=64 (NVIDIA RTX 3080 Laptop):
python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 # NVIDIA RTX 3080 Laptop
------------------------------------------------------------------------------------------------------------------------
B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 1617, Warmup: 1, Iters: 10
------------------------------------------------------------------------------------------------------------------------
B=1, H=8, N=8192, D=64, Warmup: 1, Iters: 10
mma(split-kv+stage1): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:5.586338ms, TFLOPS:25.08
mma(split-kv+stage2): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:5.326223ms, TFLOPS:26.31
mma(split-q+stage1): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:3.834152ms, TFLOPS:36.54
mma(split-q+stage2): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:4.328346ms, TFLOPS:32.37
mma(split-q+share-kv+stage1): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:2.636528ms, TFLOPS:53.15
mma(split-q+share-qkv+stage1): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:2.594471ms, TFLOPS:54.01
mma(split-q+share-qkv+stage2): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:2.574611ms, TFLOPS:54.42
(flash): ['0.01963806 ', '0.0145874 ', '-0.02593994 '], time:3.764462ms, TFLOPS:37.22
-----------------------------------------------------------------------------------------------------------------------
The Split KV
and Split Q
implementations have been carried out in flash-attention-mma⚡️⚡️ for performance comparison. The Split KV
method, which involves splitting all QKV across MMA (Warps), is slower than Split Q
policy, which splitting Q across MMA(Warps) and keep access KV for all MMA(Warps).
- 📚 Split KV (Basic, FlashAttention-1)
// Split QKV across MMA(Warps) using naive matmul MMA&Warp tiling policy.
// case: The layout of 8 MMA(2x4) [after] kWarpTileSeqLenQxkWarpTileSeqLenK(2x2) -> 32x2,32x2=64x64:
// | [64,64] | warp_KV 0 | warp_KV 1 | warp_KV 2 | warp_KV 3 |
// | warp_QP 0 |-- MMA 0,MMA 0 --|-- MMA 2,MMA 2 --|-- MMA 4,MMA 4 --|-- MMA 6,MMA 6 --|
// | warp_QP 0 |-- MMA 0,MMA 0 --|-- MMA 2,MMA 2 --|-- MMA 4,MMA 4 --|-- MMA 6,MMA 6 --|
// | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 2 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --|
// | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 2 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --|
__global__ void
flash_attn_mma_stages_split_kv_kernel(half* Q, // [B, H, N, D]
half* K, // [B, H, D, N] K^T transposed
half* V, // [B, H, N, D]
half* O, // [B, H, N, D]
int QKV_seqlen);
- 📚 Split Q (Faster, FlashAttention-2)
// Split Q across MMA(Warps) and keep access KV for all MMA(Warps),
// in order to reduce the comm between warps via smem and warp shuffle.
// case: MMA = m16n8k16, Br=16x4=64, Bc=8x8=64, layout: 4 warps
// | 64x64 | warp_KV 0 |
// | warp_QP 0 | MMA 0 ... MMA 0 (x8) |
// | warp_QP 1 | MMA 1 ... MMA 1 (x8) |
// | warp_QP 2 | MMA 2 ... MMA 2 (x8) |
// | warp_QP 3 | MMA 3 ... MMA 3 (x8) |
__global__ void
flash_attn_mma_stages_split_q_kernel(half* Q, // [B, H, N, D]
half* K, // [B, H, D, N] K^T transposed
half* V, // [B, H, N, D]
half* O, // [B, H, N, D]
int QKV_seqlen);
- 📚 Split Q + Shared KV SMEM (Faster+)
// K, V shared the same shared memory, improve block occupancy.
__global__ void
flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
half* K,
half* V,
half* O,
int QKV_seqlen);
- 📚 Split Q + Fully Shared QKV SMEM (Faster++)
// Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy.
__global__ void
flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
half* K,
half* V,
half* O,
int QKV_seqlen);
@misc{CUDA-Learn-Notes@2024,
title={CUDA-Learn-Notes: A Modern CUDA Learn Notes with PyTorch for Beginners},
url={https://github.com/DefTruth/CUDA-Learn-Notes},
note={Open-source software available at https://github.com/DefTruth/CUDA-Learn-Notes},
author={DefTruth etc},
year={2024}
}
📖 150+ CUDA Kernels 🔥🔥 (面试常考题目) (©️back👆🏻)
Workflow: custom CUDA kernel impl -> PyTorch Python bindings -> Run tests. 👉TIPS: *
= Tensor Cores(WMMA/MMA), otherwise, CUDA Cores; /
= not supported; ✔️
= supported; ❔
= in my plan.
📖 大模型|多模态|Diffusion|推理优化 (本人作者) (©️back👆🏻)
📖 CV推理部署|C++|算法|技术随笔 (本人作者) (©️back👆🏻)
📖 CUTLASS|CuTe|NCCL|CUDA|文章推荐 (其他作者) (©️back👆🏻)
💡说明: 本小节整理一些自己比较喜欢的文章。欢迎大家提PR推荐更多优秀的文章!
©️License (©️back👆🏻)
GNU General Public License v3.0
🎉Contribute (©️back👆🏻)
How to contribute? please check 🌤🌤CONTRIBUTE🎉🎉.