From ce095b5c67521bfd8d29812abecdd698ceaa84fc Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Sat, 19 Oct 2024 17:57:56 +0800 Subject: [PATCH] [HGEMM] Add GeForce RTX 3080 Laptop benchmark (#94) * update hgemm benchmark * update hgemm benchmark --- hgemm/README.md | 258 ++++++++- hgemm/hgemm.cu | 9 +- hgemm/hgemm.py | 143 +++-- hgemm/hgemm_wmma_stage.cu | 1106 +++++++++++++++++++++++++++++++------ 4 files changed, 1278 insertions(+), 238 deletions(-) diff --git a/hgemm/README.md b/hgemm/README.md index ea0e25c7..fc313709 100755 --- a/hgemm/README.md +++ b/hgemm/README.md @@ -29,8 +29,14 @@ ## 目前性能 +- NVIDIA L20 + 目前最优的实现,在L20上(理论Tensor Cores FP16算力为 119.5 TFLOPS),能达到cuBLAS大概95%~98%左右的性能(105-113 TFLOPS vs 105-115 TFLOPS),部分case会超越cuBLAS。已知问题为bank conflicts没有完全消除,目前通过padding的方式缓解bank conflicts会导致shared memory浪费,也会影响SM occupancy。并且尚未手工实现Warp swizzle(受限于WMMA API的灵活性以及本人的能力),后续将会尝试通过MMA PTX实现warp swizzle。 +- NVIDIA GeForce RTX 3080 Laptop + +在NVIDIA GeForce RTX 3080 Laptop上测试,使用mma4x4_warp4x4(16 MMA m16n16k16 ops, warp tile 64x64)以及Thread block swizzle,大部分case能持平甚至超过cuBLAS。 + ## 共享内存 Bank Conflicts 含义:在访问shared memory时,因多个线程读写同一个Bank中的不同数据地址时,导致shared memory 并发读写 退化 成顺序读写的现象叫做Bank Conflict; @@ -229,7 +235,9 @@ nsys profile --stats=true -t cuda,osrt,nvtx -o hgemm.prof --force-overwrite true ```bash # 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ... export TORCH_CUDA_ARCH_LIST=Ada -python3 hgemm.py +python3 hgemm.py # default, test some wmma kernels for all MNK +python3 hgemm.py --wmma # test all wmma kernels for all MNK +python3 hgemm.py --M 16384 --N 16384 --K 8192 --i 10 --wmma # test all wmma kernels for specific MNK ``` 输出: @@ -723,3 +731,251 @@ python3 hgemm.py f16_th: ['-67.375 ', '14.9609375'], time:38.53211ms, swizzle: NOOP, TFLOPS: 114.14(+0.46%) ---------------------------------------------------------------------------------------------------------------------------------- ``` + +- NVIDIA GeForce RTX 3080 Laptop +```bash +python3 hgemm.py --wmma --no-default +---------------------------------------------------------------------------------------------------------------------------------- + M=4096, N=4096, K=2048, Warmup=5, Iters=20, 1/27 +---------------------------------------------------------------------------------------------------------------------------------- + (mma4x4+warp4x4+stage3+dsmem): ['-34.9375 ', '2.25585938'], time:1.397085ms, swizzle: NOOP, TFLOPS: 49.19 (+0.00%) + (mma4x4+warp4x4+stage2+dsmem): ['-34.9375 ', '2.25585938'], time:1.632452ms, swizzle: NOOP, TFLOPS: 42.10 + (mma4x4+warp4x4+stage3+dsmem+swizzle): ['-34.9375 ', '2.25585938'], time:1.392316ms, swizzle: 1024, TFLOPS: 49.36 (+0.34%) + (mma4x4+warp4x4+stage2+dsmem+swizzle): ['-34.9375 ', '2.25585938'], time:1.537656ms, swizzle: 1024, TFLOPS: 44.69 + (cublas): ['-34.90625 ', '2.21875 '], time:1.072788ms, swizzle: NOOP, TFLOPS: 64.06 (+29.78%) +---------------------------------------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------------------------------------- + M=4096, N=4096, K=4096, Warmup=5, Iters=20, 2/27 +---------------------------------------------------------------------------------------------------------------------------------- + (mma4x4+warp4x4+stage3+dsmem): ['10.8515625', '9.4140625 '], time:3.154301ms, swizzle: NOOP, TFLOPS: 43.57 (+0.00%) + (mma4x4+warp4x4+stage2+dsmem): ['10.8515625', '9.4140625 '], time:3.152799ms, swizzle: NOOP, TFLOPS: 43.59 (+0.05%) + (mma4x4+warp4x4+stage3+dsmem+swizzle): ['10.8515625', '9.4140625 '], time:2.640366ms, swizzle: 1024, TFLOPS: 52.05 (+19.41%) + (mma4x4+warp4x4+stage2+dsmem+swizzle): ['10.8515625', '9.4140625 '], time:3.021883ms, swizzle: 1024, TFLOPS: 45.48 + (cublas): ['10.8515625', '9.4140625 '], time:2.330613ms, swizzle: NOOP, TFLOPS: 58.97 (+13.29%) +---------------------------------------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------------------------------------- + M=4096, N=4096, K=8192, Warmup=5, Iters=20, 3/27 +---------------------------------------------------------------------------------------------------------------------------------- + (mma4x4+warp4x4+stage3+dsmem): ['68.375 ', '-2.234375 '], time:5.776286ms, swizzle: NOOP, TFLOPS: 47.59 (+0.00%) + (mma4x4+warp4x4+stage2+dsmem): ['68.375 ', '-2.234375 '], time:6.212115ms, swizzle: NOOP, TFLOPS: 44.25 + (mma4x4+warp4x4+stage3+dsmem+swizzle): ['68.375 ', '-2.234375 '], time:5.236458ms, swizzle: 1024, TFLOPS: 52.49 (+10.31%) + (mma4x4+warp4x4+stage2+dsmem+swizzle): ['68.375 ', '-2.234375 '], time:5.674219ms, swizzle: 1024, TFLOPS: 48.44 + (cublas): ['68.375 ', '-2.234375 '], time:5.311441ms, swizzle: NOOP, TFLOPS: 51.75 +---------------------------------------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------------------------------------- + M=4096, N=8192, K=2048, Warmup=5, Iters=20, 4/27 +---------------------------------------------------------------------------------------------------------------------------------- + (mma4x4+warp4x4+stage3+dsmem): ['-34.9375 ', '2.25585938'], time:3.303718ms, swizzle: NOOP, TFLOPS: 41.60 (+0.00%) + (mma4x4+warp4x4+stage2+dsmem): ['-34.9375 ', '2.25585938'], time:3.193497ms, swizzle: NOOP, TFLOPS: 43.04 (+3.45%) + (mma4x4+warp4x4+stage3+dsmem+swizzle): ['-34.9375 ', '2.25585938'], time:2.624654ms, swizzle: 2048, TFLOPS: 52.36 (+21.67%) + (mma4x4+warp4x4+stage2+dsmem+swizzle): ['-34.9375 ', '2.25585938'], time:2.863550ms, swizzle: 2048, TFLOPS: 48.00 + (cublas): ['-34.90625 ', '2.21875 '], time:2.649235ms, swizzle: NOOP, TFLOPS: 51.88 +---------------------------------------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------------------------------------- + M=4096, N=8192, K=4096, Warmup=5, Iters=20, 5/27 +---------------------------------------------------------------------------------------------------------------------------------- + (mma4x4+warp4x4+stage3+dsmem): ['10.8515625', '9.4140625 '], time:5.747509ms, swizzle: NOOP, TFLOPS: 47.83 (+0.00%) + (mma4x4+warp4x4+stage2+dsmem): ['10.8515625', '9.4140625 '], time:6.356692ms, swizzle: NOOP, TFLOPS: 43.24 + (mma4x4+warp4x4+stage3+dsmem+swizzle): ['10.8515625', '9.4140625 '], time:5.048251ms, swizzle: 2048, TFLOPS: 54.45 (+13.85%) + (mma4x4+warp4x4+stage2+dsmem+swizzle): ['10.8515625', '9.4140625 '], time:5.489063ms, swizzle: 2048, TFLOPS: 50.08 + (cublas): ['10.8515625', '9.4140625 '], time:6.013441ms, swizzle: NOOP, TFLOPS: 45.71 +---------------------------------------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------------------------------------- + M=4096, N=8192, K=8192, Warmup=5, Iters=20, 6/27 +---------------------------------------------------------------------------------------------------------------------------------- + (mma4x4+warp4x4+stage3+dsmem): ['68.375 ', '-2.234375 '], time:11.15694ms, swizzle: NOOP, TFLOPS: 49.27 (+0.00%) + (mma4x4+warp4x4+stage2+dsmem): ['68.375 ', '-2.234375 '], time:12.09821ms, swizzle: NOOP, TFLOPS: 45.44 + (mma4x4+warp4x4+stage3+dsmem+swizzle): ['68.375 ', '-2.234375 '], time:9.958195ms, swizzle: 2048, TFLOPS: 55.21 (+12.04%) + (mma4x4+warp4x4+stage2+dsmem+swizzle): ['68.375 ', '-2.234375 '], time:10.67364ms, swizzle: 2048, TFLOPS: 51.51 + (cublas): ['68.375 ', '-2.234375 '], time:12.02430ms, swizzle: NOOP, TFLOPS: 45.72 +---------------------------------------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------------------------------------- + M=4096, N=16384, K=2048, Warmup=5, Iters=20, 7/27 +---------------------------------------------------------------------------------------------------------------------------------- + (mma4x4+warp4x4+stage3+dsmem): ['-34.9375 ', '2.25585938'], time:6.608533ms, swizzle: NOOP, TFLOPS: 41.59 (+0.00%) + (mma4x4+warp4x4+stage2+dsmem): ['-34.9375 ', '2.25585938'], time:6.812095ms, swizzle: NOOP, TFLOPS: 40.35 + (mma4x4+warp4x4+stage3+dsmem+swizzle): ['-34.9375 ', '2.25585938'], time:5.446910ms, swizzle: 4096, TFLOPS: 50.46 (+21.33%) + (mma4x4+warp4x4+stage2+dsmem+swizzle): ['-34.9375 ', '2.25585938'], time:5.769944ms, swizzle: 4096, TFLOPS: 47.64 + (cublas): ['-34.90625 ', '2.21875 '], time:6.295609ms, swizzle: NOOP, TFLOPS: 43.66 +---------------------------------------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------------------------------------- + M=4096, N=16384, K=4096, Warmup=5, Iters=20, 8/27 +---------------------------------------------------------------------------------------------------------------------------------- + (mma4x4+warp4x4+stage3+dsmem): ['10.8515625', '9.4140625 '], time:11.90752ms, swizzle: NOOP, TFLOPS: 46.17 (+0.00%) + (mma4x4+warp4x4+stage2+dsmem): ['10.8515625', '9.4140625 '], time:12.66958ms, swizzle: NOOP, TFLOPS: 43.39 + (mma4x4+warp4x4+stage3+dsmem+swizzle): ['10.8515625', '9.4140625 '], time:10.72070ms, swizzle: 4096, TFLOPS: 51.28 (+11.07%) + (mma4x4+warp4x4+stage2+dsmem+swizzle): ['10.8515625', '9.4140625 '], time:11.09249ms, swizzle: 4096, TFLOPS: 49.56 + (cublas): ['10.8515625', '9.4140625 '], time:9.910416ms, swizzle: NOOP, TFLOPS: 55.47 (+8.18%) +---------------------------------------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------------------------------------- + M=4096, N=16384, K=8192, Warmup=5, Iters=20, 9/27 +---------------------------------------------------------------------------------------------------------------------------------- + (mma4x4+warp4x4+stage3+dsmem): ['68.375 ', '-2.234375 '], time:23.75357ms, swizzle: NOOP, TFLOPS: 46.29 (+0.00%) + (mma4x4+warp4x4+stage2+dsmem): ['68.375 ', '-2.234375 '], time:25.33891ms, swizzle: NOOP, TFLOPS: 43.39 + (mma4x4+warp4x4+stage3+dsmem+swizzle): ['68.375 ', '-2.234375 '], time:20.78440ms, swizzle: 4096, TFLOPS: 52.90 (+14.29%) + (mma4x4+warp4x4+stage2+dsmem+swizzle): ['68.375 ', '-2.234375 '], time:22.58212ms, swizzle: 4096, TFLOPS: 48.69 + (cublas): ['68.375 ', '-2.234375 '], time:23.13928ms, swizzle: NOOP, TFLOPS: 47.52 +---------------------------------------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------------------------------------- + M=8192, N=4096, K=2048, Warmup=5, Iters=20, 10/27 +---------------------------------------------------------------------------------------------------------------------------------- + (mma4x4+warp4x4+stage3+dsmem): ['-34.9375 ', '2.25585938'], time:3.206682ms, swizzle: NOOP, TFLOPS: 42.86 (+0.00%) + (mma4x4+warp4x4+stage2+dsmem): ['-34.9375 ', '2.25585938'], time:3.255009ms, swizzle: NOOP, TFLOPS: 42.22 + (mma4x4+warp4x4+stage3+dsmem+swizzle): ['-34.9375 ', '2.25585938'], time:2.551007ms, swizzle: 1024, TFLOPS: 53.88 (+25.70%) + (mma4x4+warp4x4+stage2+dsmem+swizzle): ['-34.9375 ', '2.25585938'], time:2.943944ms, swizzle: 1024, TFLOPS: 46.69 + (cublas): ['-34.90625 ', '2.21875 '], time:2.616691ms, swizzle: NOOP, TFLOPS: 52.52 +---------------------------------------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------------------------------------- + M=8192, N=4096, K=4096, Warmup=5, Iters=20, 11/27 +---------------------------------------------------------------------------------------------------------------------------------- + (mma4x4+warp4x4+stage3+dsmem): ['10.8515625', '9.4140625 '], time:5.581545ms, swizzle: NOOP, TFLOPS: 49.25 (+0.00%) + (mma4x4+warp4x4+stage2+dsmem): ['10.8515625', '9.4140625 '], time:5.918717ms, swizzle: NOOP, TFLOPS: 46.44 + (mma4x4+warp4x4+stage3+dsmem+swizzle): ['10.8515625', '9.4140625 '], time:5.013823ms, swizzle: 1024, TFLOPS: 54.82 (+11.32%) + (mma4x4+warp4x4+stage2+dsmem+swizzle): ['10.8515625', '9.4140625 '], time:5.475091ms, swizzle: 1024, TFLOPS: 50.21 + (cublas): ['10.8515625', '9.4140625 '], time:5.620026ms, swizzle: NOOP, TFLOPS: 48.91 +---------------------------------------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------------------------------------- + M=8192, N=4096, K=8192, Warmup=5, Iters=20, 12/27 +---------------------------------------------------------------------------------------------------------------------------------- + (mma4x4+warp4x4+stage3+dsmem): ['68.375 ', '-2.234375 '], time:10.63799ms, swizzle: NOOP, TFLOPS: 51.68 (+0.00%) + (mma4x4+warp4x4+stage2+dsmem): ['68.375 ', '-2.234375 '], time:11.95423ms, swizzle: NOOP, TFLOPS: 45.99 + (mma4x4+warp4x4+stage3+dsmem+swizzle): ['68.375 ', '-2.234375 '], time:10.08455ms, swizzle: 1024, TFLOPS: 54.51 (+5.49%) + (mma4x4+warp4x4+stage2+dsmem+swizzle): ['68.375 ', '-2.234375 '], time:10.80915ms, swizzle: 1024, TFLOPS: 50.86 + (cublas): ['68.375 ', '-2.234375 '], time:12.14854ms, swizzle: NOOP, TFLOPS: 45.25 +---------------------------------------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------------------------------------- + M=8192, N=8192, K=2048, Warmup=5, Iters=20, 13/27 +---------------------------------------------------------------------------------------------------------------------------------- + (mma4x4+warp4x4+stage3+dsmem): ['-34.9375 ', '2.25585938'], time:6.046414ms, swizzle: NOOP, TFLOPS: 45.46 (+0.00%) + (mma4x4+warp4x4+stage2+dsmem): ['-34.9375 ', '2.25585938'], time:6.623530ms, swizzle: NOOP, TFLOPS: 41.50 + (mma4x4+warp4x4+stage3+dsmem+swizzle): ['-34.9375 ', '2.25585938'], time:5.341410ms, swizzle: 2048, TFLOPS: 51.46 (+13.20%) + (mma4x4+warp4x4+stage2+dsmem+swizzle): ['-34.9375 ', '2.25585938'], time:5.689215ms, swizzle: 2048, TFLOPS: 48.32 + (cublas): ['-34.90625 ', '2.21875 '], time:6.602764ms, swizzle: NOOP, TFLOPS: 41.63 +---------------------------------------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------------------------------------- + M=8192, N=8192, K=4096, Warmup=5, Iters=20, 14/27 +---------------------------------------------------------------------------------------------------------------------------------- + (mma4x4+warp4x4+stage3+dsmem): ['10.8515625', '9.4140625 '], time:11.54751ms, swizzle: NOOP, TFLOPS: 47.61 (+0.00%) + (mma4x4+warp4x4+stage2+dsmem): ['10.8515625', '9.4140625 '], time:12.49833ms, swizzle: NOOP, TFLOPS: 43.99 + (mma4x4+warp4x4+stage3+dsmem+swizzle): ['10.8515625', '9.4140625 '], time:10.34743ms, swizzle: 2048, TFLOPS: 53.13 (+11.60%) + (mma4x4+warp4x4+stage2+dsmem+swizzle): ['10.8515625', '9.4140625 '], time:10.89727ms, swizzle: 2048, TFLOPS: 50.45 + (cublas): ['10.8515625', '9.4140625 '], time:11.89055ms, swizzle: NOOP, TFLOPS: 46.23 +---------------------------------------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------------------------------------- + M=8192, N=8192, K=8192, Warmup=5, Iters=20, 15/27 +---------------------------------------------------------------------------------------------------------------------------------- + (mma4x4+warp4x4+stage3+dsmem): ['68.375 ', '-2.234375 '], time:23.22742ms, swizzle: NOOP, TFLOPS: 47.34 (+0.00%) + (mma4x4+warp4x4+stage2+dsmem): ['68.375 ', '-2.234375 '], time:25.00588ms, swizzle: NOOP, TFLOPS: 43.97 + (mma4x4+warp4x4+stage3+dsmem+swizzle): ['68.375 ', '-2.234375 '], time:20.04830ms, swizzle: 2048, TFLOPS: 54.84 (+15.86%) + (mma4x4+warp4x4+stage2+dsmem+swizzle): ['68.375 ', '-2.234375 '], time:21.89767ms, swizzle: 2048, TFLOPS: 50.21 + (cublas): ['68.375 ', '-2.234375 '], time:23.18794ms, swizzle: NOOP, TFLOPS: 47.42 +---------------------------------------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------------------------------------- + M=8192, N=16384, K=2048, Warmup=5, Iters=20, 16/27 +---------------------------------------------------------------------------------------------------------------------------------- + (mma4x4+warp4x4+stage3+dsmem): ['-34.9375 ', '2.25585938'], time:12.24069ms, swizzle: NOOP, TFLOPS: 44.91 (+0.00%) + (mma4x4+warp4x4+stage2+dsmem): ['-34.9375 ', '2.25585938'], time:13.07930ms, swizzle: NOOP, TFLOPS: 42.03 + (mma4x4+warp4x4+stage3+dsmem+swizzle): ['-34.9375 ', '2.25585938'], time:10.82205ms, swizzle: 4096, TFLOPS: 50.80 (+13.11%) + (mma4x4+warp4x4+stage2+dsmem+swizzle): ['-34.9375 ', '2.25585938'], time:11.43186ms, swizzle: 4096, TFLOPS: 48.09 + (cublas): ['-34.90625 ', '2.21875 '], time:13.87636ms, swizzle: NOOP, TFLOPS: 39.62 +---------------------------------------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------------------------------------- + M=8192, N=16384, K=4096, Warmup=5, Iters=20, 17/27 +---------------------------------------------------------------------------------------------------------------------------------- + (mma4x4+warp4x4+stage3+dsmem): ['10.8515625', '9.4140625 '], time:23.84941ms, swizzle: NOOP, TFLOPS: 46.10 (+0.00%) + (mma4x4+warp4x4+stage2+dsmem): ['10.8515625', '9.4140625 '], time:31.07695ms, swizzle: NOOP, TFLOPS: 35.38 + (mma4x4+warp4x4+stage3+dsmem+swizzle): ['10.8515625', '9.4140625 '], time:23.16045ms, swizzle: 4096, TFLOPS: 47.47 (+2.97%) + (mma4x4+warp4x4+stage2+dsmem+swizzle): ['10.8515625', '9.4140625 '], time:25.17983ms, swizzle: 4096, TFLOPS: 43.67 + (cublas): ['10.8515625', '9.4140625 '], time:20.92361ms, swizzle: NOOP, TFLOPS: 52.55 (+10.69%) +---------------------------------------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------------------------------------- + M=8192, N=16384, K=8192, Warmup=5, Iters=20, 18/27 +---------------------------------------------------------------------------------------------------------------------------------- + (mma4x4+warp4x4+stage3+dsmem): ['68.375 ', '-2.234375 '], time:48.17764ms, swizzle: NOOP, TFLOPS: 45.64 (+0.00%) + (mma4x4+warp4x4+stage2+dsmem): ['68.375 ', '-2.234375 '], time:51.66683ms, swizzle: NOOP, TFLOPS: 42.56 + (mma4x4+warp4x4+stage3+dsmem+swizzle): ['68.375 ', '-2.234375 '], time:42.50290ms, swizzle: 4096, TFLOPS: 51.74 (+13.35%) + (mma4x4+warp4x4+stage2+dsmem+swizzle): ['68.375 ', '-2.234375 '], time:46.67718ms, swizzle: 4096, TFLOPS: 47.11 + (cublas): ['68.375 ', '-2.234375 '], time:45.62001ms, swizzle: NOOP, TFLOPS: 48.20 +---------------------------------------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------------------------------------- + M=16384, N=4096, K=2048, Warmup=5, Iters=20, 19/27 +---------------------------------------------------------------------------------------------------------------------------------- + (mma4x4+warp4x4+stage3+dsmem): ['-34.9375 ', '2.25585938'], time:5.999112ms, swizzle: NOOP, TFLOPS: 45.82 (+0.00%) + (mma4x4+warp4x4+stage2+dsmem): ['-34.9375 ', '2.25585938'], time:6.952166ms, swizzle: NOOP, TFLOPS: 39.54 + (mma4x4+warp4x4+stage3+dsmem+swizzle): ['-34.9375 ', '2.25585938'], time:5.714607ms, swizzle: 1024, TFLOPS: 48.10 (+4.98%) + (mma4x4+warp4x4+stage2+dsmem+swizzle): ['-34.9375 ', '2.25585938'], time:5.846762ms, swizzle: 1024, TFLOPS: 47.01 + (cublas): ['-34.9375 ', '2.25585938'], time:5.578041ms, swizzle: NOOP, TFLOPS: 49.28 (+2.45%) +---------------------------------------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------------------------------------- + M=16384, N=4096, K=4096, Warmup=5, Iters=20, 20/27 +---------------------------------------------------------------------------------------------------------------------------------- + (mma4x4+warp4x4+stage3+dsmem): ['10.8515625', '9.4140625 '], time:11.36004ms, swizzle: NOOP, TFLOPS: 48.39 (+0.00%) + (mma4x4+warp4x4+stage2+dsmem): ['10.8515625', '9.4140625 '], time:12.24460ms, swizzle: NOOP, TFLOPS: 44.90 + (mma4x4+warp4x4+stage3+dsmem+swizzle): ['10.8515625', '9.4140625 '], time:10.57424ms, swizzle: 1024, TFLOPS: 51.99 (+7.43%) + (mma4x4+warp4x4+stage2+dsmem+swizzle): ['10.8515625', '9.4140625 '], time:11.31019ms, swizzle: 1024, TFLOPS: 48.61 + (cublas): ['10.8515625', '9.4140625 '], time:10.14137ms, swizzle: NOOP, TFLOPS: 54.21 (+4.27%) +---------------------------------------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------------------------------------- + M=16384, N=4096, K=8192, Warmup=5, Iters=20, 21/27 +---------------------------------------------------------------------------------------------------------------------------------- + (mma4x4+warp4x4+stage3+dsmem): ['68.375 ', '-2.234375 '], time:21.54934ms, swizzle: NOOP, TFLOPS: 51.02 (+0.00%) + (mma4x4+warp4x4+stage2+dsmem): ['68.375 ', '-2.234375 '], time:25.34153ms, swizzle: NOOP, TFLOPS: 43.39 + (mma4x4+warp4x4+stage3+dsmem+swizzle): ['68.375 ', '-2.234375 '], time:21.18096ms, swizzle: 1024, TFLOPS: 51.91 (+1.74%) + (mma4x4+warp4x4+stage2+dsmem+swizzle): ['68.375 ', '-2.234375 '], time:22.19107ms, swizzle: 1024, TFLOPS: 49.55 + (cublas): ['68.375 ', '-2.234375 '], time:23.78721ms, swizzle: NOOP, TFLOPS: 46.22 +---------------------------------------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------------------------------------- + M=16384, N=8192, K=2048, Warmup=5, Iters=20, 22/27 +---------------------------------------------------------------------------------------------------------------------------------- + (mma4x4+warp4x4+stage3+dsmem): ['-34.9375 ', '2.25585938'], time:12.14342ms, swizzle: NOOP, TFLOPS: 45.27 (+0.00%) + (mma4x4+warp4x4+stage2+dsmem): ['-34.9375 ', '2.25585938'], time:13.07780ms, swizzle: NOOP, TFLOPS: 42.04 + (mma4x4+warp4x4+stage3+dsmem+swizzle): ['-34.9375 ', '2.25585938'], time:10.68298ms, swizzle: 2048, TFLOPS: 51.46 (+13.67%) + (mma4x4+warp4x4+stage2+dsmem+swizzle): ['-34.9375 ', '2.25585938'], time:11.51511ms, swizzle: 2048, TFLOPS: 47.74 + (cublas): ['-34.9375 ', '2.25585938'], time:12.36820ms, swizzle: NOOP, TFLOPS: 44.45 +---------------------------------------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------------------------------------- + M=16384, N=8192, K=4096, Warmup=5, Iters=20, 23/27 +---------------------------------------------------------------------------------------------------------------------------------- + (mma4x4+warp4x4+stage3+dsmem): ['10.8515625', '9.4140625 '], time:23.26002ms, swizzle: NOOP, TFLOPS: 47.27 (+0.00%) + (mma4x4+warp4x4+stage2+dsmem): ['10.8515625', '9.4140625 '], time:25.28347ms, swizzle: NOOP, TFLOPS: 43.49 + (mma4x4+warp4x4+stage3+dsmem+swizzle): ['10.8515625', '9.4140625 '], time:20.98624ms, swizzle: 2048, TFLOPS: 52.39 (+10.83%) + (mma4x4+warp4x4+stage2+dsmem+swizzle): ['10.8515625', '9.4140625 '], time:22.29118ms, swizzle: 2048, TFLOPS: 49.32 + (cublas): ['10.8515625', '9.4140625 '], time:23.58868ms, swizzle: NOOP, TFLOPS: 46.61 +---------------------------------------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------------------------------------- + M=16384, N=8192, K=8192, Warmup=5, Iters=20, 24/27 +---------------------------------------------------------------------------------------------------------------------------------- + (mma4x4+warp4x4+stage3+dsmem): ['68.375 ', '-2.234375 '], time:46.57695ms, swizzle: NOOP, TFLOPS: 47.21 (+0.00%) + (mma4x4+warp4x4+stage2+dsmem): ['68.375 ', '-2.234375 '], time:50.11103ms, swizzle: NOOP, TFLOPS: 43.88 + (mma4x4+warp4x4+stage3+dsmem+swizzle): ['68.375 ', '-2.234375 '], time:39.97759ms, swizzle: 2048, TFLOPS: 55.01 (+16.51%) + (mma4x4+warp4x4+stage2+dsmem+swizzle): ['68.375 ', '-2.234375 '], time:45.07379ms, swizzle: 2048, TFLOPS: 48.79 + (cublas): ['68.375 ', '-2.234375 '], time:46.13645ms, swizzle: NOOP, TFLOPS: 47.66 +---------------------------------------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------------------------------------- + M=16384, N=16384, K=2048, Warmup=5, Iters=20, 25/27 +---------------------------------------------------------------------------------------------------------------------------------- + (mma4x4+warp4x4+stage3+dsmem): ['-34.9375 ', '2.25585938'], time:24.82917ms, swizzle: NOOP, TFLOPS: 44.28 (+0.00%) + (mma4x4+warp4x4+stage2+dsmem): ['-34.9375 ', '2.25585938'], time:26.81517ms, swizzle: NOOP, TFLOPS: 41.00 + (mma4x4+warp4x4+stage3+dsmem+swizzle): ['-34.9375 ', '2.25585938'], time:22.22962ms, swizzle: 4096, TFLOPS: 49.46 (+11.69%) + (mma4x4+warp4x4+stage2+dsmem+swizzle): ['-34.9375 ', '2.25585938'], time:23.27709ms, swizzle: 4096, TFLOPS: 47.24 + (cublas): ['-34.90625 ', '2.21875 '], time:25.84185ms, swizzle: NOOP, TFLOPS: 42.55 +---------------------------------------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------------------------------------- + M=16384, N=16384, K=4096, Warmup=5, Iters=20, 26/27 +---------------------------------------------------------------------------------------------------------------------------------- + (mma4x4+warp4x4+stage3+dsmem): ['10.8515625', '9.4140625 '], time:48.43459ms, swizzle: NOOP, TFLOPS: 45.40 (+0.00%) + (mma4x4+warp4x4+stage2+dsmem): ['10.8515625', '9.4140625 '], time:52.00080ms, swizzle: NOOP, TFLOPS: 42.29 + (mma4x4+warp4x4+stage3+dsmem+swizzle): ['10.8515625', '9.4140625 '], time:43.28680ms, swizzle: 4096, TFLOPS: 50.80 (+11.89%) + (mma4x4+warp4x4+stage2+dsmem+swizzle): ['10.8515625', '9.4140625 '], time:47.73476ms, swizzle: 4096, TFLOPS: 46.07 + (cublas): ['10.8515625', '9.4140625 '], time:40.64793ms, swizzle: NOOP, TFLOPS: 54.10 (+6.49%) +---------------------------------------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------------------------------------- + M=16384, N=16384, K=8192, Warmup=5, Iters=20, 27/27 +---------------------------------------------------------------------------------------------------------------------------------- + (mma4x4+warp4x4+stage3+dsmem): ['68.375 ', '-2.234375 '], time:96.91984ms, swizzle: NOOP, TFLOPS: 45.38 (+0.00%) + (mma4x4+warp4x4+stage2+dsmem): ['68.375 ', '-2.234375 '], time:102.8722ms, swizzle: NOOP, TFLOPS: 42.75 + (mma4x4+warp4x4+stage3+dsmem+swizzle): ['68.375 ', '-2.234375 '], time:85.65800ms, swizzle: 4096, TFLOPS: 51.34 (+13.15%) + (mma4x4+warp4x4+stage2+dsmem+swizzle): ['68.375 ', '-2.234375 '], time:95.70884ms, swizzle: 4096, TFLOPS: 45.95 + (cublas): ['68.375 ', '-2.234375 '], time:104.2092ms, swizzle: NOOP, TFLOPS: 42.20 +---------------------------------------------------------------------------------------------------------------------------------- +``` diff --git a/hgemm/hgemm.cu b/hgemm/hgemm.cu index f46cade8..c4997035 100644 --- a/hgemm/hgemm.cu +++ b/hgemm/hgemm.cu @@ -1235,11 +1235,14 @@ void hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages(torch::Tensor a, torch::Tensor b int stages, bool swizzle, int swizzle_stride); void hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); +void hgemm_wmma_m16n16k16_mma4x2_warp4x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, + int stages, bool swizzle, int swizzle_stride); void hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); void hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, - int stages, bool swizzle, int swizzle_stride); - + int stages, bool swizzle, int swizzle_stride); +void hgemm_wmma_m16n16k16_mma4x4_warp2x2x2_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, + int stages, bool swizzle, int swizzle_stride); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // CUDA Cores FP16 @@ -1285,7 +1288,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // stage, thread block swizzle, dsmem TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages) TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem) + TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_mma4x2_warp4x4_stages_dsmem) TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem) TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_stages_dsmem) + TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_mma4x4_warp2x2x2_stages_dsmem) } diff --git a/hgemm/hgemm.py b/hgemm/hgemm.py index 21c0f4a9..4372a5d8 100644 --- a/hgemm/hgemm.py +++ b/hgemm/hgemm.py @@ -7,7 +7,26 @@ torch.set_grad_enabled(False) +def get_args(): + parser = argparse.ArgumentParser(description="hgemm benchmark") + parser.add_argument("--M", type=int, default=None, help="Matrix M size") + parser.add_argument("--N", type=int, default=None, help="Matrix N size") + parser.add_argument("--K", type=int, default=None, help="Matrix K size") + parser.add_argument("--warmup", "--w", type=int, default=5, help="Warmup iters") + parser.add_argument("--iters", "--i", type=int, default=20, help="Benchmark iters") + parser.add_argument("--enable-mma-all", "--mma", action="store_true", help="Enable all MMA kernel tests") + parser.add_argument("--enable-wmma-all", "--wmma", action="store_true", help="Enable all WMMA kernel tests") + parser.add_argument("--enable-cuda-all", "--cuda", action="store_true", help="Enable all CUDA kernel tests") + parser.add_argument("--enable-torch", "--torch", action="store_true", help="Enable torch matmul") + parser.add_argument("--enable-cublas", "--cublas", action="store_true", default=True, help="Enable cublas hgemm") + parser.add_argument("--disable-default", "--no-default", action="store_true", default=False, help="Disable default tests") + return parser.parse_args() + +args = get_args() +print(args) + # Load the CUDA kernel as a python module +print("Loading hgemm lib ...") lib = load(name='hgemm_lib', sources=['hgemm.cu', 'hgemm_async.cu', 'hgemm_wmma.cu', 'hgemm_wmma_stage.cu', 'hgemm_cublas.cu'], @@ -23,6 +42,7 @@ ], extra_cflags=['-std=c++17']) + MAX_TFLOPS = -1 def run_benchmark(perf_func: callable, @@ -30,7 +50,8 @@ def run_benchmark(perf_func: callable, tag: str, out: Optional[torch.Tensor] = None, stages: int = -1, swizzle: bool = False, swizzle_stride: int = 1, - warmup: int = 5, iters: int = 20, + warmup: int = args.warmup, + iters: int = args.iters, show_all: bool = False): global MAX_TFLOPS @@ -77,6 +98,7 @@ def run_benchmark(perf_func: callable, for i in range(iters): out = perf_func(a, b) torch.cuda.synchronize() + end = time.time() total_time = (end - start) * 1000 # ms mean_time = total_time / iters @@ -96,83 +118,96 @@ def run_benchmark(perf_func: callable, else: improve = 0 MAX_TFLOPS = TFLOPS - print(f"{out_info:>35}: {out_val}, time:{mean_time}ms, " + print(f"{out_info:>40}: {out_val}, time:{mean_time}ms, " f"swizzle: {swizzle_stride:<4}, TFLOPS: {TFLOPS:<6.2f}(+{improve:.2f}%)") else: - print(f"{out_info:>35}: {out_val}, time:{mean_time}ms, " + print(f"{out_info:>40}: {out_val}, time:{mean_time}ms, " f"swizzle: {swizzle_stride:<4}, TFLOPS: {TFLOPS:<6.2f}") if show_all: print(out) return out, mean_time -def get_args(): - parser = argparse.ArgumentParser(description="hgemm benchmark") - parser.add_argument("--enable-mma-all", "-ma", action="store_true") - parser.add_argument("--enable-wmma-all", "-wa", action="store_true") - parser.add_argument("--enable-cuda-all", "-ca", action="store_true") - return parser.parse_args() - - -args = get_args() Ms = [4096, 8192, 16384] Ns = [4096, 8192, 16384] Ks = [2048, 4096, 8192] -MAX_M, MAX_N, MAX_K = 16384, 16384, 8192 +if args.M and args.N and args.K: + Ms = [args.M] + Ns = [args.N] + Ks = [args.K] +MAX_M, MAX_N, MAX_K = max(Ms), max(Ns), max(Ks) # pre allocate for fast profiling. +torch.cuda.synchronize() +start = time.time() +print(f"pre allocate for fast profiling start, MAX_M={MAX_M}, MAX_N={MAX_N}, MAX_K={MAX_K}") A = torch.randn((MAX_M, MAX_K), dtype=torch.half).cuda() B = torch.randn((MAX_K, MAX_N), dtype=torch.half).cuda() C = torch.randn((MAX_M, MAX_N), dtype=torch.half).cuda() torch.cuda.synchronize() - +end = time.time() +print(f"pre allocate for fast profiling done, time: {(end - start) * 1000} ms") MNKs = [(M, N, K) for M in Ms for N in Ns for K in Ks] + +PERF_COUNT = 0 for (M, N, K) in MNKs: MAX_TFLOPS = -1 + PERF_COUNT += 1 + print("-" * 130) + print(" " * 30 + f"M={M}, N={N}, K={K}, Warmup={args.warmup}, Iters={args.iters}, {PERF_COUNT}/{len(MNKs)}") print("-" * 130) - print(" " * 55 + f"M={M}, N={N}, K={K}") a = A[:M, :K].contiguous() b = B[:K, :N].contiguous() c = C[:M, :N].contiguous() torch.cuda.synchronize() - - if args.enable_cuda_all: + if args.enable_cuda_all: # more cuda cores kernel tests. # CUDA Cores FP16 - run_benchmark(lib.hgemm_naive_f16, a, b, "f16(naive)", c) - run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x8_pack_bcf, a, b, "f16x8pack(t8x8+bcf)", c) - - run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x8_pack_bcf_dbuf, a, b, "f16x8pack(t8x8+dbuf)", c) - run_benchmark(lib.hgemm_t_8x8_sliced_k16_f16x8_pack_dbuf, a, b, "f16x8pack(t8x8+k16+dbuf)", c) - - print("-" * 68 + "WMMA" + "-" * 58) - # wmma api, stages, dsmem, swizzle - run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2, a, b, "(mma4x2)", c) - run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4, a, b, "(mma4x2+warp2x4)", c) - - # prefer on NVIDIA L20 device. - run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "(mma2x4+warp2x4+stage3)", c, stages=3) - run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "(mma2x4+warp2x4+stage2)", c, stages=2) - - run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "(mma2x4+...+stage3+dsmem)", c, stages=3) - run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "(mma2x4+...+stage2+dsmem)", c, stages=2) - - run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "(mma2x4+...+stage3+swizzle)", c, stages=3, swizzle=True) - run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "(mma2x4+...+stage2+swizzle)", c, stages=2, swizzle=True) - - run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "(...+stage3+dsmem+swizzle)", c, stages=3, swizzle=True) - run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "(...+stage2+dsmem+swizzle)", c, stages=2, swizzle=True) - - if args.enable_wmma_all: + run_benchmark(lib.hgemm_naive_f16, a, b, "(naive)", c) + run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x8_pack_bcf, a, b, "(f16x8pack+t8x8+bcf)", c) + if not args.disable_default: + run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x8_pack_bcf_dbuf, a, b, "(f16x8pack+t8x8+dbuf)", c) + run_benchmark(lib.hgemm_t_8x8_sliced_k16_f16x8_pack_dbuf, a, b, "(f16x8pack+t8x8+k16+dbuf)", c) + print("-" * 68 + "WMMA" + "-" * 58) + # wmma api, stages, dsmem, swizzle + run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2, a, b, "(mma4x2)", c) + run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4, a, b, "(mma4x2+warp2x4)", c) + # prefer on NVIDIA L20 device. + run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "(mma4x2+warp2x4+stage3)", c, stages=3) + run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "(mma4x2+warp2x4+stage2)", c, stages=2) + run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "(mma4x2+warp2x4+stage3+dsmem)", c, stages=3) + run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "(mma4x2+warp2x4+stage2+dsmem)", c, stages=2) + # swizzle + run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "(mma4x2+warp2x4+stage3+swizzle)", c, stages=3, swizzle=True) + run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "(mma4x2+warp2x4+stage2+swizzle)", c, stages=2, swizzle=True) + run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "(mma4x2+warp2x4+stage3+dsmem+swizzle)", c, stages=3, swizzle=True) + run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "(mma4x2+warp2x4+stage2+dsmem+swizzle)", c, stages=2, swizzle=True) + # TODO: add MMA PTX kernel tests. + if args.enable_wmma_all: # more wmma kernel tests. + # TODO: add more stages tests for mma2x4/mma4x4, 4,5 etc. # prefer on NVIDIA TRX 3080 Laptop 16GB GDDR6 device. - run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "(mma4x4+...+stage3+dsmem)", c, stages=3) - run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "(mma4x4+...+stage2+dsmem)", c, stages=2) - run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_stages_dsmem, a, b, "(warp2x4x2+...+stage3+dsmem)", c, stages=3) - run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_stages_dsmem, a, b, "(warp2x4x2+...+stage2+dsmem)", c, stages=2) - - run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "(mma4x4+stage3+dsmem+swizzle)", c, stages=3, swizzle=True) - run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "(mma4x4+stage2+dsmem+swizzle)", c, stages=2, swizzle=True) - run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_stages_dsmem, a, b, "(warp2x4x2+stage3+dsmem+swizzle)", c, stages=3, swizzle=True) - run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_stages_dsmem, a, b, "(warp2x4x2+stage2+dsmem+swizzle)", c, stages=2, swizzle=True) - - run_benchmark(lib.hgemm_cublas_tensor_op, a, b, "f16(cublas)", c) - run_benchmark(partial(torch.matmul, out=c), a, b, "f16_th") + run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "(mma4x4+warp4x4+stage3+dsmem)", c, stages=3) + run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "(mma4x4+warp4x4+stage2+dsmem)", c, stages=2) + # may not get good performance for warp_tile_k, e.g. warp2x2x2, warp2x4x2 etc. + # run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp4x4_stages_dsmem, a, b, "(mma4x2+warp4x4+stage3+dsmem)", c, stages=3) + # run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp4x4_stages_dsmem, a, b, "(mma4x2+warp4x4+stage2+dsmem)", c, stages=2) + # run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_stages_dsmem, a, b, "(mma4x2+warp2x4x2+stage3+dsmem)", c, stages=3) + # run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_stages_dsmem, a, b, "(mma4x2+warp2x4x2+stage2+dsmem)", c, stages=2) + # run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp2x2x2_stages_dsmem, a, b, "(mma4x4+warp2x2x2+stage3+dsmem)", c, stages=3) + # run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp2x2x2_stages_dsmem, a, b, "(mma4x4+warp2x2x2+stage2+dsmem)", c, stages=2) + # swizzle + run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "(mma4x4+warp4x4+stage3+dsmem+swizzle)", c, stages=3, swizzle=True) + run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "(mma4x4+warp4x4+stage2+dsmem+swizzle)", c, stages=2, swizzle=True) + # may not get good performance for warp_tile_k, e.g. warp2x2x2, warp2x4x2 etc. + # run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp4x4_stages_dsmem, a, b, "(mma4x2+warp4x4+stage3+dsmem+swizzle)", c, stages=3, swizzle=True) + # run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp4x4_stages_dsmem, a, b, "(mma4x2+warp4x4+stage2+dsmem+swizzle)", c, stages=2, swizzle=True) + # run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_stages_dsmem, a, b, "(mma4x2+warp2x4x2+stage3+dsmem+swizzle)", c, stages=3, swizzle=True) + # run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_stages_dsmem, a, b, "(mma4x2+warp2x4x2+stage2+dsmem+swizzle)", c, stages=2, swizzle=True) + # run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp2x2x2_stages_dsmem, a, b, "(mma4x4+warp2x2x2+stage3+dsmem+swizzle)", c, stages=3, swizzle=True) + # run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp2x2x2_stages_dsmem, a, b, "(mma4x4+warp2x2x2+stage2+dsmem+swizzle)", c, stages=2, swizzle=True) + if args.enable_mma_all: # more mma kernel tests. + print("-" * 68 + "MMA" + "-" * 59) + pass + if args.enable_cublas: + run_benchmark(lib.hgemm_cublas_tensor_op, a, b, "(cublas)", c) + if args.enable_torch: + run_benchmark(partial(torch.matmul, out=c), a, b, "(torch)") torch.cuda.synchronize() print("-" * 130) diff --git a/hgemm/hgemm_wmma_stage.cu b/hgemm/hgemm_wmma_stage.cu index cb635439..b583f685 100644 --- a/hgemm/hgemm_wmma_stage.cu +++ b/hgemm/hgemm_wmma_stage.cu @@ -34,7 +34,7 @@ using namespace nvcuda; HOST_DEVICE_INLINE int div_ceil(int a, int b) { return (a % b != 0) ? (a / b + 1) : (a / b); } -// stage2/3/4 (stage2=double buffers+copy async) +// stage2/3/4 (stage2=double buffers+copy async), 128x128, warp2x4(32,64,16) // 1. When using shared memory exceeds 48 KB, dynamic shared memory needs to be used, // i.e., declare a block of dynamic shared memory with extern shared half smem[];. // When calling the kernel, the size of the dynamic shared memory needs to be specified, @@ -249,7 +249,7 @@ hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_kernel( } } -// stage2/3/4 (stage2=double buffers+copy async) +// stage2/3/4 (stage2=double buffers+copy async), 128x128, warp2x4(32,64,16) // 1. When using shared memory exceeds 48 KB, dynamic shared memory needs to be used, // i.e., declare a block of dynamic shared memory with extern shared half smem[];. // When calling the kernel, the size of the dynamic shared memory needs to be specified, @@ -480,7 +480,7 @@ hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem_kernel( } } -// stage with 256x256 block, dynamic smem +// stage with 256x256 block, warp4x4(64,64,16), dynamic smem // __launch_bounds__: avoid error 'too many resources required for launch' // reference: https://blog.csdn.net/feng__shuai/article/details/124395023 template +__global__ void __launch_bounds__(512) +hgemm_wmma_m16n16k16_mma4x4_warp2x2x2_stages_dsmem_kernel( + half* A, half* B, half* C, int M, int N, int K) { + // 512 threads(16 warps) per block. + // const int bx = blockIdx.x; + // BLOCK_SWIZZLE 0/1 control use block swizzle or not. + const int bx = ((int) BLOCK_SWIZZLE) * blockIdx.z * gridDim.x + blockIdx.x; + const int by = blockIdx.y; + const int NUM_K_TILES = div_ceil(K, WMMA_K * WARP_TILE_K); + constexpr int BM = WMMA_M * WMMA_TILE_M * WARP_TILE_M; // 16x4*2=128 + constexpr int BN = WMMA_N * WMMA_TILE_N * WARP_TILE_N; // 16x4*2=128 + constexpr int BK = WMMA_K * WARP_TILE_K; // 16*2=32 + // s2: 2*128*(32)*2=16KB, 2*32*(128+16)*2=18KB, ~42KB + // s3: 3*128*(32)*2=24KB, 3*32*(128+16)*2=27KB, ~51KB + // s4: 4*128*(32)*2=32KB, 4*32*(128+16)*2=36KB, ~68KB + // s5: 5*128*(32)*2=40KB, 5*32*(128+16)*2=45KB, ~85KB + extern __shared__ half smem[]; + half* s_a = smem; + half* s_b = smem + K_STAGE * BM * (BK + A_PAD); + constexpr int s_a_stage_offset = BM * (BK + A_PAD); + constexpr int s_b_stage_offset = BK * (BN + B_PAD); -// --------------------- PyTorch bindings for custom kernel ----------------------- -#define STRINGFY(str) #str -#define TORCH_BINDING_COMMON_EXTENSION(func) \ - m.def(STRINGFY(func), &func, STRINGFY(func)); + // 要保证相同的warp下thread执行相同的指令 + const int tid = threadIdx.y * blockDim.x + threadIdx.x; + const int warp_id = tid / WARP_SIZE; // 0~7 warp_id within block + const int warp_m = warp_id / 4; // 0,1,2,3 + const int warp_n = warp_id % 4; // 0,1,2,3 + + // 先计算shared memory中的索引 + // tid和需要加载的smem s_a[BM][BK] 之间的索引关系 BM=128 BK=32 按行读取 A行主序 + // 对于s_a每行32个数据,每个线程读取8个,需要4个线程;总共128行,需要128x4刚好512线程 + int load_smem_a_m = tid / 4; // row 0~127 + int load_smem_a_k = (tid % 4) * 8; // col 0,8,16,24 + // tid和需要加载的smem s_b[BK][BN] 之间的索引关系 BK=32 BN=128 按行读取 B行主序 + // 对于s_b每行128个数据,每个线程读8个数据,需要16个线程;总共32行,需要32x16=256个线程 + int load_smem_b_k = tid / 16; // row 0~31 + int load_smem_b_n = (tid % 16) * 8; // col 0,8,...,120 + // 再计算全局内存中的索引 + // 要加载到s_a中的元素对应到A全局内存中的行数 每个block负责出C中大小为BM*BN的块 + int load_gmem_a_m = by * BM + load_smem_a_m; // global row of a and c + int load_gmem_b_n = bx * BN + load_smem_b_n; // global col of b and c -#define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \ -if(((T).options().dtype() != (th_type))) { \ - std::cout << "Tensor Info:" << (T).options() << std::endl; \ - throw std::runtime_error("values must be "#th_type); \ -} + wmma::fragment + C_frag[WARP_TILE_M][WARP_TILE_N]; + + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + wmma::fill_fragment(C_frag[i][j], 0.0); + } + } -#define CHECK_TORCH_TENSOR_SHAPE(T, S0, S1) \ -if (((T).size(0) != (S0)) || ((T).size(1) != (S1))) { \ - throw std::runtime_error("Tensor size mismatch!"); \ -} + // only cvta smem base ptr once for cp.async. + uint32_t smem_a_base_ptr = __cvta_generic_to_shared(s_a); + uint32_t smem_b_base_ptr = __cvta_generic_to_shared(s_b); -// 128x128 w/o dynamic smem -#define LAUNCH_161616_STAGE_SWIZZLE_KERNEL(stages, stride) \ -{ \ - const int N_SWIZZLE = (N + (stride) - 1) / (stride); \ - dim3 block(NUM_THREADS); \ - dim3 grid((div_ceil(N, BN) + N_SWIZZLE - 1) / N_SWIZZLE, \ - div_ceil(M, BM), \ - N_SWIZZLE); \ - hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_kernel< \ - WMMA_M, WMMA_N, WMMA_K, WMMA_TILE_M, WMMA_TILE_N, \ - WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD, \ - (stages), true><<>>( \ - reinterpret_cast(a.data_ptr()), \ - reinterpret_cast(b.data_ptr()), \ - reinterpret_cast(c.data_ptr()), \ - M, N, K \ - ); \ -} + #pragma unroll + for (int k = 0; k < (K_STAGE - 1); ++k) { // 0, 1 + // k * WMMA_K, WMMA_K=16 -> (k << 4) + int load_gmem_a_k = k * (WMMA_K * WARP_TILE_K) + load_smem_a_k; // global col of a + int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k; + int load_gmem_b_k = k * (WMMA_K * WARP_TILE_K) + load_smem_b_k; // global row of b + int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n; + + uint32_t load_smem_a_ptr = ( + smem_a_base_ptr + (k * s_a_stage_offset + + load_smem_a_m * (BK + A_PAD) + + load_smem_a_k) * sizeof(half) + ); -#define LAUNCH_161616_STAGE_NO_SWIZZLE_KERNEL(stages) \ -{ \ - dim3 block(NUM_THREADS); \ - dim3 grid(div_ceil(N, BN), div_ceil(M, BM)); \ - hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_kernel< \ - WMMA_M, WMMA_N, WMMA_K, WMMA_TILE_M, WMMA_TILE_N, \ - WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD, \ - (stages), false><<>>( \ - reinterpret_cast(a.data_ptr()), \ - reinterpret_cast(b.data_ptr()), \ - reinterpret_cast(c.data_ptr()), \ - M, N, K \ - ); \ -} + uint32_t load_smem_b_ptr = ( + smem_b_base_ptr + (k * s_b_stage_offset + + load_smem_b_k * (BN + B_PAD) + + load_smem_b_n) * sizeof(half) + ); -// 128x128 stage 2/3/4 w/o block swizzle across N dim, static smem < 48KB -void hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages( - torch::Tensor a, torch::Tensor b, torch::Tensor c, - int stages, bool swizzle, int swizzle_stride) { - CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf) - CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf) - CHECK_TORCH_TENSOR_DTYPE(c, torch::kHalf) - const int M = a.size(0); - const int K = a.size(1); - const int N = b.size(1); - CHECK_TORCH_TENSOR_SHAPE(a, M, K) - CHECK_TORCH_TENSOR_SHAPE(b, K, N) - CHECK_TORCH_TENSOR_SHAPE(c, M, N) - constexpr int WMMA_M = 16; - constexpr int WMMA_N = 16; - constexpr int WMMA_K = 16; - constexpr int WMMA_TILE_M = 4; - constexpr int WMMA_TILE_N = 2; - constexpr int WARP_TILE_M = 2; - constexpr int WARP_TILE_N = 4; - // s_a 4 ways bank conflicts within warp, after pad 8 -> 4 ways bank conflicts. - // s_b 16 ways bank conflicts within warp, after pad 8 -> 8 ways bank conflicts. - // s_b 16 ways bank conflicts within warp, after pad 16 -> 4 ways bank conflicts. - // so, the best padding policy for s_a and s_b is A_PAD=0/8, B_PAD=16. Thus, - // improve B_PAD consume 8x~ less smem than A_PAD, 16xB_PAD vs 128xA_PAD. - constexpr int A_PAD = 0; // 0,8,16 - constexpr int B_PAD = 16; // 0,8,16 - constexpr int NUM_THREADS= ( - WMMA_TILE_M * WMMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256 - constexpr int BM = WMMA_M * WMMA_TILE_M * WARP_TILE_M; - constexpr int BN = WMMA_N * WMMA_TILE_N * WARP_TILE_N; - constexpr int BK = WMMA_K; - // s2: 2*128*(16)*2=8KB, 2*16*(128+16)*2=9KB, ~17KB - // s3: 3*128*(16)*2=12KB, 3*16*(128+16)*2=13.5KB, ~26KB - // s4: 4*128*(16)*2=16KB, 4*16*(128+16)*2=18KB, ~34KB - // s5: 5*128*(16)*2=20KB, 5*16*(128+16)*2=22.5KB, ~43KB - if (swizzle) { - assert(swizzle_stride % 256 == 0); - switch (stages) - { - case 2: // ~17KB - LAUNCH_161616_STAGE_SWIZZLE_KERNEL(2, swizzle_stride); - break; - case 3: // ~26KB - LAUNCH_161616_STAGE_SWIZZLE_KERNEL(3, swizzle_stride); - break; - case 4: // ~34KB - LAUNCH_161616_STAGE_SWIZZLE_KERNEL(4, swizzle_stride); - break; - case 5: // ~43KB - LAUNCH_161616_STAGE_SWIZZLE_KERNEL(5, swizzle_stride); - break; - default: - LAUNCH_161616_STAGE_SWIZZLE_KERNEL(2, swizzle_stride); - break; - } - } else { - switch (stages) - { - case 2: - LAUNCH_161616_STAGE_NO_SWIZZLE_KERNEL(2); - break; - case 3: - LAUNCH_161616_STAGE_NO_SWIZZLE_KERNEL(3); - break; - case 4: - LAUNCH_161616_STAGE_NO_SWIZZLE_KERNEL(4); - break; - default: - LAUNCH_161616_STAGE_NO_SWIZZLE_KERNEL(2); - break; - } + CP_ASYNC_CG(load_smem_a_ptr, &A[load_gmem_a_addr], 16); + CP_ASYNC_CG(load_smem_b_ptr, &B[load_gmem_b_addr], 16); + + CP_ASYNC_COMMIT_GROUP(); } -} -// 128x128 w dynamic smem, 98304=96KB < Ampere, Ada, Hopper ... -#define LAUNCH_161616_STAGE_SWIZZLE_DSMEM_KERNEL(stages, stride) \ -{ \ - const int smem_max_size = ( \ - (stages) * BM * (BK + A_PAD) * sizeof(half) + \ - (stages) * BK * (BN + B_PAD) * sizeof(half)); \ - cudaFuncSetAttribute( \ - hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem_kernel< \ - WMMA_M, WMMA_N, WMMA_K, WMMA_TILE_M, WMMA_TILE_N, \ - WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD, (stages), true>, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, \ - 98304); \ - const int N_SWIZZLE = (N + (stride) - 1) / (stride); \ - dim3 block(NUM_THREADS); \ - dim3 grid((div_ceil(N, BN) + N_SWIZZLE - 1) / N_SWIZZLE, \ - div_ceil(M, BM), \ - N_SWIZZLE); \ - hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem_kernel< \ - WMMA_M, WMMA_N, WMMA_K, WMMA_TILE_M, WMMA_TILE_N, \ - WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD, (stages), true><<< \ - grid, block, smem_max_size>>>( \ - reinterpret_cast(a.data_ptr()), \ - reinterpret_cast(b.data_ptr()), \ - reinterpret_cast(c.data_ptr()), \ - M, N, K \ - ); \ -} + CP_ASYNC_WAIT_GROUP(K_STAGE-2); // s2->0, s3->1, s4->2 + __syncthreads(); -#define LAUNCH_161616_STAGE_NO_SWIZZLE_DSMEM_KERNEL(stages) \ -{ \ - const int smem_max_size = ( \ - (stages) * BM * (BK + A_PAD) * sizeof(half) + \ - (stages) * BK * (BN + B_PAD) * sizeof(half)); \ - cudaFuncSetAttribute( \ - hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem_kernel< \ - WMMA_M, WMMA_N, WMMA_K, WMMA_TILE_M, WMMA_TILE_N, \ - WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD, (stages), false>,\ - cudaFuncAttributeMaxDynamicSharedMemorySize, \ - 98304); \ - dim3 block(NUM_THREADS); \ - dim3 grid(div_ceil(N, BN), div_ceil(M, BM)); \ - hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem_kernel< \ - WMMA_M, WMMA_N, WMMA_K, WMMA_TILE_M, WMMA_TILE_N, \ - WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD, (stages), false><<<\ - grid, block, smem_max_size>>>( \ - reinterpret_cast(a.data_ptr()), \ - reinterpret_cast(b.data_ptr()), \ + #pragma unroll + for (int k = (K_STAGE - 1); k < NUM_K_TILES; k++) { + // s2/4 can use bitwise ops but s3 can not, so, we use mod + // ops for all stages kernel. s2: (k + 1)&1, s4: (k + 1)&3 + // s3: (k + 1) % 3 + int smem_sel = (k + 1) % K_STAGE; // s3 k 2->0, k 3->1, k 4->2... + int smem_sel_next = k % K_STAGE; // s3 k 2->2, k 3->0, k 4->1... + + // k * WMMA_K, WMMA_K=16 -> (k << 4) + int load_gmem_a_k = k * (WMMA_K * WARP_TILE_K) + load_smem_a_k; // global col of a + int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k; + int load_gmem_b_k = k * (WMMA_K * WARP_TILE_K) + load_smem_b_k; // global row of b + int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n; + + // load stage 2, k start from 2 + uint32_t load_smem_a_ptr = ( + smem_a_base_ptr + (smem_sel_next * s_a_stage_offset + + load_smem_a_m * (BK + A_PAD) + + load_smem_a_k) * sizeof(half) + ); + + uint32_t load_smem_b_ptr = ( + smem_b_base_ptr + (smem_sel_next * s_b_stage_offset + + load_smem_b_k * (BN + B_PAD) + + load_smem_b_n) * sizeof(half) + ); + + CP_ASYNC_CG(load_smem_a_ptr, &A[load_gmem_a_addr], 16); + CP_ASYNC_CG(load_smem_b_ptr, &B[load_gmem_b_addr], 16); + + CP_ASYNC_COMMIT_GROUP(); + + // WARP_TILE_K=2 + for (int warp_k = 0; warp_k < WARP_TILE_K; ++warp_k) { + wmma::fragment A_frag[WARP_TILE_M]; + wmma::fragment B_frag[WARP_TILE_N]; + const int warp_smem_k = warp_k * WMMA_K; // 0,16 + + // compute stage 0 + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + // load 2 tiles -> reg, smem a -> frags a, warp_m 0~3 + int warp_smem_a_m = warp_m * (WMMA_M * WARP_TILE_M) + i * WMMA_M; + half* load_smem_a_frag_ptr = (s_a + smem_sel * s_a_stage_offset + + warp_smem_a_m * (BK + A_PAD) + + warp_smem_k); + wmma::load_matrix_sync(A_frag[i], load_smem_a_frag_ptr, BK + A_PAD); + } + + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + // load 4 tiles -> reg, smem b -> frags b, warp_n 0~2 + int warp_smem_b_n = warp_n * (WMMA_N * WARP_TILE_N) + j * WMMA_N; + half* load_smem_b_frag_ptr = (s_b + smem_sel * s_b_stage_offset + + warp_smem_k * (BN + B_PAD) + + warp_smem_b_n); + wmma::load_matrix_sync(B_frag[j], load_smem_b_frag_ptr, BN + B_PAD); + } + + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + wmma::mma_sync(C_frag[i][j], A_frag[i], B_frag[j], C_frag[i][j]); + } + } + } + + CP_ASYNC_WAIT_GROUP(K_STAGE-2); + __syncthreads(); + } + + // make sure all memory issues ready. + if ((K_STAGE - 2) > 0) { + CP_ASYNC_WAIT_GROUP(0); + __syncthreads(); + } + + // processing last (K_STAGE-1) k iters. + { + #pragma unroll + for (int k = 0; k < (K_STAGE - 1); k++) { + const int stage_sel = ((NUM_K_TILES - (K_STAGE - 1) + k) % K_STAGE); + + #pragma unroll + for (int warp_k = 0; warp_k < WARP_TILE_K; ++warp_k) { + wmma::fragment A_frag[WARP_TILE_M]; + wmma::fragment B_frag[WARP_TILE_N]; + const int warp_smem_k = warp_k * WMMA_K; // 0,16 + + // compute stage 0 + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + // load 2 tiles -> reg, smem a -> frags a, warp_m 0~3 + int warp_smem_a_m = warp_m * (WMMA_M * WARP_TILE_M) + i * WMMA_M; + half* load_smem_a_frag_ptr = (s_a + stage_sel * s_a_stage_offset + + warp_smem_a_m * (BK + A_PAD) + + warp_smem_k); + wmma::load_matrix_sync(A_frag[i], load_smem_a_frag_ptr, BK + A_PAD); + } + + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + // load 4 tiles -> reg, smem b -> frags b, warp_n 0~2 + int warp_smem_b_n = warp_n * (WMMA_N * WARP_TILE_N) + j * WMMA_N; + half* load_smem_b_frag_ptr = (s_b + stage_sel * s_b_stage_offset + + warp_smem_k * (BN + B_PAD) + + warp_smem_b_n); + wmma::load_matrix_sync(B_frag[j], load_smem_b_frag_ptr, BN + B_PAD); + } + + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + wmma::mma_sync(C_frag[i][j], A_frag[i], B_frag[j], C_frag[i][j]); + } + } + } + } + } + + // finally, store back to C matrix. + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + const int store_gmem_a_m = by * BM + warp_m * (WMMA_M * WARP_TILE_M) + i * WMMA_M; + const int store_gmem_a_n = bx * BN + warp_n * (WMMA_N * WARP_TILE_N) + j * WMMA_N; + wmma::store_matrix_sync(C + store_gmem_a_m * N + store_gmem_a_n, C_frag[i][j], N, + wmma::mem_row_major); + } + } +} + +// TODO: 256x128, Stages + K32 + Reg Buffers, mma4x2, warp4x4x2(64,64,16) +template +__global__ void __launch_bounds__(256) +hgemm_wmma_m16n16k16_mma4x2_warp4x4_stages_dsmem_kernel( + half* A, half* B, half* C, int M, int N, int K) { + // 256 threads(8 warps) per block. + // const int bx = blockIdx.x; + // BLOCK_SWIZZLE 0/1 control use block swizzle or not. + const int bx = ((int) BLOCK_SWIZZLE) * blockIdx.z * gridDim.x + blockIdx.x; + const int by = blockIdx.y; + const int NUM_K_TILES = div_ceil(K, WMMA_K * WARP_TILE_K); + constexpr int BM = WMMA_M * WMMA_TILE_M * WARP_TILE_M; // 16x4*4=256 + constexpr int BN = WMMA_N * WMMA_TILE_N * WARP_TILE_N; // 16x2*4=128 + constexpr int BK = WMMA_K * WARP_TILE_K; // 16*2=32 + // s2: 2*128*(32)*2=16KB, 2*32*(128+16)*2=18KB, ~42KB + // s3: 3*128*(32)*2=24KB, 3*32*(128+16)*2=27KB, ~51KB + // s4: 4*128*(32)*2=32KB, 4*32*(128+16)*2=36KB, ~68KB + // s4: 5*128*(32)*2=40KB, 5*32*(128+16)*2=45KB, ~85KB + extern __shared__ half smem[]; + half* s_a = smem; + half* s_b = smem + K_STAGE * BM * (BK + A_PAD); + constexpr int s_a_stage_offset = BM * (BK + A_PAD); + constexpr int s_b_stage_offset = BK * (BN + B_PAD); + + // 要保证相同的warp下thread执行相同的指令 + const int tid = threadIdx.y * blockDim.x + threadIdx.x; + const int warp_id = tid / WARP_SIZE; // 0~7 warp_id within block + const int warp_m = warp_id / 2; // 0,1,2,3 + const int warp_n = warp_id % 2; // 0,1 + + // 先计算shared memory中的索引 + // tid和需要加载的smem s_a[BM][BK] 之间的索引关系 BM=256 BK=32 按行读取 A行主序 + // 对于s_a每行16个数据,每个线程读取16个,需要1个线程;总共256行,刚好256线程 + int load_smem_a_m = tid; // row 0~255 + int load_smem_a_k = 0; // col 0,16 + // tid和需要加载的smem s_b[BK][BN] 之间的索引关系 BK=16 BN=128 按行读取 B行主序 + // 对于s_b每行128个数据,每个线程读8个数据,需要16个线程;总共16行,需要16x16=256个线程 + int load_smem_b_k = tid / 16; // row 0~15 + int load_smem_b_n = (tid % 16) * 8; // col 0,8,...,120 + // 再计算全局内存中的索引 + // 要加载到s_a中的元素对应到A全局内存中的行数 每个block负责出C中大小为BM*BN的块 + int load_gmem_a_m = by * BM + load_smem_a_m; // global row of a and c + int load_gmem_b_n = bx * BN + load_smem_b_n; // global col of b and c + + wmma::fragment + C_frag[WARP_TILE_M][WARP_TILE_N]; + + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + wmma::fill_fragment(C_frag[i][j], 0.0); + } + } + + // only cvta smem base ptr once for cp.async. + uint32_t smem_a_base_ptr = __cvta_generic_to_shared(s_a); + uint32_t smem_b_base_ptr = __cvta_generic_to_shared(s_b); + + #pragma unroll + for (int k = 0; k < (K_STAGE - 1); ++k) { // 0, 1 + // k * WMMA_K, WMMA_K=16 -> (k << 4) + int load_gmem_a_k = k * (WMMA_K * WARP_TILE_K) + load_smem_a_k; // global col of a + int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k; + int load_gmem_b_k = k * (WMMA_K * WARP_TILE_K) + load_smem_b_k; // global row of b + int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n; + + uint32_t load_smem_a_ptr = ( + smem_a_base_ptr + (k * s_a_stage_offset + + load_smem_a_m * (BK + A_PAD) + + load_smem_a_k) * sizeof(half) + ); + + uint32_t load_smem_b_ptr = ( + smem_b_base_ptr + (k * s_b_stage_offset + + load_smem_b_k * (BN + B_PAD) + + load_smem_b_n) * sizeof(half) + ); + + CP_ASYNC_CG(load_smem_a_ptr, &A[load_gmem_a_addr], 16); + CP_ASYNC_CG(load_smem_b_ptr, &B[load_gmem_b_addr], 16); + CP_ASYNC_CG(load_smem_a_ptr + 16, &A[load_gmem_a_addr + 8], 16); + + CP_ASYNC_COMMIT_GROUP(); + } + + CP_ASYNC_WAIT_GROUP(K_STAGE-2); // s2->0, s3->1, s4->2 + __syncthreads(); + + #pragma unroll + for (int k = (K_STAGE - 1); k < NUM_K_TILES; k++) { + // s2/4 can use bitwise ops but s3 can not, so, we use mod + // ops for all stages kernel. s2: (k + 1)&1, s4: (k + 1)&3 + // s3: (k + 1) % 3 + int smem_sel = (k + 1) % K_STAGE; // s3 k 2->0, k 3->1, k 4->2... + int smem_sel_next = k % K_STAGE; // s3 k 2->2, k 3->0, k 4->1... + + // k * WMMA_K, WMMA_K=16 -> (k << 4) + int load_gmem_a_k = k * (WMMA_K * WARP_TILE_K) + load_smem_a_k; // global col of a + int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k; + int load_gmem_b_k = k * (WMMA_K * WARP_TILE_K) + load_smem_b_k; // global row of b + int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n; + + // load stage 2, k start from 2 + uint32_t load_smem_a_ptr = ( + smem_a_base_ptr + (smem_sel_next * s_a_stage_offset + + load_smem_a_m * (BK + A_PAD) + + load_smem_a_k) * sizeof(half) + ); + + uint32_t load_smem_b_ptr = ( + smem_b_base_ptr + (smem_sel_next * s_b_stage_offset + + load_smem_b_k * (BN + B_PAD) + + load_smem_b_n) * sizeof(half) + ); + + CP_ASYNC_CG(load_smem_a_ptr, &A[load_gmem_a_addr], 16); + CP_ASYNC_CG(load_smem_b_ptr, &B[load_gmem_b_addr], 16); + CP_ASYNC_CG(load_smem_a_ptr + 16, &A[load_gmem_a_addr + 8], 16); + + CP_ASYNC_COMMIT_GROUP(); + + for (int warp_k = 0; warp_k < WARP_TILE_K; ++warp_k) { + wmma::fragment A_frag[WARP_TILE_M]; + wmma::fragment B_frag[WARP_TILE_N]; + const int warp_smem_k = warp_k * WMMA_K; // 0,16 + + // compute stage 0 + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + // load 2 tiles -> reg, smem a -> frags a, warp_m 0~3 + int warp_smem_a_m = warp_m * (WMMA_M * WARP_TILE_M) + i * WMMA_M; + half* load_smem_a_frag_ptr = (s_a + smem_sel * s_a_stage_offset + + warp_smem_a_m * (BK + A_PAD) + + warp_smem_k); + wmma::load_matrix_sync(A_frag[i], load_smem_a_frag_ptr, BK + A_PAD); + } + + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + // load 4 tiles -> reg, smem b -> frags b, warp_n 0~2 + int warp_smem_b_n = warp_n * (WMMA_N * WARP_TILE_N) + j * WMMA_N; + half* load_smem_b_frag_ptr = (s_b + smem_sel * s_b_stage_offset + + warp_smem_k * (BN + B_PAD) + + warp_smem_b_n); + wmma::load_matrix_sync(B_frag[j], load_smem_b_frag_ptr, BN + B_PAD); + } + + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + wmma::mma_sync(C_frag[i][j], A_frag[i], B_frag[j], C_frag[i][j]); + } + } + } + + CP_ASYNC_WAIT_GROUP(K_STAGE-2); + __syncthreads(); + } + + // make sure all memory issues ready. + if ((K_STAGE - 2) > 0) { + CP_ASYNC_WAIT_GROUP(0); + __syncthreads(); + } + + // processing last (K_STAGE-1) k iters. + { + #pragma unroll + for (int k = 0; k < (K_STAGE - 1); k++) { + const int stage_sel = ((NUM_K_TILES - (K_STAGE - 1) + k) % K_STAGE); + + #pragma unroll + for (int warp_k = 0; warp_k < WARP_TILE_K; ++warp_k) { + wmma::fragment A_frag[WARP_TILE_M]; + wmma::fragment B_frag[WARP_TILE_N]; + const int warp_smem_k = warp_k * WMMA_K; // 0,16 + + // compute stage 0 + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + // load 2 tiles -> reg, smem a -> frags a, warp_m 0~3 + int warp_smem_a_m = warp_m * (WMMA_M * WARP_TILE_M) + i * WMMA_M; + half* load_smem_a_frag_ptr = (s_a + stage_sel * s_a_stage_offset + + warp_smem_a_m * (BK + A_PAD) + + warp_smem_k); + wmma::load_matrix_sync(A_frag[i], load_smem_a_frag_ptr, BK + A_PAD); + } + + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + // load 4 tiles -> reg, smem b -> frags b, warp_n 0~2 + int warp_smem_b_n = warp_n * (WMMA_N * WARP_TILE_N) + j * WMMA_N; + half* load_smem_b_frag_ptr = (s_b + stage_sel * s_b_stage_offset + + warp_smem_k * (BN + B_PAD) + + warp_smem_b_n); + wmma::load_matrix_sync(B_frag[j], load_smem_b_frag_ptr, BN + B_PAD); + } + + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + wmma::mma_sync(C_frag[i][j], A_frag[i], B_frag[j], C_frag[i][j]); + } + } + } + } + } + + // finally, store back to C matrix. + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + const int store_gmem_a_m = by * BM + warp_m * (WMMA_M * WARP_TILE_M) + i * WMMA_M; + const int store_gmem_a_n = bx * BN + warp_n * (WMMA_N * WARP_TILE_N) + j * WMMA_N; + wmma::store_matrix_sync(C + store_gmem_a_m * N + store_gmem_a_n, C_frag[i][j], N, + wmma::mem_row_major); + } + } +} + +// TODO: Warp swizzle/permute support ? (MMA, not WMMA) + +// --------------------- PyTorch bindings for custom kernel ----------------------- +#define STRINGFY(str) #str +#define TORCH_BINDING_COMMON_EXTENSION(func) \ + m.def(STRINGFY(func), &func, STRINGFY(func)); + +#define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \ +if(((T).options().dtype() != (th_type))) { \ + std::cout << "Tensor Info:" << (T).options() << std::endl; \ + throw std::runtime_error("values must be "#th_type); \ +} + +#define CHECK_TORCH_TENSOR_SHAPE(T, S0, S1) \ +if (((T).size(0) != (S0)) || ((T).size(1) != (S1))) { \ + throw std::runtime_error("Tensor size mismatch!"); \ +} + +// 128x128 warp2x4(32,64) w/o dynamic smem +#define LAUNCH_161616_STAGE_SWIZZLE_KERNEL(stages, stride) \ +{ \ + const int N_SWIZZLE = (N + (stride) - 1) / (stride); \ + dim3 block(NUM_THREADS); \ + dim3 grid((div_ceil(N, BN) + N_SWIZZLE - 1) / N_SWIZZLE, \ + div_ceil(M, BM), \ + N_SWIZZLE); \ + hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_kernel< \ + WMMA_M, WMMA_N, WMMA_K, WMMA_TILE_M, WMMA_TILE_N, \ + WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD, \ + (stages), true><<>>( \ + reinterpret_cast(a.data_ptr()), \ + reinterpret_cast(b.data_ptr()), \ + reinterpret_cast(c.data_ptr()), \ + M, N, K \ + ); \ +} + +#define LAUNCH_161616_STAGE_NO_SWIZZLE_KERNEL(stages) \ +{ \ + dim3 block(NUM_THREADS); \ + dim3 grid(div_ceil(N, BN), div_ceil(M, BM)); \ + hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_kernel< \ + WMMA_M, WMMA_N, WMMA_K, WMMA_TILE_M, WMMA_TILE_N, \ + WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD, \ + (stages), false><<>>( \ + reinterpret_cast(a.data_ptr()), \ + reinterpret_cast(b.data_ptr()), \ + reinterpret_cast(c.data_ptr()), \ + M, N, K \ + ); \ +} + +// 128x128 warp2x4(32,64) stage 2/3/4 w/o block swizzle across N dim, static smem < 48KB +void hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages( + torch::Tensor a, torch::Tensor b, torch::Tensor c, + int stages, bool swizzle, int swizzle_stride) { + CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf) + CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf) + CHECK_TORCH_TENSOR_DTYPE(c, torch::kHalf) + const int M = a.size(0); + const int K = a.size(1); + const int N = b.size(1); + CHECK_TORCH_TENSOR_SHAPE(a, M, K) + CHECK_TORCH_TENSOR_SHAPE(b, K, N) + CHECK_TORCH_TENSOR_SHAPE(c, M, N) + constexpr int WMMA_M = 16; + constexpr int WMMA_N = 16; + constexpr int WMMA_K = 16; + constexpr int WMMA_TILE_M = 4; + constexpr int WMMA_TILE_N = 2; + constexpr int WARP_TILE_M = 2; + constexpr int WARP_TILE_N = 4; + // s_a 4 ways bank conflicts within warp, after pad 8 -> 4 ways bank conflicts. + // s_b 16 ways bank conflicts within warp, after pad 8 -> 8 ways bank conflicts. + // s_b 16 ways bank conflicts within warp, after pad 16 -> 4 ways bank conflicts. + // so, the best padding policy for s_a and s_b is A_PAD=0/8, B_PAD=16. Thus, + // improve B_PAD consume 8x~ less smem than A_PAD, 16xB_PAD vs 128xA_PAD. + constexpr int A_PAD = 0; // 0,8,16 + constexpr int B_PAD = 16; // 0,8,16 + constexpr int NUM_THREADS= ( + WMMA_TILE_M * WMMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256 + constexpr int BM = WMMA_M * WMMA_TILE_M * WARP_TILE_M; + constexpr int BN = WMMA_N * WMMA_TILE_N * WARP_TILE_N; + constexpr int BK = WMMA_K; + // s2: 2*128*(16)*2=8KB, 2*16*(128+16)*2=9KB, ~17KB + // s3: 3*128*(16)*2=12KB, 3*16*(128+16)*2=13.5KB, ~26KB + // s4: 4*128*(16)*2=16KB, 4*16*(128+16)*2=18KB, ~34KB + // s5: 5*128*(16)*2=20KB, 5*16*(128+16)*2=22.5KB, ~43KB + if (swizzle) { + assert(swizzle_stride % 256 == 0); + switch (stages) + { + case 2: // ~17KB + LAUNCH_161616_STAGE_SWIZZLE_KERNEL(2, swizzle_stride); + break; + case 3: // ~26KB + LAUNCH_161616_STAGE_SWIZZLE_KERNEL(3, swizzle_stride); + break; + case 4: // ~34KB + LAUNCH_161616_STAGE_SWIZZLE_KERNEL(4, swizzle_stride); + break; + case 5: // ~43KB + LAUNCH_161616_STAGE_SWIZZLE_KERNEL(5, swizzle_stride); + break; + default: + LAUNCH_161616_STAGE_SWIZZLE_KERNEL(2, swizzle_stride); + break; + } + } else { + switch (stages) + { + case 2: + LAUNCH_161616_STAGE_NO_SWIZZLE_KERNEL(2); + break; + case 3: + LAUNCH_161616_STAGE_NO_SWIZZLE_KERNEL(3); + break; + case 4: + LAUNCH_161616_STAGE_NO_SWIZZLE_KERNEL(4); + break; + default: + LAUNCH_161616_STAGE_NO_SWIZZLE_KERNEL(2); + break; + } + } +} + +// 128x128 warp2x4(32,64) w dynamic smem, 98304=96KB < Ampere, Ada, Hopper ... +#define LAUNCH_161616_STAGE_SWIZZLE_DSMEM_KERNEL(stages, stride) \ +{ \ + const int smem_max_size = ( \ + (stages) * BM * (BK + A_PAD) * sizeof(half) + \ + (stages) * BK * (BN + B_PAD) * sizeof(half)); \ + cudaFuncSetAttribute( \ + hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem_kernel< \ + WMMA_M, WMMA_N, WMMA_K, WMMA_TILE_M, WMMA_TILE_N, \ + WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD, (stages), true>, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + 98304); \ + const int N_SWIZZLE = (N + (stride) - 1) / (stride); \ + dim3 block(NUM_THREADS); \ + dim3 grid((div_ceil(N, BN) + N_SWIZZLE - 1) / N_SWIZZLE, \ + div_ceil(M, BM), \ + N_SWIZZLE); \ + hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem_kernel< \ + WMMA_M, WMMA_N, WMMA_K, WMMA_TILE_M, WMMA_TILE_N, \ + WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD, (stages), true><<< \ + grid, block, smem_max_size>>>( \ + reinterpret_cast(a.data_ptr()), \ + reinterpret_cast(b.data_ptr()), \ + reinterpret_cast(c.data_ptr()), \ + M, N, K \ + ); \ +} + +#define LAUNCH_161616_STAGE_NO_SWIZZLE_DSMEM_KERNEL(stages) \ +{ \ + const int smem_max_size = ( \ + (stages) * BM * (BK + A_PAD) * sizeof(half) + \ + (stages) * BK * (BN + B_PAD) * sizeof(half)); \ + cudaFuncSetAttribute( \ + hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem_kernel< \ + WMMA_M, WMMA_N, WMMA_K, WMMA_TILE_M, WMMA_TILE_N, \ + WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD, (stages), false>,\ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + 98304); \ + dim3 block(NUM_THREADS); \ + dim3 grid(div_ceil(N, BN), div_ceil(M, BM)); \ + hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem_kernel< \ + WMMA_M, WMMA_N, WMMA_K, WMMA_TILE_M, WMMA_TILE_N, \ + WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD, (stages), false><<<\ + grid, block, smem_max_size>>>( \ + reinterpret_cast(a.data_ptr()), \ + reinterpret_cast(b.data_ptr()), \ reinterpret_cast(c.data_ptr()), \ M, N, K \ ); \ } -// 128x128 stage 2/3/4 + dynamic smem, w/o block swizzle across N dim +// 128x128 warp2x4(32,64) stage 2/3/4 + dynamic smem, w/o block swizzle across N dim void hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem( torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride) { @@ -1220,7 +1701,7 @@ void hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem( } } -// 256x256 w dynamic smem, 98304=96KB < Ampere, Ada, Hopper ... +// 256x256 warp4x4(64,64,32) w dynamic smem, 98304=96KB < Ampere, Ada, Hopper ... #define LAUNCH_161616_STAGE_SWIZZLE_DSMEM_256x256_KERNEL(stages, stride) \ { \ const int smem_max_size = ( \ @@ -1272,7 +1753,7 @@ void hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem( ); \ } -// 256x256 stage 2/3/4 + dynamic smem, w/o block swizzle across N dim +// 256x256 warp4x4(64,64,32) stage 2/3/4 + dynamic smem, w/o block swizzle across N dim void hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem( torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride) { @@ -1343,7 +1824,7 @@ void hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem( } } -// 128x128 warp2x4x2 w dynamic smem, 98304=96KB < Ampere, Ada, Hopper ... +// 128x128 warp2x4x2(32,64,32) w dynamic smem, 98304=96KB < Ampere, Ada, Hopper ... #define LAUNCH_161616_STAGE_SWIZZLE_DSMEM_K32_KERNEL(stages, stride)\ { \ const int smem_max_size = ( \ @@ -1399,6 +1880,7 @@ void hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem( ); \ } +// 128x128 warp2x4x2(32,64,32) void hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_stages_dsmem( torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride) { @@ -1473,3 +1955,265 @@ void hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_stages_dsmem( } } } + +// 128x128 warp2x2x2(32,32,32) w dynamic smem, 98304=96KB < Ampere, Ada, Hopper ... +#define LAUNCH_161616_STAGE_SWIZZLE_DSMEM_K32_MMA4x4_KERNEL(stages, stride)\ +{ \ + const int smem_max_size = ( \ + (stages) * BM * (BK + A_PAD) * sizeof(half) + \ + (stages) * BK * (BN + B_PAD) * sizeof(half)); \ + cudaFuncSetAttribute( \ + hgemm_wmma_m16n16k16_mma4x4_warp2x2x2_stages_dsmem_kernel< \ + WMMA_M, WMMA_N, WMMA_K, WMMA_TILE_M, WMMA_TILE_N, \ + WARP_TILE_M, WARP_TILE_N, WARP_TILE_K, \ + A_PAD, B_PAD, (stages), true>, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + 98304); \ + const int N_SWIZZLE = (N + (stride) - 1) / (stride); \ + dim3 block(NUM_THREADS); \ + dim3 grid((div_ceil(N, BN) + N_SWIZZLE - 1) / N_SWIZZLE, \ + div_ceil(M, BM), \ + N_SWIZZLE); \ + hgemm_wmma_m16n16k16_mma4x4_warp2x2x2_stages_dsmem_kernel< \ + WMMA_M, WMMA_N, WMMA_K, WMMA_TILE_M, WMMA_TILE_N, \ + WARP_TILE_M, WARP_TILE_N, WARP_TILE_K, \ + A_PAD, B_PAD, (stages), true><<< \ + grid, block, smem_max_size>>>( \ + reinterpret_cast(a.data_ptr()), \ + reinterpret_cast(b.data_ptr()), \ + reinterpret_cast(c.data_ptr()), \ + M, N, K \ + ); \ +} + +#define LAUNCH_161616_STAGE_NO_SWIZZLE_DSMEM_K32_MMA4x4_KERNEL(stages) \ +{ \ + const int smem_max_size = ( \ + (stages) * BM * (BK + A_PAD) * sizeof(half) + \ + (stages) * BK * (BN + B_PAD) * sizeof(half)); \ + cudaFuncSetAttribute( \ + hgemm_wmma_m16n16k16_mma4x4_warp2x2x2_stages_dsmem_kernel< \ + WMMA_M, WMMA_N, WMMA_K, WMMA_TILE_M, WMMA_TILE_N, \ + WARP_TILE_M, WARP_TILE_N, WARP_TILE_K, \ + A_PAD, B_PAD, (stages), false>, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + 98304); \ + dim3 block(NUM_THREADS); \ + dim3 grid(div_ceil(N, BN), div_ceil(M, BM)); \ + hgemm_wmma_m16n16k16_mma4x4_warp2x2x2_stages_dsmem_kernel< \ + WMMA_M, WMMA_N, WMMA_K, WMMA_TILE_M, WMMA_TILE_N, \ + WARP_TILE_M, WARP_TILE_N, WARP_TILE_K, \ + A_PAD, B_PAD, (stages), false><<< \ + grid, block, smem_max_size>>>( \ + reinterpret_cast(a.data_ptr()), \ + reinterpret_cast(b.data_ptr()), \ + reinterpret_cast(c.data_ptr()), \ + M, N, K \ + ); \ +} + +void hgemm_wmma_m16n16k16_mma4x4_warp2x2x2_stages_dsmem( + torch::Tensor a, torch::Tensor b, torch::Tensor c, + int stages, bool swizzle, int swizzle_stride) { + CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf) + CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf) + CHECK_TORCH_TENSOR_DTYPE(c, torch::kHalf) + const int M = a.size(0); + const int K = a.size(1); + const int N = b.size(1); + CHECK_TORCH_TENSOR_SHAPE(a, M, K) + CHECK_TORCH_TENSOR_SHAPE(b, K, N) + CHECK_TORCH_TENSOR_SHAPE(c, M, N) + constexpr int WMMA_M = 16; + constexpr int WMMA_N = 16; + constexpr int WMMA_K = 16; + constexpr int WMMA_TILE_M = 4; + constexpr int WMMA_TILE_N = 4; + constexpr int WARP_TILE_M = 2; + constexpr int WARP_TILE_N = 2; + constexpr int WARP_TILE_K = 2; + // s_a 4 ways bank conflicts within warp, after pad 8 -> 4 ways bank conflicts. + // s_b 16 ways bank conflicts within warp, after pad 8 -> 8 ways bank conflicts. + // s_b 16 ways bank conflicts within warp, after pad 16 -> 4 ways bank conflicts. + // so, the best padding policy for s_a and s_b is A_PAD=0/8, B_PAD=16. Thus, + // improve B_PAD consume 8x~ less smem than A_PAD, 16xB_PAD vs 128xA_PAD. + constexpr int A_PAD = 0; // 0,8,16 + constexpr int B_PAD = 16; // 0,8,16 + constexpr int NUM_THREADS= ( + WMMA_TILE_M * WMMA_TILE_N * WARP_SIZE); // 4 * 4 * 32 = 512 + constexpr int BM = WMMA_M * WMMA_TILE_M * WARP_TILE_M; + constexpr int BN = WMMA_N * WMMA_TILE_N * WARP_TILE_N; + constexpr int BK = WMMA_K * WARP_TILE_K; + + if (swizzle) { + assert(swizzle_stride % 256 == 0); + switch (stages) + { + case 2: + LAUNCH_161616_STAGE_SWIZZLE_DSMEM_K32_MMA4x4_KERNEL(2, swizzle_stride); + break; + case 3: + LAUNCH_161616_STAGE_SWIZZLE_DSMEM_K32_MMA4x4_KERNEL(3, swizzle_stride); + break; + case 4: + LAUNCH_161616_STAGE_SWIZZLE_DSMEM_K32_MMA4x4_KERNEL(4, swizzle_stride); + break; + case 5: + LAUNCH_161616_STAGE_SWIZZLE_DSMEM_K32_MMA4x4_KERNEL(5, swizzle_stride); + break; + default: + LAUNCH_161616_STAGE_SWIZZLE_DSMEM_K32_MMA4x4_KERNEL(2, swizzle_stride); + break; + } + } else { + switch (stages) + { + case 2: + LAUNCH_161616_STAGE_NO_SWIZZLE_DSMEM_K32_MMA4x4_KERNEL(2); + break; + case 3: + LAUNCH_161616_STAGE_NO_SWIZZLE_DSMEM_K32_MMA4x4_KERNEL(3); + break; + case 4: + LAUNCH_161616_STAGE_NO_SWIZZLE_DSMEM_K32_MMA4x4_KERNEL(4); + break; + case 5: + LAUNCH_161616_STAGE_NO_SWIZZLE_DSMEM_K32_MMA4x4_KERNEL(5); + break; + default: + LAUNCH_161616_STAGE_NO_SWIZZLE_DSMEM_K32_MMA4x4_KERNEL(2); + break; + } + } +} + +// 256x128 warp4x4(64,64,16) w dynamic smem, 98304=96KB < Ampere, Ada, Hopper ... +#define LAUNCH_161616_STAGE_SWIZZLE_DSMEM_WARP4X4_KERNEL(stages, stride) \ +{ \ + const int smem_max_size = ( \ + (stages) * BM * (BK + A_PAD) * sizeof(half) + \ + (stages) * BK * (BN + B_PAD) * sizeof(half)); \ + cudaFuncSetAttribute( \ + hgemm_wmma_m16n16k16_mma4x2_warp4x4_stages_dsmem_kernel< \ + WMMA_M, WMMA_N, WMMA_K, WMMA_TILE_M, WMMA_TILE_N, \ + WARP_TILE_M, WARP_TILE_N, WARP_TILE_K, \ + A_PAD, B_PAD, (stages), true>, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + 98304); \ + const int N_SWIZZLE = (N + (stride) - 1) / (stride); \ + dim3 block(NUM_THREADS); \ + dim3 grid((div_ceil(N, BN) + N_SWIZZLE - 1) / N_SWIZZLE, \ + div_ceil(M, BM), \ + N_SWIZZLE); \ + hgemm_wmma_m16n16k16_mma4x2_warp4x4_stages_dsmem_kernel< \ + WMMA_M, WMMA_N, WMMA_K, WMMA_TILE_M, WMMA_TILE_N, \ + WARP_TILE_M, WARP_TILE_N, WARP_TILE_K, \ + A_PAD, B_PAD, (stages), true><<< \ + grid, block, smem_max_size>>>( \ + reinterpret_cast(a.data_ptr()), \ + reinterpret_cast(b.data_ptr()), \ + reinterpret_cast(c.data_ptr()), \ + M, N, K \ + ); \ +} + +#define LAUNCH_161616_STAGE_NO_SWIZZLE_DSMEM_WARP4x4_KERNEL(stages) \ +{ \ + const int smem_max_size = ( \ + (stages) * BM * (BK + A_PAD) * sizeof(half) + \ + (stages) * BK * (BN + B_PAD) * sizeof(half)); \ + cudaFuncSetAttribute( \ + hgemm_wmma_m16n16k16_mma4x2_warp4x4_stages_dsmem_kernel< \ + WMMA_M, WMMA_N, WMMA_K, WMMA_TILE_M, WMMA_TILE_N, \ + WARP_TILE_M, WARP_TILE_N, WARP_TILE_K, \ + A_PAD, B_PAD, (stages), false>, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + 98304); \ + dim3 block(NUM_THREADS); \ + dim3 grid(div_ceil(N, BN), div_ceil(M, BM)); \ + hgemm_wmma_m16n16k16_mma4x2_warp4x4_stages_dsmem_kernel< \ + WMMA_M, WMMA_N, WMMA_K, WMMA_TILE_M, WMMA_TILE_N, \ + WARP_TILE_M, WARP_TILE_N, WARP_TILE_K, \ + A_PAD, B_PAD, (stages), false><<< \ + grid, block, smem_max_size>>>( \ + reinterpret_cast(a.data_ptr()), \ + reinterpret_cast(b.data_ptr()), \ + reinterpret_cast(c.data_ptr()), \ + M, N, K \ + ); \ +} + +void hgemm_wmma_m16n16k16_mma4x2_warp4x4_stages_dsmem( + torch::Tensor a, torch::Tensor b, torch::Tensor c, + int stages, bool swizzle, int swizzle_stride) { + CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf) + CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf) + CHECK_TORCH_TENSOR_DTYPE(c, torch::kHalf) + const int M = a.size(0); + const int K = a.size(1); + const int N = b.size(1); + CHECK_TORCH_TENSOR_SHAPE(a, M, K) + CHECK_TORCH_TENSOR_SHAPE(b, K, N) + CHECK_TORCH_TENSOR_SHAPE(c, M, N) + constexpr int WMMA_M = 16; + constexpr int WMMA_N = 16; + constexpr int WMMA_K = 16; + constexpr int WMMA_TILE_M = 4; + constexpr int WMMA_TILE_N = 2; + constexpr int WARP_TILE_M = 4; + constexpr int WARP_TILE_N = 4; + constexpr int WARP_TILE_K = 1; + // s_a 4 ways bank conflicts within warp, after pad 8 -> 4 ways bank conflicts. + // s_b 16 ways bank conflicts within warp, after pad 8 -> 8 ways bank conflicts. + // s_b 16 ways bank conflicts within warp, after pad 16 -> 4 ways bank conflicts. + // so, the best padding policy for s_a and s_b is A_PAD=0/8, B_PAD=16. Thus, + // improve B_PAD consume 8x~ less smem than A_PAD, 16xB_PAD vs 128xA_PAD. + constexpr int A_PAD = 0; // 0,8,16 + constexpr int B_PAD = 16; // 0,8,16 + constexpr int NUM_THREADS= ( + WMMA_TILE_M * WMMA_TILE_N * WARP_SIZE); // 4 * 2 * 32 = 256 + constexpr int BM = WMMA_M * WMMA_TILE_M * WARP_TILE_M; + constexpr int BN = WMMA_N * WMMA_TILE_N * WARP_TILE_N; + constexpr int BK = WMMA_K * WARP_TILE_K; + + if (swizzle) { + assert(swizzle_stride % 256 == 0); + switch (stages) + { + case 2: + LAUNCH_161616_STAGE_SWIZZLE_DSMEM_WARP4X4_KERNEL(2, swizzle_stride); + break; + case 3: + LAUNCH_161616_STAGE_SWIZZLE_DSMEM_WARP4X4_KERNEL(3, swizzle_stride); + break; + case 4: + LAUNCH_161616_STAGE_SWIZZLE_DSMEM_WARP4X4_KERNEL(4, swizzle_stride); + break; + case 5: + LAUNCH_161616_STAGE_SWIZZLE_DSMEM_WARP4X4_KERNEL(5, swizzle_stride); + break; + default: + LAUNCH_161616_STAGE_SWIZZLE_DSMEM_WARP4X4_KERNEL(2, swizzle_stride); + break; + } + } else { + switch (stages) + { + case 2: + LAUNCH_161616_STAGE_NO_SWIZZLE_DSMEM_WARP4x4_KERNEL(2); + break; + case 3: + LAUNCH_161616_STAGE_NO_SWIZZLE_DSMEM_WARP4x4_KERNEL(3); + break; + case 4: + LAUNCH_161616_STAGE_NO_SWIZZLE_DSMEM_WARP4x4_KERNEL(4); + break; + case 5: + LAUNCH_161616_STAGE_NO_SWIZZLE_DSMEM_WARP4x4_KERNEL(5); + break; + default: + LAUNCH_161616_STAGE_NO_SWIZZLE_DSMEM_WARP4x4_KERNEL(2); + break; + } + } +} \ No newline at end of file