From 523a610574c4464d508b81a3338453b25cc0dda7 Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Sat, 19 Oct 2024 10:57:38 +0800 Subject: [PATCH] [Docs] rename mat_transpose -> mat-transpose (#93) * Update sgemm_wmma_tf32_stage.cu * Update sgemm.py * Update README.md * Update sgemm_wmma_tf32_stage.cu * Update hgemm_wmma_stage.cu * Update hgemm.cu * Update hgemm.py * Update hgemm.py * rename mat_transpose->mat-transpose * update hgemm benchmark * update hgemm benchmark --- README.md | 10 +- hgemm/hgemm.cu | 3 + hgemm/hgemm.py | 78 +- hgemm/hgemm_wmma_stage.cu | 399 ++++++++- {mat_transpose => mat-transpose}/.gitignore | 0 {mat_transpose => mat-transpose}/README.md | 0 .../mat_transpose.cu | 0 .../mat_transpose.py | 0 sgemm/README.md | 783 +++++++++--------- sgemm/sgemm.py | 8 +- sgemm/sgemm_wmma_tf32_stage.cu | 163 ++-- 11 files changed, 922 insertions(+), 522 deletions(-) rename {mat_transpose => mat-transpose}/.gitignore (100%) rename {mat_transpose => mat-transpose}/README.md (100%) rename {mat_transpose => mat-transpose}/mat_transpose.cu (100%) rename {mat_transpose => mat-transpose}/mat_transpose.py (100%) diff --git a/README.md b/README.md index 31509fda..ecd58b35 100644 --- a/README.md +++ b/README.md @@ -62,11 +62,11 @@ | ✔️ [embedding_f16x2](./embedding/embedding.cu)|f16|/|[link](./embedding/)|⭐️| | ✔️ [embedding_f16x8](./embedding/embedding.cu)|f16|/|[link](./embedding/)|⭐️| | ✔️ [embedding_f16x8_pack](./embedding/embedding.cu)|f16|/|[link](./embedding/)|⭐️⭐️| -| ✔️ [mat_trans_f32_col2row{2d}](./mat_transpose/mat_transpose.cu)|f32|/|[link](./mat_transpose/)|⭐️| -| ✔️ [mat_trans_f32_row2col{2d}](./mat_transpose/mat_transpose.cu)|f32|/|[link](./mat_transpose/)|⭐️| -| ✔️ [mat_trans_f32_diagonal2d](./mat_transpose/mat_transpose.cu)|f32|/|[link](./mat_transpose/)|⭐️⭐️| -| ✔️ [mat_trans_f32x4_col2row{2d}](./mat_transpose/mat_transpose.cu)|f32|/|[link](./mat_transpose/)|⭐️⭐️| -| ✔️ [mat_trans_f32x4_row2col{2d}](./mat_transpose/mat_transpose.cu)|f32|/|[link](./mat_transpose/)|⭐️⭐️| +| ✔️ [mat_trans_f32_col2row{2d}](./mat-transpose/mat_transpose.cu)|f32|/|[link](./mat-transpose/)|⭐️| +| ✔️ [mat_trans_f32_row2col{2d}](./mat-transpose/mat_transpose.cu)|f32|/|[link](./mat-transpose/)|⭐️| +| ✔️ [mat_trans_f32_diagonal2d](./mat-transpose/mat_transpose.cu)|f32|/|[link](./mat-transpose/)|⭐️⭐️| +| ✔️ [mat_trans_f32x4_col2row{2d}](./mat-transpose/mat_transpose.cu)|f32|/|[link](./mat-transpose/)|⭐️⭐️| +| ✔️ [mat_trans_f32x4_row2col{2d}](./mat-transpose/mat_transpose.cu)|f32|/|[link](./mat-transpose/)|⭐️⭐️| | ✔️ [warp_reduce_[all]](./reduce/reduce.cu)|all|all|[link](./reduce/)|⭐️⭐️| | ✔️ [reduce_f32_f32](./reduce/reduce.cu)|f32|f32|[link](./reduce/)|⭐️⭐️| | ✔️ [reduce_f32x4_f32](./reduce/reduce.cu)|f32|f32|[link](./reduce/)|⭐️⭐️| diff --git a/hgemm/hgemm.cu b/hgemm/hgemm.cu index 16da85c7..f46cade8 100644 --- a/hgemm/hgemm.cu +++ b/hgemm/hgemm.cu @@ -1237,6 +1237,8 @@ void hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem(torch::Tensor a, torch::Te int stages, bool swizzle, int swizzle_stride); void hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); +void hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, + int stages, bool swizzle, int swizzle_stride); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { @@ -1284,5 +1286,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages) TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem) TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem) + TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_stages_dsmem) } diff --git a/hgemm/hgemm.py b/hgemm/hgemm.py index 473c3a7d..21c0f4a9 100644 --- a/hgemm/hgemm.py +++ b/hgemm/hgemm.py @@ -3,6 +3,7 @@ from torch.utils.cpp_extension import load from functools import partial from typing import Optional +import argparse torch.set_grad_enabled(False) @@ -95,15 +96,24 @@ def run_benchmark(perf_func: callable, else: improve = 0 MAX_TFLOPS = TFLOPS - print(f"{out_info:>40}: {out_val}, time:{mean_time}ms, " + print(f"{out_info:>35}: {out_val}, time:{mean_time}ms, " f"swizzle: {swizzle_stride:<4}, TFLOPS: {TFLOPS:<6.2f}(+{improve:.2f}%)") else: - print(f"{out_info:>40}: {out_val}, time:{mean_time}ms, " + print(f"{out_info:>35}: {out_val}, time:{mean_time}ms, " f"swizzle: {swizzle_stride:<4}, TFLOPS: {TFLOPS:<6.2f}") if show_all: print(out) return out, mean_time +def get_args(): + parser = argparse.ArgumentParser(description="hgemm benchmark") + parser.add_argument("--enable-mma-all", "-ma", action="store_true") + parser.add_argument("--enable-wmma-all", "-wa", action="store_true") + parser.add_argument("--enable-cuda-all", "-ca", action="store_true") + return parser.parse_args() + + +args = get_args() Ms = [4096, 8192, 16384] Ns = [4096, 8192, 16384] Ks = [2048, 4096, 8192] @@ -124,44 +134,44 @@ def run_benchmark(perf_func: callable, c = C[:M, :N].contiguous() torch.cuda.synchronize() - # CUDA Cores FP16 - # run_benchmark(lib.hgemm_naive_f16, a, b, "f16(naive)", c) - # run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x8_pack_bcf, a, b, "f16x8pack(t8x8+bcf)", c) + if args.enable_cuda_all: + # CUDA Cores FP16 + run_benchmark(lib.hgemm_naive_f16, a, b, "f16(naive)", c) + run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x8_pack_bcf, a, b, "f16x8pack(t8x8+bcf)", c) + run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x8_pack_bcf_dbuf, a, b, "f16x8pack(t8x8+dbuf)", c) run_benchmark(lib.hgemm_t_8x8_sliced_k16_f16x8_pack_dbuf, a, b, "f16x8pack(t8x8+k16+dbuf)", c) print("-" * 68 + "WMMA" + "-" * 58) - # run_benchmark(lib.hgemm_wmma_m16n16k16_naive, a, b, "f16wmma(naive)", c) - run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2, a, b, "f16wmma(mma4x2)", c) - run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4, a, b, "f16wmma(mma4x2+warp2x4)", c) - # run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_dbuf_async_offset, a, b, "f16wmma(mma2x4+warp2x4+dbuf)", c) - - # Stages, dsmem - # run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "f16wmma(mma2x4+warp2x4+stage4)", c, stages=4) - run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "f16wmma(mma2x4+warp2x4+stage3)", c, stages=3) - run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "f16wmma(mma2x4+warp2x4+stage2)", c, stages=2) - - # run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "f16wmma(mma2x4+...+stage4+dsmem)", c, stages=4) - run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "f16wmma(mma2x4+...+stage3+dsmem)", c, stages=3) - run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "f16wmma(mma2x4+...+stage2+dsmem)", c, stages=2) - - # run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "f16wmma(mma4x4+...+stage4+dsmem)", c, stages=4) - # run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "f16wmma(mma4x4+...+stage3+dsmem)", c, stages=3) - # run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "f16wmma(mma4x4+...+stage2+dsmem)", c, stages=2) + # wmma api, stages, dsmem, swizzle + run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2, a, b, "(mma4x2)", c) + run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4, a, b, "(mma4x2+warp2x4)", c) - # Thread block swizzle - # run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "f16wmma(mma2x4+...+stage4+swizzle)", c, stages=4, swizzle=True) - run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "f16wmma(mma2x4+...+stage3+swizzle)", c, stages=3, swizzle=True) - run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "f16wmma(mma2x4+...+stage2+swizzle)", c, stages=2, swizzle=True) - - # run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "f16wmma(...+stage4+dsmem+swizzle)", c, stages=4, swizzle=True) - run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "f16wmma(...+stage3+dsmem+swizzle)", c, stages=3, swizzle=True) - run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "f16wmma(...+stage2+dsmem+swizzle)", c, stages=2, swizzle=True) - - # run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "f16wmma(mma4x4+stage4+dsmem+swizzle)", c, stages=4, swizzle=True) - # run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "f16wmma(mma4x4+stage3+dsmem+swizzle)", c, stages=3, swizzle=True) - # run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "f16wmma(mma4x4+stage2+dsmem+swizzle)", c, stages=2, swizzle=True) + # prefer on NVIDIA L20 device. + run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "(mma2x4+warp2x4+stage3)", c, stages=3) + run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "(mma2x4+warp2x4+stage2)", c, stages=2) + + run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "(mma2x4+...+stage3+dsmem)", c, stages=3) + run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "(mma2x4+...+stage2+dsmem)", c, stages=2) + + run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "(mma2x4+...+stage3+swizzle)", c, stages=3, swizzle=True) + run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "(mma2x4+...+stage2+swizzle)", c, stages=2, swizzle=True) + + run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "(...+stage3+dsmem+swizzle)", c, stages=3, swizzle=True) + run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "(...+stage2+dsmem+swizzle)", c, stages=2, swizzle=True) + if args.enable_wmma_all: + # prefer on NVIDIA TRX 3080 Laptop 16GB GDDR6 device. + run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "(mma4x4+...+stage3+dsmem)", c, stages=3) + run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "(mma4x4+...+stage2+dsmem)", c, stages=2) + run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_stages_dsmem, a, b, "(warp2x4x2+...+stage3+dsmem)", c, stages=3) + run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_stages_dsmem, a, b, "(warp2x4x2+...+stage2+dsmem)", c, stages=2) + + run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "(mma4x4+stage3+dsmem+swizzle)", c, stages=3, swizzle=True) + run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "(mma4x4+stage2+dsmem+swizzle)", c, stages=2, swizzle=True) + run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_stages_dsmem, a, b, "(warp2x4x2+stage3+dsmem+swizzle)", c, stages=3, swizzle=True) + run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_stages_dsmem, a, b, "(warp2x4x2+stage2+dsmem+swizzle)", c, stages=2, swizzle=True) + run_benchmark(lib.hgemm_cublas_tensor_op, a, b, "f16(cublas)", c) run_benchmark(partial(torch.matmul, out=c), a, b, "f16_th") torch.cuda.synchronize() diff --git a/hgemm/hgemm_wmma_stage.cu b/hgemm/hgemm_wmma_stage.cu index 3e5319c8..cb635439 100644 --- a/hgemm/hgemm_wmma_stage.cu +++ b/hgemm/hgemm_wmma_stage.cu @@ -195,12 +195,13 @@ hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_kernel( CP_ASYNC_WAIT_GROUP(K_STAGE-2); __syncthreads(); } - + // make sure all memory issues ready. if ((K_STAGE - 2) > 0) { CP_ASYNC_WAIT_GROUP(0); __syncthreads(); } + // processing last (K_STAGE-1) k iters. { #pragma unroll @@ -266,7 +267,7 @@ template + const bool BLOCK_SWIZZLE=false> __global__ void __launch_bounds__(256) hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem_kernel( half* A, half* B, half* C, int M, int N, int K) { @@ -419,12 +420,13 @@ hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem_kernel( CP_ASYNC_WAIT_GROUP(K_STAGE-2); __syncthreads(); } - + // make sure all memory issues ready. if ((K_STAGE - 2) > 0) { CP_ASYNC_WAIT_GROUP(0); __syncthreads(); } + // processing last (K_STAGE-1) k iters. { #pragma unroll @@ -491,7 +493,7 @@ template + const bool BLOCK_SWIZZLE=false> __global__ void __launch_bounds__(512) hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem_kernel( half* A, half* B, half* C, int M, int N, int K) { @@ -640,12 +642,13 @@ hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem_kernel( CP_ASYNC_WAIT_GROUP(K_STAGE-2); __syncthreads(); } - + // make sure all memory issues ready. if ((K_STAGE - 2) > 0) { CP_ASYNC_WAIT_GROUP(0); __syncthreads(); } + // processing last (K_STAGE-1) k iters. { #pragma unroll @@ -699,7 +702,260 @@ hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem_kernel( } } -// TODO: K32 ? +// 128x128, Stages + K32 + Reg Buffers? +// stage2/3/4 (stage2=double buffers+copy async) +// 1. When using shared memory exceeds 48 KB, dynamic shared memory needs to be used, +// i.e., declare a block of dynamic shared memory with extern shared half smem[];. +// When calling the kernel, the size of the dynamic shared memory needs to be specified, +// and smem addressing should be used in a one-dimensional array manner. +// 2. Improve L2 Cache locality (Thread Block Swizzle): https://zhuanlan.zhihu.com/p/555339335 +// 3. __launch_bounds__: avoid error 'too many resources required for launch' +// reference: https://blog.csdn.net/feng__shuai/article/details/124395023 +template +__global__ void __launch_bounds__(256) +hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_stages_dsmem_kernel( + half* A, half* B, half* C, int M, int N, int K) { + // 256 threads(8 warps) per block. + // const int bx = blockIdx.x; + // BLOCK_SWIZZLE 0/1 control use block swizzle or not. + const int bx = ((int) BLOCK_SWIZZLE) * blockIdx.z * gridDim.x + blockIdx.x; + const int by = blockIdx.y; + const int NUM_K_TILES = div_ceil(K, WMMA_K * WARP_TILE_K); + constexpr int BM = WMMA_M * WMMA_TILE_M * WARP_TILE_M; // 16x4*2=128 + constexpr int BN = WMMA_N * WMMA_TILE_N * WARP_TILE_N; // 16x2*4=128 + constexpr int BK = WMMA_K * WARP_TILE_K; // 16*2=32 + // s2: 2*128*(32)*2=16KB, 2*32*(128+16)*2=18KB, ~42KB + // s3: 3*128*(32)*2=24KB, 3*32*(128+16)*2=27KB, ~51KB + // s4: 4*128*(32)*2=32KB, 4*32*(128+16)*2=36KB, ~68KB + // s4: 5*128*(32)*2=40KB, 5*32*(128+16)*2=45KB, ~85KB + extern __shared__ half smem[]; + half* s_a = smem; + half* s_b = smem + K_STAGE * BM * (BK + A_PAD); + constexpr int s_a_stage_offset = BM * (BK + A_PAD); + constexpr int s_b_stage_offset = BK * (BN + B_PAD); + + // 要保证相同的warp下thread执行相同的指令 + const int tid = threadIdx.y * blockDim.x + threadIdx.x; + const int warp_id = tid / WARP_SIZE; // 0~7 warp_id within block + const int warp_m = warp_id / 2; // 0,1,2,3 + const int warp_n = warp_id % 2; // 0,1 + + // 先计算shared memory中的索引 + // tid和需要加载的smem s_a[BM][BK] 之间的索引关系 BM=128 BK=32 按行读取 A行主序 + // 对于s_a每行32个数据,每个线程读取16个,需要2个线程;总共128行,需要128x2刚好256线程 + int load_smem_a_m = tid / 2; // row 0~127 + int load_smem_a_k = (tid % 2 == 0) ? 0 : 16; // col 0,16 + // tid和需要加载的smem s_b[BK][BN] 之间的索引关系 BK=32 BN=128 按行读取 B行主序 + // 对于s_b每行128个数据,每个线程读16个数据,需要8个线程;总共32行,需要32x16=256个线程 + int load_smem_b_k = tid / 8; // row 0~31 + int load_smem_b_n = (tid % 8) * 16; // col 0,16,...,127 + // 再计算全局内存中的索引 + // 要加载到s_a中的元素对应到A全局内存中的行数 每个block负责出C中大小为BM*BN的块 + int load_gmem_a_m = by * BM + load_smem_a_m; // global row of a and c + int load_gmem_b_n = bx * BN + load_smem_b_n; // global col of b and c + + wmma::fragment + C_frag[WARP_TILE_M][WARP_TILE_N]; + + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + wmma::fill_fragment(C_frag[i][j], 0.0); + } + } + + // only cvta smem base ptr once for cp.async. + uint32_t smem_a_base_ptr = __cvta_generic_to_shared(s_a); + uint32_t smem_b_base_ptr = __cvta_generic_to_shared(s_b); + + #pragma unroll + for (int k = 0; k < (K_STAGE - 1); ++k) { // 0, 1 + // k * WMMA_K, WMMA_K=16 -> (k << 4) + int load_gmem_a_k = k * (WMMA_K * WARP_TILE_K) + load_smem_a_k; // global col of a + int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k; + int load_gmem_b_k = k * (WMMA_K * WARP_TILE_K) + load_smem_b_k; // global row of b + int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n; + + uint32_t load_smem_a_ptr = ( + smem_a_base_ptr + (k * s_a_stage_offset + + load_smem_a_m * (BK + A_PAD) + + load_smem_a_k) * sizeof(half) + ); + + uint32_t load_smem_b_ptr = ( + smem_b_base_ptr + (k * s_b_stage_offset + + load_smem_b_k * (BN + B_PAD) + + load_smem_b_n) * sizeof(half) + ); + // first part + CP_ASYNC_CG(load_smem_a_ptr, &A[load_gmem_a_addr], 16); + CP_ASYNC_CG(load_smem_b_ptr, &B[load_gmem_b_addr], 16); + CP_ASYNC_COMMIT_GROUP(); + // second part + CP_ASYNC_CG(load_smem_a_ptr + 16, &A[load_gmem_a_addr + 8], 16); + CP_ASYNC_CG(load_smem_b_ptr + 16, &B[load_gmem_b_addr + 8], 16); + + CP_ASYNC_COMMIT_GROUP(); + } + + CP_ASYNC_WAIT_GROUP(K_STAGE-2); // s2->0, s3->1, s4->2 + __syncthreads(); + + #pragma unroll + for (int k = (K_STAGE - 1); k < NUM_K_TILES; k++) { + // s2/4 can use bitwise ops but s3 can not, so, we use mod + // ops for all stages kernel. s2: (k + 1)&1, s4: (k + 1)&3 + // s3: (k + 1) % 3 + int smem_sel = (k + 1) % K_STAGE; // s3 k 2->0, k 3->1, k 4->2... + int smem_sel_next = k % K_STAGE; // s3 k 2->2, k 3->0, k 4->1... + + // k * WMMA_K, WMMA_K=16 -> (k << 4) + int load_gmem_a_k = k * (WMMA_K * WARP_TILE_K) + load_smem_a_k; // global col of a + int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k; + int load_gmem_b_k = k * (WMMA_K * WARP_TILE_K) + load_smem_b_k; // global row of b + int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n; + + // load stage 2, k start from 2 + uint32_t load_smem_a_ptr = ( + smem_a_base_ptr + (smem_sel_next * s_a_stage_offset + + load_smem_a_m * (BK + A_PAD) + + load_smem_a_k) * sizeof(half) + ); + + uint32_t load_smem_b_ptr = ( + smem_b_base_ptr + (smem_sel_next * s_b_stage_offset + + load_smem_b_k * (BN + B_PAD) + + load_smem_b_n) * sizeof(half) + ); + + // first part + CP_ASYNC_CG(load_smem_a_ptr, &A[load_gmem_a_addr], 16); + CP_ASYNC_CG(load_smem_b_ptr, &B[load_gmem_b_addr], 16); + CP_ASYNC_COMMIT_GROUP(); + // second part + CP_ASYNC_CG(load_smem_a_ptr + 16, &A[load_gmem_a_addr + 8], 16); + CP_ASYNC_CG(load_smem_b_ptr + 16, &B[load_gmem_b_addr + 8], 16); + CP_ASYNC_COMMIT_GROUP(); + + // WARP_TILE_K=2 + for (int warp_k = 0; warp_k < WARP_TILE_K; ++warp_k) { + wmma::fragment A_frag[WARP_TILE_M]; + wmma::fragment B_frag[WARP_TILE_N]; + const int warp_smem_k = warp_k * WMMA_K; // 0,16 + + // compute stage 0 + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + // load 2 tiles -> reg, smem a -> frags a, warp_m 0~3 + int warp_smem_a_m = warp_m * (WMMA_M * WARP_TILE_M) + i * WMMA_M; + half* load_smem_a_frag_ptr = (s_a + smem_sel * s_a_stage_offset + + warp_smem_a_m * (BK + A_PAD) + + warp_smem_k); + wmma::load_matrix_sync(A_frag[i], load_smem_a_frag_ptr, BK + A_PAD); + } + + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + // load 4 tiles -> reg, smem b -> frags b, warp_n 0~2 + int warp_smem_b_n = warp_n * (WMMA_N * WARP_TILE_N) + j * WMMA_N; + half* load_smem_b_frag_ptr = (s_b + smem_sel * s_b_stage_offset + + warp_smem_k * (BN + B_PAD) + + warp_smem_b_n); + wmma::load_matrix_sync(B_frag[j], load_smem_b_frag_ptr, BN + B_PAD); + } + + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + wmma::mma_sync(C_frag[i][j], A_frag[i], B_frag[j], C_frag[i][j]); + } + } + } + + CP_ASYNC_WAIT_GROUP(K_STAGE-2); + __syncthreads(); + } + + // make sure all memory issues ready. + if ((K_STAGE - 2) > 0) { + CP_ASYNC_WAIT_GROUP(0); + __syncthreads(); + } + + // processing last (K_STAGE-1) k iters. + { + #pragma unroll + for (int k = 0; k < (K_STAGE - 1); k++) { + const int stage_sel = ((NUM_K_TILES - (K_STAGE - 1) + k) % K_STAGE); + + #pragma unroll + for (int warp_k = 0; warp_k < WARP_TILE_K; ++warp_k) { + wmma::fragment A_frag[WARP_TILE_M]; + wmma::fragment B_frag[WARP_TILE_N]; + const int warp_smem_k = warp_k * WMMA_K; // 0,16 + + // compute stage 0 + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + // load 2 tiles -> reg, smem a -> frags a, warp_m 0~3 + int warp_smem_a_m = warp_m * (WMMA_M * WARP_TILE_M) + i * WMMA_M; + half* load_smem_a_frag_ptr = (s_a + stage_sel * s_a_stage_offset + + warp_smem_a_m * (BK + A_PAD) + + warp_smem_k); + wmma::load_matrix_sync(A_frag[i], load_smem_a_frag_ptr, BK + A_PAD); + } + + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + // load 4 tiles -> reg, smem b -> frags b, warp_n 0~2 + int warp_smem_b_n = warp_n * (WMMA_N * WARP_TILE_N) + j * WMMA_N; + half* load_smem_b_frag_ptr = (s_b + stage_sel * s_b_stage_offset + + warp_smem_k * (BN + B_PAD) + + warp_smem_b_n); + wmma::load_matrix_sync(B_frag[j], load_smem_b_frag_ptr, BN + B_PAD); + } + + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + wmma::mma_sync(C_frag[i][j], A_frag[i], B_frag[j], C_frag[i][j]); + } + } + } + } + } + + // finally, store back to C matrix. + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + const int store_gmem_a_m = by * BM + warp_m * (WMMA_M * WARP_TILE_M) + i * WMMA_M; + const int store_gmem_a_n = bx * BN + warp_n * (WMMA_N * WARP_TILE_N) + j * WMMA_N; + wmma::store_matrix_sync(C + store_gmem_a_m * N + store_gmem_a_n, C_frag[i][j], N, + wmma::mem_row_major); + } + } +} + // TODO: Warp swizzle/permute support ? (MMA, not WMMA) // --------------------- PyTorch bindings for custom kernel ----------------------- @@ -1086,3 +1342,134 @@ void hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem( } } } + +// 128x128 warp2x4x2 w dynamic smem, 98304=96KB < Ampere, Ada, Hopper ... +#define LAUNCH_161616_STAGE_SWIZZLE_DSMEM_K32_KERNEL(stages, stride)\ +{ \ + const int smem_max_size = ( \ + (stages) * BM * (BK + A_PAD) * sizeof(half) + \ + (stages) * BK * (BN + B_PAD) * sizeof(half)); \ + cudaFuncSetAttribute( \ + hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_stages_dsmem_kernel< \ + WMMA_M, WMMA_N, WMMA_K, WMMA_TILE_M, WMMA_TILE_N, \ + WARP_TILE_M, WARP_TILE_N, WARP_TILE_K, \ + A_PAD, B_PAD, (stages), true>, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + 98304); \ + const int N_SWIZZLE = (N + (stride) - 1) / (stride); \ + dim3 block(NUM_THREADS); \ + dim3 grid((div_ceil(N, BN) + N_SWIZZLE - 1) / N_SWIZZLE, \ + div_ceil(M, BM), \ + N_SWIZZLE); \ + hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_stages_dsmem_kernel< \ + WMMA_M, WMMA_N, WMMA_K, WMMA_TILE_M, WMMA_TILE_N, \ + WARP_TILE_M, WARP_TILE_N, WARP_TILE_K, \ + A_PAD, B_PAD, (stages), true><<< \ + grid, block, smem_max_size>>>( \ + reinterpret_cast(a.data_ptr()), \ + reinterpret_cast(b.data_ptr()), \ + reinterpret_cast(c.data_ptr()), \ + M, N, K \ + ); \ +} + +#define LAUNCH_161616_STAGE_NO_SWIZZLE_DSMEM_K32_KERNEL(stages)\ +{ \ + const int smem_max_size = ( \ + (stages) * BM * (BK + A_PAD) * sizeof(half) + \ + (stages) * BK * (BN + B_PAD) * sizeof(half)); \ + cudaFuncSetAttribute( \ + hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_stages_dsmem_kernel< \ + WMMA_M, WMMA_N, WMMA_K, WMMA_TILE_M, WMMA_TILE_N, \ + WARP_TILE_M, WARP_TILE_N, WARP_TILE_K, \ + A_PAD, B_PAD, (stages), false>, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + 98304); \ + dim3 block(NUM_THREADS); \ + dim3 grid(div_ceil(N, BN), div_ceil(M, BM)); \ + hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_stages_dsmem_kernel< \ + WMMA_M, WMMA_N, WMMA_K, WMMA_TILE_M, WMMA_TILE_N, \ + WARP_TILE_M, WARP_TILE_N, WARP_TILE_K, \ + A_PAD, B_PAD, (stages), false><<< \ + grid, block, smem_max_size>>>( \ + reinterpret_cast(a.data_ptr()), \ + reinterpret_cast(b.data_ptr()), \ + reinterpret_cast(c.data_ptr()), \ + M, N, K \ + ); \ +} + +void hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_stages_dsmem( + torch::Tensor a, torch::Tensor b, torch::Tensor c, + int stages, bool swizzle, int swizzle_stride) { + CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf) + CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf) + CHECK_TORCH_TENSOR_DTYPE(c, torch::kHalf) + const int M = a.size(0); + const int K = a.size(1); + const int N = b.size(1); + CHECK_TORCH_TENSOR_SHAPE(a, M, K) + CHECK_TORCH_TENSOR_SHAPE(b, K, N) + CHECK_TORCH_TENSOR_SHAPE(c, M, N) + constexpr int WMMA_M = 16; + constexpr int WMMA_N = 16; + constexpr int WMMA_K = 16; + constexpr int WMMA_TILE_M = 4; + constexpr int WMMA_TILE_N = 2; + constexpr int WARP_TILE_M = 2; + constexpr int WARP_TILE_N = 4; + constexpr int WARP_TILE_K = 2; + // s_a 4 ways bank conflicts within warp, after pad 8 -> 4 ways bank conflicts. + // s_b 16 ways bank conflicts within warp, after pad 8 -> 8 ways bank conflicts. + // s_b 16 ways bank conflicts within warp, after pad 16 -> 4 ways bank conflicts. + // so, the best padding policy for s_a and s_b is A_PAD=0/8, B_PAD=16. Thus, + // improve B_PAD consume 8x~ less smem than A_PAD, 16xB_PAD vs 128xA_PAD. + constexpr int A_PAD = 0; // 0,8,16 + constexpr int B_PAD = 16; // 0,8,16 + constexpr int NUM_THREADS= ( + WMMA_TILE_M * WMMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256 + constexpr int BM = WMMA_M * WMMA_TILE_M * WARP_TILE_M; + constexpr int BN = WMMA_N * WMMA_TILE_N * WARP_TILE_N; + constexpr int BK = WMMA_K * WARP_TILE_K; + + if (swizzle) { + assert(swizzle_stride % 256 == 0); + switch (stages) + { + case 2: + LAUNCH_161616_STAGE_SWIZZLE_DSMEM_K32_KERNEL(2, swizzle_stride); + break; + case 3: + LAUNCH_161616_STAGE_SWIZZLE_DSMEM_K32_KERNEL(3, swizzle_stride); + break; + case 4: + LAUNCH_161616_STAGE_SWIZZLE_DSMEM_K32_KERNEL(4, swizzle_stride); + break; + case 5: + LAUNCH_161616_STAGE_SWIZZLE_DSMEM_K32_KERNEL(5, swizzle_stride); + break; + default: + LAUNCH_161616_STAGE_SWIZZLE_DSMEM_K32_KERNEL(2, swizzle_stride); + break; + } + } else { + switch (stages) + { + case 2: + LAUNCH_161616_STAGE_NO_SWIZZLE_DSMEM_K32_KERNEL(2); + break; + case 3: + LAUNCH_161616_STAGE_NO_SWIZZLE_DSMEM_K32_KERNEL(3); + break; + case 4: + LAUNCH_161616_STAGE_NO_SWIZZLE_DSMEM_K32_KERNEL(4); + break; + case 5: + LAUNCH_161616_STAGE_NO_SWIZZLE_DSMEM_K32_KERNEL(5); + break; + default: + LAUNCH_161616_STAGE_NO_SWIZZLE_DSMEM_K32_KERNEL(2); + break; + } + } +} diff --git a/mat_transpose/.gitignore b/mat-transpose/.gitignore similarity index 100% rename from mat_transpose/.gitignore rename to mat-transpose/.gitignore diff --git a/mat_transpose/README.md b/mat-transpose/README.md similarity index 100% rename from mat_transpose/README.md rename to mat-transpose/README.md diff --git a/mat_transpose/mat_transpose.cu b/mat-transpose/mat_transpose.cu similarity index 100% rename from mat_transpose/mat_transpose.cu rename to mat-transpose/mat_transpose.cu diff --git a/mat_transpose/mat_transpose.py b/mat-transpose/mat_transpose.py similarity index 100% rename from mat_transpose/mat_transpose.py rename to mat-transpose/mat_transpose.py diff --git a/sgemm/README.md b/sgemm/README.md index cca40378..f7cca426 100755 --- a/sgemm/README.md +++ b/sgemm/README.md @@ -150,515 +150,488 @@ python3 sgemm.py ```bash ---------------------------------------------------------------------------------------------------------------------------------- M=4096, N=4096, K=2048 - out_f32(naive): ['-17.842391', '0.18722232'], time:20.40503ms, swizzle: NOOP, TFLOPS: 3.37 (+0.00%) - out_f32x4(t8x8sk): ['-17.842391', '0.18722232'], time:2.430105ms, swizzle: NOOP, TFLOPS: 28.28 (+739.68%) - out_f32x4(t8x8bcf): ['-17.842391', '0.18722232'], time:2.102661ms, swizzle: NOOP, TFLOPS: 32.68 (+15.57%) - out_f32x4(t8x8dbuf): ['-17.842391', '0.18722232'], time:1.985645ms, swizzle: NOOP, TFLOPS: 34.61 (+5.89%) - out_f32(cublas): ['-17.842391', '0.18722232'], time:2.087247ms, swizzle: NOOP, TFLOPS: 32.92 - out_f32_th: ['-17.842391', '0.18722232'], time:1.845526ms, swizzle: NOOP, TFLOPS: 37.24 (+7.59%) + out_f32x4(t8x8sk): ['70.6019897', '26.1625347'], time:2.428984ms, swizzle: NOOP, TFLOPS: 28.29 (+0.00%) + out_f32x4(t8x8bcf): ['70.6019897', '26.1625347'], time:2.112817ms, swizzle: NOOP, TFLOPS: 32.53 (+14.96%) + out_f32x4(t8x8dbuf): ['70.6019897', '26.1625347'], time:1.877713ms, swizzle: NOOP, TFLOPS: 36.60 (+12.52%) + out_f32(cublas): ['70.6019897', '26.1625347'], time:2.229022ms, swizzle: NOOP, TFLOPS: 30.83 + out_f32_th: ['70.6019897', '26.1625347'], time:1.778435ms, swizzle: NOOP, TFLOPS: 38.64 (+5.58%) --------------------------------------------------------------WMMA---------------------------------------------------------------- - out_tf32(mma2x4+warp2x4+stage3): ['-17.849796', '0.19758633'], time:2.149534ms, swizzle: NOOP, TFLOPS: 31.97 - out_tf32(mma2x4+warp2x4+stage2): ['-17.849796', '0.19758633'], time:1.663148ms, swizzle: NOOP, TFLOPS: 41.32 (+10.97%) - out_tf32(mma2x4+...+stage3+dsmem): ['-17.849796', '0.19758633'], time:1.823186ms, swizzle: NOOP, TFLOPS: 37.69 - out_tf32(mma2x4+...+stage2+dsmem): ['-17.849796', '0.19758633'], time:1.662409ms, swizzle: NOOP, TFLOPS: 41.34 (+0.04%) -out_tf32(mma2x4+...+stage3+swizzle): ['-17.849796', '0.19758633'], time:2.001917ms, swizzle: 1024, TFLOPS: 34.33 -out_tf32(mma2x4+...+stage2+swizzle): ['-17.849796', '0.19758633'], time:1.634597ms, swizzle: 1024, TFLOPS: 42.04 (+1.70%) - out_tf32(...+stage3+dsmem+swizzle): ['-17.849796', '0.19758633'], time:1.806831ms, swizzle: 1024, TFLOPS: 38.03 - out_tf32(...+stage2+dsmem+swizzle): ['-17.849796', '0.19758633'], time:1.635658ms, swizzle: 1024, TFLOPS: 42.01 - out_tf32(cublas+tf32): ['-17.849796', '0.19758633'], time:1.539385ms, swizzle: NOOP, TFLOPS: 44.64 (+6.19%) + out_tf32(mma2x4+warp2x4+stage3): ['70.5943985', '26.1725273'], time:2.035927ms, swizzle: NOOP, TFLOPS: 33.75 + out_tf32(mma2x4+warp2x4+stage2): ['70.5943985', '26.1725273'], time:1.670312ms, swizzle: NOOP, TFLOPS: 41.14 (+6.47%) + out_tf32(mma2x4+...+stage3+dsmem): ['70.5943985', '26.1725273'], time:1.820373ms, swizzle: NOOP, TFLOPS: 37.75 + out_tf32(mma2x4+...+stage2+dsmem): ['70.5943985', '26.1725273'], time:1.646137ms, swizzle: NOOP, TFLOPS: 41.75 (+1.47%) +out_tf32(mma2x4+...+stage3+swizzle): ['70.5943985', '26.1725273'], time:2.027678ms, swizzle: 512 , TFLOPS: 33.89 +out_tf32(mma2x4+...+stage2+swizzle): ['70.5943985', '26.1725273'], time:1.640319ms, swizzle: 512 , TFLOPS: 41.89 (+0.35%) + out_tf32(...+stage3+dsmem+swizzle): ['70.5943985', '26.1725273'], time:1.807355ms, swizzle: 512 , TFLOPS: 38.02 + out_tf32(...+stage2+dsmem+swizzle): ['70.5943985', '26.1725273'], time:1.627850ms, swizzle: 512 , TFLOPS: 42.21 (+0.77%) + out_tf32(cublas+tf32): ['70.5943985', '26.1725273'], time:7.086372ms, swizzle: NOOP, TFLOPS: 9.70 ---------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------------------------------------- M=4096, N=4096, K=4096 - out_f32(naive): ['-24.547933', '26.0833282'], time:40.70725ms, swizzle: NOOP, TFLOPS: 3.38 (+0.00%) - out_f32x4(t8x8sk): ['-24.547933', '26.0833282'], time:4.871845ms, swizzle: NOOP, TFLOPS: 28.21 (+735.56%) - out_f32x4(t8x8bcf): ['-24.547933', '26.0833282'], time:4.412031ms, swizzle: NOOP, TFLOPS: 31.15 (+10.42%) - out_f32x4(t8x8dbuf): ['-24.547933', '26.0833282'], time:4.048168ms, swizzle: NOOP, TFLOPS: 33.95 (+8.99%) - out_f32(cublas): ['-24.547933', '26.0833282'], time:4.019129ms, swizzle: NOOP, TFLOPS: 34.20 (+0.72%) - out_f32_th: ['-24.547933', '26.0833282'], time:3.687226ms, swizzle: NOOP, TFLOPS: 37.27 (+9.00%) + out_f32x4(t8x8sk): ['151.780014', '4.5990448 '], time:4.822254ms, swizzle: NOOP, TFLOPS: 28.50 (+0.00%) + out_f32x4(t8x8bcf): ['151.780014', '4.5990448 '], time:4.319739ms, swizzle: NOOP, TFLOPS: 31.82 (+11.63%) + out_f32x4(t8x8dbuf): ['151.780014', '4.5990448 '], time:3.906702ms, swizzle: NOOP, TFLOPS: 35.18 (+10.57%) + out_f32(cublas): ['151.780014', '4.5990448 '], time:4.850530ms, swizzle: NOOP, TFLOPS: 28.33 + out_f32_th: ['151.780014', '4.5990448 '], time:3.584909ms, swizzle: NOOP, TFLOPS: 38.34 (+8.98%) --------------------------------------------------------------WMMA---------------------------------------------------------------- - out_tf32(mma2x4+warp2x4+stage3): ['-24.539142', '26.0933208'], time:4.418766ms, swizzle: NOOP, TFLOPS: 31.10 - out_tf32(mma2x4+warp2x4+stage2): ['-24.539142', '26.0933208'], time:3.449714ms, swizzle: NOOP, TFLOPS: 39.84 (+6.88%) - out_tf32(mma2x4+...+stage3+dsmem): ['-24.539142', '26.0933208'], time:3.769552ms, swizzle: NOOP, TFLOPS: 36.46 - out_tf32(mma2x4+...+stage2+dsmem): ['-24.539142', '26.0933208'], time:3.442633ms, swizzle: NOOP, TFLOPS: 39.92 (+0.21%) -out_tf32(mma2x4+...+stage3+swizzle): ['-24.539142', '26.0933208'], time:4.073047ms, swizzle: 1024, TFLOPS: 33.74 -out_tf32(mma2x4+...+stage2+swizzle): ['-24.539142', '26.0933208'], time:3.320538ms, swizzle: 1024, TFLOPS: 41.39 (+3.68%) - out_tf32(...+stage3+dsmem+swizzle): ['-24.539142', '26.0933208'], time:3.651261ms, swizzle: 1024, TFLOPS: 37.64 - out_tf32(...+stage2+dsmem+swizzle): ['-24.539142', '26.0933208'], time:3.319275ms, swizzle: 1024, TFLOPS: 41.41 (+0.04%) - out_tf32(cublas+tf32): ['-24.539142', '26.0933208'], time:2.791321ms, swizzle: NOOP, TFLOPS: 49.24 (+18.91%) + out_tf32(mma2x4+warp2x4+stage3): ['151.794143', '4.5965395 '], time:4.346919ms, swizzle: NOOP, TFLOPS: 31.62 + out_tf32(mma2x4+warp2x4+stage2): ['151.794143', '4.5965395 '], time:3.493309ms, swizzle: NOOP, TFLOPS: 39.34 (+2.62%) + out_tf32(mma2x4+...+stage3+dsmem): ['151.794143', '4.5965395 '], time:3.765821ms, swizzle: NOOP, TFLOPS: 36.50 + out_tf32(mma2x4+...+stage2+dsmem): ['151.794143', '4.5965395 '], time:3.599095ms, swizzle: NOOP, TFLOPS: 38.19 +out_tf32(mma2x4+...+stage3+swizzle): ['151.794143', '4.5965395 '], time:4.048442ms, swizzle: 512 , TFLOPS: 33.95 +out_tf32(mma2x4+...+stage2+swizzle): ['151.794143', '4.5965395 '], time:3.320336ms, swizzle: 512 , TFLOPS: 41.39 (+5.21%) + out_tf32(...+stage3+dsmem+swizzle): ['151.794143', '4.5965395 '], time:3.658032ms, swizzle: 512 , TFLOPS: 37.57 + out_tf32(...+stage2+dsmem+swizzle): ['151.794143', '4.5965395 '], time:3.310155ms, swizzle: 512 , TFLOPS: 41.52 (+0.31%) + out_tf32(cublas+tf32): ['151.794143', '4.5965395 '], time:2.807903ms, swizzle: NOOP, TFLOPS: 48.95 (+17.89%) ---------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------------------------------------- M=4096, N=4096, K=8192 - out_f32(naive): ['47.3211364', '96.7818374'], time:124.1455ms, swizzle: NOOP, TFLOPS: 2.21 (+0.00%) - out_f32x4(t8x8sk): ['47.3211364', '96.7818374'], time:10.18377ms, swizzle: NOOP, TFLOPS: 26.99 (+1119.05%) - out_f32x4(t8x8bcf): ['47.3211364', '96.7818374'], time:8.965158ms, swizzle: NOOP, TFLOPS: 30.66 (+13.59%) - out_f32x4(t8x8dbuf): ['47.3211364', '96.7818374'], time:9.146523ms, swizzle: NOOP, TFLOPS: 30.05 - out_f32(cublas): ['47.3211364', '96.7818374'], time:7.824325ms, swizzle: NOOP, TFLOPS: 35.13 (+14.58%) - out_f32_th: ['47.3211364', '96.7818374'], time:7.979285ms, swizzle: NOOP, TFLOPS: 34.45 + out_f32x4(t8x8sk): ['118.496635', '44.2837791'], time:9.974384ms, swizzle: NOOP, TFLOPS: 27.56 (+0.00%) + out_f32x4(t8x8bcf): ['118.496635', '44.2837791'], time:8.764767ms, swizzle: NOOP, TFLOPS: 31.36 (+13.80%) + out_f32x4(t8x8dbuf): ['118.496635', '44.2837791'], time:8.941769ms, swizzle: NOOP, TFLOPS: 30.74 + out_f32(cublas): ['118.496635', '44.2837791'], time:7.849812ms, swizzle: NOOP, TFLOPS: 35.02 (+11.66%) + out_f32_th: ['118.496635', '44.2837791'], time:7.393693ms, swizzle: NOOP, TFLOPS: 37.18 (+6.17%) --------------------------------------------------------------WMMA---------------------------------------------------------------- - out_tf32(mma2x4+warp2x4+stage3): ['47.3002510', '96.8180007'], time:8.486199ms, swizzle: NOOP, TFLOPS: 32.39 - out_tf32(mma2x4+warp2x4+stage2): ['47.3002510', '96.8180007'], time:6.938815ms, swizzle: NOOP, TFLOPS: 39.61 (+12.76%) - out_tf32(mma2x4+...+stage3+dsmem): ['47.3002510', '96.8180007'], time:7.443344ms, swizzle: NOOP, TFLOPS: 36.93 - out_tf32(mma2x4+...+stage2+dsmem): ['47.3002510', '96.8180007'], time:6.947302ms, swizzle: NOOP, TFLOPS: 39.57 -out_tf32(mma2x4+...+stage3+swizzle): ['47.3002510', '96.8180007'], time:8.268499ms, swizzle: 1024, TFLOPS: 33.24 -out_tf32(mma2x4+...+stage2+swizzle): ['47.3002510', '96.8180007'], time:6.863093ms, swizzle: 1024, TFLOPS: 40.05 (+1.10%) - out_tf32(...+stage3+dsmem+swizzle): ['47.3002510', '96.8180007'], time:7.466220ms, swizzle: 1024, TFLOPS: 36.82 - out_tf32(...+stage2+dsmem+swizzle): ['47.3002510', '96.8180007'], time:6.836020ms, swizzle: 1024, TFLOPS: 40.21 (+0.40%) - out_tf32(cublas+tf32): ['47.3002510', '96.8180007'], time:5.307173ms, swizzle: NOOP, TFLOPS: 51.79 (+28.81%) + out_tf32(mma2x4+warp2x4+stage3): ['118.526184', '44.2716636'], time:8.627605ms, swizzle: NOOP, TFLOPS: 31.86 + out_tf32(mma2x4+warp2x4+stage2): ['118.526184', '44.2716636'], time:6.934285ms, swizzle: NOOP, TFLOPS: 39.64 (+6.63%) + out_tf32(mma2x4+...+stage3+dsmem): ['118.526184', '44.2716636'], time:7.462024ms, swizzle: NOOP, TFLOPS: 36.84 + out_tf32(mma2x4+...+stage2+dsmem): ['118.526184', '44.2716636'], time:6.970906ms, swizzle: NOOP, TFLOPS: 39.43 +out_tf32(mma2x4+...+stage3+swizzle): ['118.526184', '44.2716636'], time:8.261394ms, swizzle: 512 , TFLOPS: 33.27 +out_tf32(mma2x4+...+stage2+swizzle): ['118.526184', '44.2716636'], time:6.864094ms, swizzle: 512 , TFLOPS: 40.05 (+1.02%) + out_tf32(...+stage3+dsmem+swizzle): ['118.526184', '44.2716636'], time:7.449316ms, swizzle: 512 , TFLOPS: 36.90 + out_tf32(...+stage2+dsmem+swizzle): ['118.526184', '44.2716636'], time:6.867933ms, swizzle: 512 , TFLOPS: 40.02 + out_tf32(cublas+tf32): ['118.526184', '44.2716636'], time:5.459380ms, swizzle: NOOP, TFLOPS: 50.35 (+25.73%) ---------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------------------------------------- M=4096, N=8192, K=2048 - out_f32(naive): ['-17.835220', '0.19710006'], time:40.66953ms, swizzle: NOOP, TFLOPS: 3.38 (+0.00%) - out_f32x4(t8x8sk): ['-17.835220', '0.19710006'], time:4.663670ms, swizzle: NOOP, TFLOPS: 29.47 (+772.05%) - out_f32x4(t8x8bcf): ['-17.835220', '0.19710006'], time:4.213857ms, swizzle: NOOP, TFLOPS: 32.62 (+10.67%) - out_f32x4(t8x8dbuf): ['-17.835220', '0.19710006'], time:3.852760ms, swizzle: NOOP, TFLOPS: 35.67 (+9.37%) - out_f32(cublas): ['-17.835220', '0.19710006'], time:3.993618ms, swizzle: NOOP, TFLOPS: 34.41 - out_f32_th: ['-17.835220', '0.19710006'], time:3.633618ms, swizzle: NOOP, TFLOPS: 37.82 (+6.03%) + out_f32x4(t8x8sk): ['70.5972366', '26.1622695'], time:4.638457ms, swizzle: NOOP, TFLOPS: 29.63 (+0.00%) + out_f32x4(t8x8bcf): ['70.5972366', '26.1622695'], time:4.083228ms, swizzle: NOOP, TFLOPS: 33.66 (+13.60%) + out_f32x4(t8x8dbuf): ['70.5972366', '26.1622695'], time:3.705859ms, swizzle: NOOP, TFLOPS: 37.09 (+10.18%) + out_f32(cublas): ['70.5972366', '26.1622695'], time:4.071259ms, swizzle: NOOP, TFLOPS: 33.76 + out_f32_th: ['70.5972366', '26.1622695'], time:3.648686ms, swizzle: NOOP, TFLOPS: 37.67 (+1.57%) --------------------------------------------------------------WMMA---------------------------------------------------------------- - out_tf32(mma2x4+warp2x4+stage3): ['-17.849796', '0.19758633'], time:3.958535ms, swizzle: NOOP, TFLOPS: 34.72 - out_tf32(mma2x4+warp2x4+stage2): ['-17.849796', '0.19758633'], time:3.184044ms, swizzle: NOOP, TFLOPS: 43.16 (+14.12%) - out_tf32(mma2x4+...+stage3+dsmem): ['-17.849796', '0.19758633'], time:3.465819ms, swizzle: NOOP, TFLOPS: 39.66 - out_tf32(mma2x4+...+stage2+dsmem): ['-17.849796', '0.19758633'], time:3.177452ms, swizzle: NOOP, TFLOPS: 43.25 (+0.21%) -out_tf32(mma2x4+...+stage3+swizzle): ['-17.849796', '0.19758633'], time:3.823959ms, swizzle: 2048, TFLOPS: 35.94 -out_tf32(mma2x4+...+stage2+swizzle): ['-17.849796', '0.19758633'], time:3.122901ms, swizzle: 2048, TFLOPS: 44.01 (+1.75%) - out_tf32(...+stage3+dsmem+swizzle): ['-17.849796', '0.19758633'], time:3.422784ms, swizzle: 2048, TFLOPS: 40.15 - out_tf32(...+stage2+dsmem+swizzle): ['-17.849796', '0.19758633'], time:3.124201ms, swizzle: 2048, TFLOPS: 43.99 - out_tf32(cublas+tf32): ['-17.849796', '0.19758633'], time:2.821981ms, swizzle: NOOP, TFLOPS: 48.70 (+10.66%) + out_tf32(mma2x4+warp2x4+stage3): ['70.5943985', '26.1725273'], time:3.987336ms, swizzle: NOOP, TFLOPS: 34.47 + out_tf32(mma2x4+warp2x4+stage2): ['70.5943985', '26.1725273'], time:3.204703ms, swizzle: NOOP, TFLOPS: 42.89 (+13.85%) + out_tf32(mma2x4+...+stage3+dsmem): ['70.5943985', '26.1725273'], time:3.465056ms, swizzle: NOOP, TFLOPS: 39.66 + out_tf32(mma2x4+...+stage2+dsmem): ['70.5943985', '26.1725273'], time:3.179168ms, swizzle: NOOP, TFLOPS: 43.23 (+0.80%) +out_tf32(mma2x4+...+stage3+swizzle): ['70.5943985', '26.1725273'], time:3.828763ms, swizzle: 1024, TFLOPS: 35.90 +out_tf32(mma2x4+...+stage2+swizzle): ['70.5943985', '26.1725273'], time:3.141665ms, swizzle: 1024, TFLOPS: 43.75 (+1.19%) + out_tf32(...+stage3+dsmem+swizzle): ['70.5943985', '26.1725273'], time:3.441977ms, swizzle: 1024, TFLOPS: 39.93 + out_tf32(...+stage2+dsmem+swizzle): ['70.5943985', '26.1725273'], time:3.152799ms, swizzle: 1024, TFLOPS: 43.59 + out_tf32(cublas+tf32): ['70.5943985', '26.1725273'], time:2.859544ms, swizzle: NOOP, TFLOPS: 48.06 (+9.87%) ---------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------------------------------------- M=4096, N=8192, K=4096 - out_f32(naive): ['-24.541957', '26.1021537'], time:134.3554ms, swizzle: NOOP, TFLOPS: 2.05 (+0.00%) - out_f32x4(t8x8sk): ['-24.541957', '26.1021537'], time:9.993898ms, swizzle: NOOP, TFLOPS: 27.50 (+1244.37%) - out_f32x4(t8x8bcf): ['-24.541957', '26.1021537'], time:9.019649ms, swizzle: NOOP, TFLOPS: 30.48 (+10.80%) - out_f32x4(t8x8dbuf): ['-24.541957', '26.1021537'], time:9.230816ms, swizzle: NOOP, TFLOPS: 29.78 - out_f32(cublas): ['-24.541957', '26.1021537'], time:7.709038ms, swizzle: NOOP, TFLOPS: 35.66 (+17.00%) - out_f32_th: ['-24.541957', '26.1021537'], time:7.547247ms, swizzle: NOOP, TFLOPS: 36.42 (+2.14%) + out_f32x4(t8x8sk): ['151.801406', '4.59161139'], time:9.912538ms, swizzle: NOOP, TFLOPS: 27.73 (+0.00%) + out_f32x4(t8x8bcf): ['151.801406', '4.59161139'], time:8.917999ms, swizzle: NOOP, TFLOPS: 30.82 (+11.15%) + out_f32x4(t8x8dbuf): ['151.801406', '4.59161139'], time:8.958077ms, swizzle: NOOP, TFLOPS: 30.68 + out_f32(cublas): ['151.801406', '4.59161139'], time:7.909870ms, swizzle: NOOP, TFLOPS: 34.75 (+12.75%) + out_f32_th: ['151.801406', '4.59161139'], time:7.236218ms, swizzle: NOOP, TFLOPS: 37.99 (+9.31%) --------------------------------------------------------------WMMA---------------------------------------------------------------- - out_tf32(mma2x4+warp2x4+stage3): ['-24.539142', '26.0933208'], time:7.894265ms, swizzle: NOOP, TFLOPS: 34.82 - out_tf32(mma2x4+warp2x4+stage2): ['-24.539142', '26.0933208'], time:6.565976ms, swizzle: NOOP, TFLOPS: 41.86 (+14.94%) - out_tf32(mma2x4+...+stage3+dsmem): ['-24.539142', '26.0933208'], time:6.902301ms, swizzle: NOOP, TFLOPS: 39.82 - out_tf32(mma2x4+...+stage2+dsmem): ['-24.539142', '26.0933208'], time:6.546056ms, swizzle: NOOP, TFLOPS: 41.99 (+0.30%) -out_tf32(mma2x4+...+stage3+swizzle): ['-24.539142', '26.0933208'], time:7.725191ms, swizzle: 2048, TFLOPS: 35.58 -out_tf32(mma2x4+...+stage2+swizzle): ['-24.539142', '26.0933208'], time:6.347680ms, swizzle: 2048, TFLOPS: 43.30 (+3.13%) - out_tf32(...+stage3+dsmem+swizzle): ['-24.539142', '26.0933208'], time:6.931185ms, swizzle: 2048, TFLOPS: 39.66 - out_tf32(...+stage2+dsmem+swizzle): ['-24.539142', '26.0933208'], time:6.350684ms, swizzle: 2048, TFLOPS: 43.28 - out_tf32(cublas+tf32): ['-24.539142', '26.0933208'], time:5.310356ms, swizzle: NOOP, TFLOPS: 51.76 (+19.53%) + out_tf32(mma2x4+warp2x4+stage3): ['151.794143', '4.5965395 '], time:7.893776ms, swizzle: NOOP, TFLOPS: 34.82 + out_tf32(mma2x4+warp2x4+stage2): ['151.794143', '4.5965395 '], time:6.559514ms, swizzle: NOOP, TFLOPS: 41.91 (+10.32%) + out_tf32(mma2x4+...+stage3+dsmem): ['151.794143', '4.5965395 '], time:6.930255ms, swizzle: NOOP, TFLOPS: 39.66 + out_tf32(mma2x4+...+stage2+dsmem): ['151.794143', '4.5965395 '], time:6.577444ms, swizzle: NOOP, TFLOPS: 41.79 +out_tf32(mma2x4+...+stage3+swizzle): ['151.794143', '4.5965395 '], time:7.675647ms, swizzle: 1024, TFLOPS: 35.81 +out_tf32(mma2x4+...+stage2+swizzle): ['151.794143', '4.5965395 '], time:6.308770ms, swizzle: 1024, TFLOPS: 43.57 (+3.97%) + out_tf32(...+stage3+dsmem+swizzle): ['151.794143', '4.5965395 '], time:6.884336ms, swizzle: 1024, TFLOPS: 39.93 + out_tf32(...+stage2+dsmem+swizzle): ['151.794143', '4.5965395 '], time:6.305503ms, swizzle: 1024, TFLOPS: 43.59 (+0.05%) + out_tf32(cublas+tf32): ['151.794143', '4.5965395 '], time:5.328726ms, swizzle: NOOP, TFLOPS: 51.58 (+18.33%) ---------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------------------------------------- M=4096, N=8192, K=8192 - out_f32(naive): ['47.3047409', '96.8061981'], time:275.4955ms, swizzle: NOOP, TFLOPS: 2.00 (+0.00%) - out_f32x4(t8x8sk): ['47.3047409', '96.8061981'], time:20.12772ms, swizzle: NOOP, TFLOPS: 27.31 (+1268.74%) - out_f32x4(t8x8bcf): ['47.3047409', '96.8061981'], time:18.67715ms, swizzle: NOOP, TFLOPS: 29.43 (+7.77%) - out_f32x4(t8x8dbuf): ['47.3047409', '96.8061981'], time:20.63833ms, swizzle: NOOP, TFLOPS: 26.64 - out_f32(cublas): ['47.3047409', '96.8061981'], time:15.28421ms, swizzle: NOOP, TFLOPS: 35.97 (+22.20%) - out_f32_th: ['47.3047409', '96.8061981'], time:14.75317ms, swizzle: NOOP, TFLOPS: 37.26 (+3.60%) + out_f32x4(t8x8sk): ['118.518661', '44.2836265'], time:20.20986ms, swizzle: NOOP, TFLOPS: 27.20 (+0.00%) + out_f32x4(t8x8bcf): ['118.518661', '44.2836265'], time:18.03719ms, swizzle: NOOP, TFLOPS: 30.48 (+12.05%) + out_f32x4(t8x8dbuf): ['118.518661', '44.2836265'], time:18.61379ms, swizzle: NOOP, TFLOPS: 29.53 + out_f32(cublas): ['118.518661', '44.2836265'], time:15.54746ms, swizzle: NOOP, TFLOPS: 35.36 (+16.01%) + out_f32_th: ['118.518661', '44.2836265'], time:15.30375ms, swizzle: NOOP, TFLOPS: 35.92 (+1.59%) --------------------------------------------------------------WMMA---------------------------------------------------------------- - out_tf32(mma2x4+warp2x4+stage3): ['47.3002510', '96.8180007'], time:15.59357ms, swizzle: NOOP, TFLOPS: 35.26 - out_tf32(mma2x4+warp2x4+stage2): ['47.3002510', '96.8180007'], time:13.13108ms, swizzle: NOOP, TFLOPS: 41.87 (+12.35%) - out_tf32(mma2x4+...+stage3+dsmem): ['47.3002510', '96.8180007'], time:13.90327ms, swizzle: NOOP, TFLOPS: 39.54 - out_tf32(mma2x4+...+stage2+dsmem): ['47.3002510', '96.8180007'], time:13.19309ms, swizzle: NOOP, TFLOPS: 41.67 -out_tf32(mma2x4+...+stage3+swizzle): ['47.3002510', '96.8180007'], time:15.48467ms, swizzle: 2048, TFLOPS: 35.50 -out_tf32(mma2x4+...+stage2+swizzle): ['47.3002510', '96.8180007'], time:12.81875ms, swizzle: 2048, TFLOPS: 42.89 (+2.44%) - out_tf32(...+stage3+dsmem+swizzle): ['47.3002510', '96.8180007'], time:13.90204ms, swizzle: 2048, TFLOPS: 39.54 - out_tf32(...+stage2+dsmem+swizzle): ['47.3002510', '96.8180007'], time:12.77471ms, swizzle: 2048, TFLOPS: 43.03 (+0.34%) - out_tf32(cublas+tf32): ['47.3002510', '96.8180007'], time:10.32357ms, swizzle: NOOP, TFLOPS: 53.25 (+23.74%) + out_tf32(mma2x4+warp2x4+stage3): ['118.526184', '44.2716636'], time:15.66731ms, swizzle: NOOP, TFLOPS: 35.09 + out_tf32(mma2x4+warp2x4+stage2): ['118.526184', '44.2716636'], time:13.19141ms, swizzle: NOOP, TFLOPS: 41.68 (+16.01%) + out_tf32(mma2x4+...+stage3+dsmem): ['118.526184', '44.2716636'], time:13.83848ms, swizzle: NOOP, TFLOPS: 39.73 + out_tf32(mma2x4+...+stage2+dsmem): ['118.526184', '44.2716636'], time:13.15524ms, swizzle: NOOP, TFLOPS: 41.79 (+0.27%) +out_tf32(mma2x4+...+stage3+swizzle): ['118.526184', '44.2716636'], time:15.49148ms, swizzle: 1024, TFLOPS: 35.49 +out_tf32(mma2x4+...+stage2+swizzle): ['118.526184', '44.2716636'], time:12.80868ms, swizzle: 1024, TFLOPS: 42.92 (+2.71%) + out_tf32(...+stage3+dsmem+swizzle): ['118.526184', '44.2716636'], time:13.90929ms, swizzle: 1024, TFLOPS: 39.52 + out_tf32(...+stage2+dsmem+swizzle): ['118.526184', '44.2716636'], time:12.78388ms, swizzle: 1024, TFLOPS: 43.00 (+0.19%) + out_tf32(cublas+tf32): ['118.526184', '44.2716636'], time:10.33768ms, swizzle: NOOP, TFLOPS: 53.18 (+23.66%) ---------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------------------------------------- M=4096, N=16384, K=2048 - out_f32(naive): ['-17.835220', '0.19710006'], time:138.1242ms, swizzle: NOOP, TFLOPS: 1.99 (+0.00%) - out_f32x4(t8x8sk): ['-17.835220', '0.19710006'], time:10.01632ms, swizzle: NOOP, TFLOPS: 27.44 (+1278.99%) - out_f32x4(t8x8bcf): ['-17.835220', '0.19710006'], time:9.498941ms, swizzle: NOOP, TFLOPS: 28.94 (+5.45%) - out_f32x4(t8x8dbuf): ['-17.835220', '0.19710006'], time:9.595859ms, swizzle: NOOP, TFLOPS: 28.65 - out_f32(cublas): ['-17.835220', '0.19710006'], time:7.673835ms, swizzle: NOOP, TFLOPS: 35.82 (+23.78%) - out_f32_th: ['-17.835220', '0.19710006'], time:7.615864ms, swizzle: NOOP, TFLOPS: 36.09 (+0.76%) + out_f32x4(t8x8sk): ['70.5972366', '26.1622695'], time:9.941315ms, swizzle: NOOP, TFLOPS: 27.65 (+0.00%) + out_f32x4(t8x8bcf): ['70.5972366', '26.1622695'], time:9.267258ms, swizzle: NOOP, TFLOPS: 29.66 (+7.27%) + out_f32x4(t8x8dbuf): ['70.5972366', '26.1622695'], time:9.232449ms, swizzle: NOOP, TFLOPS: 29.77 (+0.38%) + out_f32(cublas): ['70.5972366', '26.1622695'], time:7.846927ms, swizzle: NOOP, TFLOPS: 35.03 (+17.66%) + out_f32_th: ['70.5972366', '26.1622695'], time:7.085800ms, swizzle: NOOP, TFLOPS: 38.79 (+10.74%) --------------------------------------------------------------WMMA---------------------------------------------------------------- - out_tf32(mma2x4+warp2x4+stage3): ['-17.849796', '0.19758633'], time:7.673478ms, swizzle: NOOP, TFLOPS: 35.82 - out_tf32(mma2x4+warp2x4+stage2): ['-17.849796', '0.19758633'], time:6.537437ms, swizzle: NOOP, TFLOPS: 42.05 (+16.50%) - out_tf32(mma2x4+...+stage3+dsmem): ['-17.849796', '0.19758633'], time:6.784737ms, swizzle: NOOP, TFLOPS: 40.51 - out_tf32(mma2x4+...+stage2+dsmem): ['-17.849796', '0.19758633'], time:6.545460ms, swizzle: NOOP, TFLOPS: 42.00 -out_tf32(mma2x4+...+stage3+swizzle): ['-17.849796', '0.19758633'], time:7.543981ms, swizzle: 4096, TFLOPS: 36.44 -out_tf32(mma2x4+...+stage2+swizzle): ['-17.849796', '0.19758633'], time:6.199836ms, swizzle: 4096, TFLOPS: 44.34 (+5.45%) - out_tf32(...+stage3+dsmem+swizzle): ['-17.849796', '0.19758633'], time:6.745231ms, swizzle: 4096, TFLOPS: 40.75 - out_tf32(...+stage2+dsmem+swizzle): ['-17.849796', '0.19758633'], time:6.165897ms, swizzle: 4096, TFLOPS: 44.58 (+0.55%) - out_tf32(cublas+tf32): ['-17.849796', '0.19758633'], time:5.148077ms, swizzle: NOOP, TFLOPS: 53.39 (+19.77%) + out_tf32(mma2x4+warp2x4+stage3): ['70.5943985', '26.1725273'], time:7.701039ms, swizzle: NOOP, TFLOPS: 35.69 + out_tf32(mma2x4+warp2x4+stage2): ['70.5943985', '26.1725273'], time:6.537389ms, swizzle: NOOP, TFLOPS: 42.05 (+8.39%) + out_tf32(mma2x4+...+stage3+dsmem): ['70.5943985', '26.1725273'], time:6.712508ms, swizzle: NOOP, TFLOPS: 40.95 + out_tf32(mma2x4+...+stage2+dsmem): ['70.5943985', '26.1725273'], time:6.550049ms, swizzle: NOOP, TFLOPS: 41.97 +out_tf32(mma2x4+...+stage3+swizzle): ['70.5943985', '26.1725273'], time:7.554650ms, swizzle: 2048, TFLOPS: 36.39 +out_tf32(mma2x4+...+stage2+swizzle): ['70.5943985', '26.1725273'], time:6.168079ms, swizzle: 2048, TFLOPS: 44.56 (+5.99%) + out_tf32(...+stage3+dsmem+swizzle): ['70.5943985', '26.1725273'], time:6.722187ms, swizzle: 2048, TFLOPS: 40.89 + out_tf32(...+stage2+dsmem+swizzle): ['70.5943985', '26.1725273'], time:6.171321ms, swizzle: 2048, TFLOPS: 44.54 + out_tf32(cublas+tf32): ['70.5943985', '26.1725273'], time:5.131006ms, swizzle: NOOP, TFLOPS: 53.57 (+20.21%) ---------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------------------------------------- M=4096, N=16384, K=4096 - out_f32(naive): ['-24.556724', '26.1026535'], time:275.8962ms, swizzle: NOOP, TFLOPS: 1.99 (+0.00%) - out_f32x4(t8x8sk): ['-24.556724', '26.1026535'], time:20.26476ms, swizzle: NOOP, TFLOPS: 27.13 (+1261.46%) - out_f32x4(t8x8bcf): ['-24.556724', '26.1026535'], time:19.30289ms, swizzle: NOOP, TFLOPS: 28.48 (+4.98%) - out_f32x4(t8x8dbuf): ['-24.556724', '26.1026535'], time:20.73652ms, swizzle: NOOP, TFLOPS: 26.51 - out_f32(cublas): ['-24.556724', '26.1026535'], time:14.44500ms, swizzle: NOOP, TFLOPS: 38.06 (+33.63%) - out_f32_th: ['-24.556724', '26.1026535'], time:14.17189ms, swizzle: NOOP, TFLOPS: 38.79 (+1.93%) + out_f32x4(t8x8sk): ['151.799118', '4.6021018 '], time:20.19996ms, swizzle: NOOP, TFLOPS: 27.22 (+0.00%) + out_f32x4(t8x8bcf): ['151.799118', '4.6021018 '], time:18.53487ms, swizzle: NOOP, TFLOPS: 29.66 (+8.98%) + out_f32x4(t8x8dbuf): ['151.799118', '4.6021018 '], time:18.93479ms, swizzle: NOOP, TFLOPS: 29.03 + out_f32(cublas): ['151.799118', '4.6021018 '], time:14.90321ms, swizzle: NOOP, TFLOPS: 36.89 (+24.37%) + out_f32_th: ['151.799118', '4.6021018 '], time:14.38026ms, swizzle: NOOP, TFLOPS: 38.23 (+3.64%) --------------------------------------------------------------WMMA---------------------------------------------------------------- - out_tf32(mma2x4+warp2x4+stage3): ['-24.539142', '26.0933208'], time:15.15815ms, swizzle: NOOP, TFLOPS: 36.27 - out_tf32(mma2x4+warp2x4+stage2): ['-24.539142', '26.0933208'], time:13.11252ms, swizzle: NOOP, TFLOPS: 41.93 (+8.08%) - out_tf32(mma2x4+...+stage3+dsmem): ['-24.539142', '26.0933208'], time:13.83591ms, swizzle: NOOP, TFLOPS: 39.73 - out_tf32(mma2x4+...+stage2+dsmem): ['-24.539142', '26.0933208'], time:13.17880ms, swizzle: NOOP, TFLOPS: 41.72 -out_tf32(mma2x4+...+stage3+swizzle): ['-24.539142', '26.0933208'], time:15.04755ms, swizzle: 4096, TFLOPS: 36.53 -out_tf32(mma2x4+...+stage2+swizzle): ['-24.539142', '26.0933208'], time:12.35028ms, swizzle: 4096, TFLOPS: 44.51 (+6.17%) - out_tf32(...+stage3+dsmem+swizzle): ['-24.539142', '26.0933208'], time:13.40752ms, swizzle: 4096, TFLOPS: 41.00 - out_tf32(...+stage2+dsmem+swizzle): ['-24.539142', '26.0933208'], time:12.35129ms, swizzle: 4096, TFLOPS: 44.51 - out_tf32(cublas+tf32): ['-24.539142', '26.0933208'], time:10.01133ms, swizzle: NOOP, TFLOPS: 54.91 (+23.36%) + out_tf32(mma2x4+warp2x4+stage3): ['151.794143', '4.5965395 '], time:15.34090ms, swizzle: NOOP, TFLOPS: 35.84 + out_tf32(mma2x4+warp2x4+stage2): ['151.794143', '4.5965395 '], time:12.95042ms, swizzle: NOOP, TFLOPS: 42.45 (+11.04%) + out_tf32(mma2x4+...+stage3+dsmem): ['151.794143', '4.5965395 '], time:13.73360ms, swizzle: NOOP, TFLOPS: 40.03 + out_tf32(mma2x4+...+stage2+dsmem): ['151.794143', '4.5965395 '], time:12.93442ms, swizzle: NOOP, TFLOPS: 42.50 (+0.12%) +out_tf32(mma2x4+...+stage3+swizzle): ['151.794143', '4.5965395 '], time:15.03224ms, swizzle: 2048, TFLOPS: 36.57 +out_tf32(mma2x4+...+stage2+swizzle): ['151.794143', '4.5965395 '], time:12.34993ms, swizzle: 2048, TFLOPS: 44.51 (+4.73%) + out_tf32(...+stage3+dsmem+swizzle): ['151.794143', '4.5965395 '], time:13.40029ms, swizzle: 2048, TFLOPS: 41.03 + out_tf32(...+stage2+dsmem+swizzle): ['151.794143', '4.5965395 '], time:12.32724ms, swizzle: 2048, TFLOPS: 44.60 (+0.18%) + out_tf32(cublas+tf32): ['151.794143', '4.5965395 '], time:9.960341ms, swizzle: NOOP, TFLOPS: 55.19 (+23.76%) ---------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------------------------------------- M=4096, N=16384, K=8192 - out_f32(naive): ['47.3072891', '96.7974395'], time:551.3394ms, swizzle: NOOP, TFLOPS: 1.99 (+0.00%) - out_f32x4(t8x8sk): ['47.3072891', '96.7974395'], time:41.22277ms, swizzle: NOOP, TFLOPS: 26.67 (+1237.46%) - out_f32x4(t8x8bcf): ['47.3072891', '96.7974395'], time:39.89914ms, swizzle: NOOP, TFLOPS: 27.56 (+3.32%) - out_f32x4(t8x8dbuf): ['47.3072891', '96.7974395'], time:40.29097ms, swizzle: NOOP, TFLOPS: 27.29 - out_f32(cublas): ['47.3072891', '96.7974395'], time:29.63916ms, swizzle: NOOP, TFLOPS: 37.10 (+34.62%) - out_f32_th: ['47.3072891', '96.7974395'], time:29.41981ms, swizzle: NOOP, TFLOPS: 37.37 (+0.75%) + out_f32x4(t8x8sk): ['118.513626', '44.2889137'], time:40.22870ms, swizzle: NOOP, TFLOPS: 27.33 (+0.00%) + out_f32x4(t8x8bcf): ['118.513626', '44.2889137'], time:39.04280ms, swizzle: NOOP, TFLOPS: 28.16 (+3.04%) + out_f32x4(t8x8dbuf): ['118.513626', '44.2889137'], time:39.80977ms, swizzle: NOOP, TFLOPS: 27.62 + out_f32(cublas): ['118.513626', '44.2889137'], time:28.38425ms, swizzle: NOOP, TFLOPS: 38.74 (+37.55%) + out_f32_th: ['118.513626', '44.2889137'], time:29.08875ms, swizzle: NOOP, TFLOPS: 37.80 --------------------------------------------------------------WMMA---------------------------------------------------------------- - out_tf32(mma2x4+warp2x4+stage3): ['47.3002510', '96.8180007'], time:30.10461ms, swizzle: NOOP, TFLOPS: 36.52 - out_tf32(mma2x4+warp2x4+stage2): ['47.3002510', '96.8180007'], time:26.53632ms, swizzle: NOOP, TFLOPS: 41.43 (+10.87%) - out_tf32(mma2x4+...+stage3+dsmem): ['47.3002510', '96.8180007'], time:27.68478ms, swizzle: NOOP, TFLOPS: 39.72 - out_tf32(mma2x4+...+stage2+dsmem): ['47.3002510', '96.8180007'], time:26.57709ms, swizzle: NOOP, TFLOPS: 41.37 -out_tf32(mma2x4+...+stage3+swizzle): ['47.3002510', '96.8180007'], time:30.01861ms, swizzle: 4096, TFLOPS: 36.63 -out_tf32(mma2x4+...+stage2+swizzle): ['47.3002510', '96.8180007'], time:25.24836ms, swizzle: 4096, TFLOPS: 43.55 (+5.10%) - out_tf32(...+stage3+dsmem+swizzle): ['47.3002510', '96.8180007'], time:27.08832ms, swizzle: 4096, TFLOPS: 40.59 - out_tf32(...+stage2+dsmem+swizzle): ['47.3002510', '96.8180007'], time:25.11584ms, swizzle: 4096, TFLOPS: 43.78 (+0.53%) - out_tf32(cublas+tf32): ['47.3002510', '96.8180007'], time:19.69352ms, swizzle: NOOP, TFLOPS: 55.83 (+27.53%) + out_tf32(mma2x4+warp2x4+stage3): ['118.526184', '44.2716636'], time:30.07037ms, swizzle: NOOP, TFLOPS: 36.56 + out_tf32(mma2x4+warp2x4+stage2): ['118.526184', '44.2716636'], time:26.02388ms, swizzle: NOOP, TFLOPS: 42.25 (+9.07%) + out_tf32(mma2x4+...+stage3+dsmem): ['118.526184', '44.2716636'], time:27.45041ms, swizzle: NOOP, TFLOPS: 40.05 + out_tf32(mma2x4+...+stage2+dsmem): ['118.526184', '44.2716636'], time:26.32236ms, swizzle: NOOP, TFLOPS: 41.77 +out_tf32(mma2x4+...+stage3+swizzle): ['118.526184', '44.2716636'], time:30.09891ms, swizzle: 2048, TFLOPS: 36.53 +out_tf32(mma2x4+...+stage2+swizzle): ['118.526184', '44.2716636'], time:24.76131ms, swizzle: 2048, TFLOPS: 44.40 (+5.10%) + out_tf32(...+stage3+dsmem+swizzle): ['118.526184', '44.2716636'], time:26.82106ms, swizzle: 2048, TFLOPS: 40.99 + out_tf32(...+stage2+dsmem+swizzle): ['118.526184', '44.2716636'], time:24.67982ms, swizzle: 2048, TFLOPS: 44.55 (+0.33%) + out_tf32(cublas+tf32): ['118.526184', '44.2716636'], time:19.58444ms, swizzle: NOOP, TFLOPS: 56.14 (+26.02%) ---------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------------------------------------- M=8192, N=4096, K=2048 - out_f32(naive): ['-17.849985', '0.19760081'], time:40.67556ms, swizzle: NOOP, TFLOPS: 3.38 (+0.00%) - out_f32x4(t8x8sk): ['-17.849985', '0.19760081'], time:4.759192ms, swizzle: NOOP, TFLOPS: 28.88 (+754.67%) - out_f32x4(t8x8bcf): ['-17.849985', '0.19760081'], time:4.249489ms, swizzle: NOOP, TFLOPS: 32.34 (+11.99%) - out_f32x4(t8x8dbuf): ['-17.849985', '0.19760081'], time:3.854548ms, swizzle: NOOP, TFLOPS: 35.66 (+10.25%) - out_f32(cublas): ['-17.849985', '0.19760081'], time:4.017460ms, swizzle: NOOP, TFLOPS: 34.21 - out_f32_th: ['-17.849985', '0.19760081'], time:3.689861ms, swizzle: NOOP, TFLOPS: 37.25 (+4.46%) + out_f32x4(t8x8sk): ['70.5949554', '26.1727619'], time:4.644012ms, swizzle: NOOP, TFLOPS: 29.59 (+0.00%) + out_f32x4(t8x8bcf): ['70.5949554', '26.1727619'], time:4.165029ms, swizzle: NOOP, TFLOPS: 33.00 (+11.50%) + out_f32x4(t8x8dbuf): ['70.5949554', '26.1727619'], time:3.532195ms, swizzle: NOOP, TFLOPS: 38.91 (+17.92%) + out_f32(cublas): ['70.5949554', '26.1727619'], time:4.056715ms, swizzle: NOOP, TFLOPS: 33.88 + out_f32_th: ['70.5949554', '26.1727619'], time:3.668260ms, swizzle: NOOP, TFLOPS: 37.47 --------------------------------------------------------------WMMA---------------------------------------------------------------- - out_tf32(mma2x4+warp2x4+stage3): ['-17.849796', '0.19758633'], time:4.017901ms, swizzle: NOOP, TFLOPS: 34.21 - out_tf32(mma2x4+warp2x4+stage2): ['-17.849796', '0.19758633'], time:3.204596ms, swizzle: NOOP, TFLOPS: 42.89 (+15.14%) - out_tf32(mma2x4+...+stage3+dsmem): ['-17.849796', '0.19758633'], time:3.486108ms, swizzle: NOOP, TFLOPS: 39.42 - out_tf32(mma2x4+...+stage2+dsmem): ['-17.849796', '0.19758633'], time:3.196144ms, swizzle: NOOP, TFLOPS: 43.00 (+0.26%) -out_tf32(mma2x4+...+stage3+swizzle): ['-17.849796', '0.19758633'], time:3.813004ms, swizzle: 1024, TFLOPS: 36.04 -out_tf32(mma2x4+...+stage2+swizzle): ['-17.849796', '0.19758633'], time:3.110313ms, swizzle: 1024, TFLOPS: 44.19 (+2.76%) - out_tf32(...+stage3+dsmem+swizzle): ['-17.849796', '0.19758633'], time:3.408885ms, swizzle: 1024, TFLOPS: 40.32 - out_tf32(...+stage2+dsmem+swizzle): ['-17.849796', '0.19758633'], time:3.111791ms, swizzle: 1024, TFLOPS: 44.17 - out_tf32(cublas+tf32): ['-17.849796', '0.19758633'], time:2.788853ms, swizzle: NOOP, TFLOPS: 49.28 (+11.53%) + out_tf32(mma2x4+warp2x4+stage3): ['70.5943985', '26.1725273'], time:4.008388ms, swizzle: NOOP, TFLOPS: 34.29 + out_tf32(mma2x4+warp2x4+stage2): ['70.5943985', '26.1725273'], time:3.218698ms, swizzle: NOOP, TFLOPS: 42.70 (+9.74%) + out_tf32(mma2x4+...+stage3+dsmem): ['70.5943985', '26.1725273'], time:3.489041ms, swizzle: NOOP, TFLOPS: 39.39 + out_tf32(mma2x4+...+stage2+dsmem): ['70.5943985', '26.1725273'], time:3.196096ms, swizzle: NOOP, TFLOPS: 43.00 (+0.71%) +out_tf32(mma2x4+...+stage3+swizzle): ['70.5943985', '26.1725273'], time:3.782248ms, swizzle: 512 , TFLOPS: 36.34 +out_tf32(mma2x4+...+stage2+swizzle): ['70.5943985', '26.1725273'], time:3.096580ms, swizzle: 512 , TFLOPS: 44.38 (+3.21%) + out_tf32(...+stage3+dsmem+swizzle): ['70.5943985', '26.1725273'], time:3.394317ms, swizzle: 512 , TFLOPS: 40.49 + out_tf32(...+stage2+dsmem+swizzle): ['70.5943985', '26.1725273'], time:3.095269ms, swizzle: 512 , TFLOPS: 44.40 (+0.04%) + out_tf32(cublas+tf32): ['70.5943985', '26.1725273'], time:11.76311ms, swizzle: NOOP, TFLOPS: 11.68 ---------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------------------------------------- M=8192, N=4096, K=4096 - out_f32(naive): ['-24.539411', '26.0933971'], time:81.24710ms, swizzle: NOOP, TFLOPS: 3.38 (+0.00%) - out_f32x4(t8x8sk): ['-24.539411', '26.0933971'], time:9.419822ms, swizzle: NOOP, TFLOPS: 29.18 (+762.51%) - out_f32x4(t8x8bcf): ['-24.539411', '26.0933971'], time:8.507907ms, swizzle: NOOP, TFLOPS: 32.31 (+10.72%) - out_f32x4(t8x8dbuf): ['-24.539411', '26.0933971'], time:7.820093ms, swizzle: NOOP, TFLOPS: 35.15 (+8.80%) - out_f32(cublas): ['-24.539411', '26.0933971'], time:7.591652ms, swizzle: NOOP, TFLOPS: 36.21 (+3.01%) - out_f32_th: ['-24.539411', '26.0933971'], time:7.503700ms, swizzle: NOOP, TFLOPS: 36.63 (+1.17%) + out_f32x4(t8x8sk): ['151.796371', '4.59689951'], time:9.283566ms, swizzle: NOOP, TFLOPS: 29.61 (+0.00%) + out_f32x4(t8x8bcf): ['151.796371', '4.59689951'], time:8.359241ms, swizzle: NOOP, TFLOPS: 32.88 (+11.06%) + out_f32x4(t8x8dbuf): ['151.796371', '4.59689951'], time:7.493996ms, swizzle: NOOP, TFLOPS: 36.68 (+11.55%) + out_f32(cublas): ['151.796371', '4.59689951'], time:7.483124ms, swizzle: NOOP, TFLOPS: 36.73 (+0.15%) + out_f32_th: ['151.796371', '4.59689951'], time:7.139444ms, swizzle: NOOP, TFLOPS: 38.50 (+4.81%) --------------------------------------------------------------WMMA---------------------------------------------------------------- - out_tf32(mma2x4+warp2x4+stage3): ['-24.539142', '26.0933208'], time:7.938122ms, swizzle: NOOP, TFLOPS: 34.63 - out_tf32(mma2x4+warp2x4+stage2): ['-24.539142', '26.0933208'], time:6.458127ms, swizzle: NOOP, TFLOPS: 42.56 (+16.19%) - out_tf32(mma2x4+...+stage3+dsmem): ['-24.539142', '26.0933208'], time:6.996679ms, swizzle: NOOP, TFLOPS: 39.29 - out_tf32(mma2x4+...+stage2+dsmem): ['-24.539142', '26.0933208'], time:6.418275ms, swizzle: NOOP, TFLOPS: 42.83 (+0.62%) -out_tf32(mma2x4+...+stage3+swizzle): ['-24.539142', '26.0933208'], time:7.730925ms, swizzle: 1024, TFLOPS: 35.56 -out_tf32(mma2x4+...+stage2+swizzle): ['-24.539142', '26.0933208'], time:6.363022ms, swizzle: 1024, TFLOPS: 43.20 (+0.87%) - out_tf32(...+stage3+dsmem+swizzle): ['-24.539142', '26.0933208'], time:6.948149ms, swizzle: 1024, TFLOPS: 39.56 - out_tf32(...+stage2+dsmem+swizzle): ['-24.539142', '26.0933208'], time:6.365275ms, swizzle: 1024, TFLOPS: 43.18 - out_tf32(cublas+tf32): ['-24.539142', '26.0933208'], time:5.335128ms, swizzle: NOOP, TFLOPS: 51.52 (+19.27%) + out_tf32(mma2x4+warp2x4+stage3): ['151.794143', '4.5965395 '], time:7.942914ms, swizzle: NOOP, TFLOPS: 34.61 + out_tf32(mma2x4+warp2x4+stage2): ['151.794143', '4.5965395 '], time:6.454420ms, swizzle: NOOP, TFLOPS: 42.59 (+10.61%) + out_tf32(mma2x4+...+stage3+dsmem): ['151.794143', '4.5965395 '], time:7.018256ms, swizzle: NOOP, TFLOPS: 39.17 + out_tf32(mma2x4+...+stage2+dsmem): ['151.794143', '4.5965395 '], time:6.443977ms, swizzle: NOOP, TFLOPS: 42.66 (+0.16%) +out_tf32(mma2x4+...+stage3+swizzle): ['151.794143', '4.5965395 '], time:7.723641ms, swizzle: 512 , TFLOPS: 35.59 +out_tf32(mma2x4+...+stage2+swizzle): ['151.794143', '4.5965395 '], time:6.369042ms, swizzle: 512 , TFLOPS: 43.16 (+1.18%) + out_tf32(...+stage3+dsmem+swizzle): ['151.794143', '4.5965395 '], time:6.931543ms, swizzle: 512 , TFLOPS: 39.66 + out_tf32(...+stage2+dsmem+swizzle): ['151.794143', '4.5965395 '], time:6.361842ms, swizzle: 512 , TFLOPS: 43.21 (+0.11%) + out_tf32(cublas+tf32): ['151.794143', '4.5965395 '], time:5.284237ms, swizzle: NOOP, TFLOPS: 52.02 (+20.39%) ---------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------------------------------------- M=8192, N=4096, K=8192 - out_f32(naive): ['47.2999496', '96.8197784'], time:247.6329ms, swizzle: NOOP, TFLOPS: 2.22 (+0.00%) - out_f32x4(t8x8sk): ['47.2999496', '96.8197784'], time:19.65559ms, swizzle: NOOP, TFLOPS: 27.97 (+1159.86%) - out_f32x4(t8x8bcf): ['47.2999496', '96.8197784'], time:17.52810ms, swizzle: NOOP, TFLOPS: 31.36 (+12.14%) - out_f32x4(t8x8dbuf): ['47.2999496', '96.8197784'], time:18.90896ms, swizzle: NOOP, TFLOPS: 29.07 - out_f32(cublas): ['47.2999496', '96.8197784'], time:15.03305ms, swizzle: NOOP, TFLOPS: 36.57 (+16.60%) - out_f32_th: ['47.2999496', '96.8197784'], time:14.72257ms, swizzle: NOOP, TFLOPS: 37.34 (+2.11%) + out_f32x4(t8x8sk): ['118.532104', '44.2729606'], time:19.66500ms, swizzle: NOOP, TFLOPS: 27.96 (+0.00%) + out_f32x4(t8x8bcf): ['118.532104', '44.2729606'], time:17.24970ms, swizzle: NOOP, TFLOPS: 31.87 (+14.00%) + out_f32x4(t8x8dbuf): ['118.532104', '44.2729606'], time:17.30856ms, swizzle: NOOP, TFLOPS: 31.76 + out_f32(cublas): ['118.532104', '44.2729606'], time:15.01247ms, swizzle: NOOP, TFLOPS: 36.62 (+14.90%) + out_f32_th: ['118.532104', '44.2729606'], time:14.77088ms, swizzle: NOOP, TFLOPS: 37.22 (+1.64%) --------------------------------------------------------------WMMA---------------------------------------------------------------- - out_tf32(mma2x4+warp2x4+stage3): ['47.3002510', '96.8180007'], time:15.59674ms, swizzle: NOOP, TFLOPS: 35.25 - out_tf32(mma2x4+warp2x4+stage2): ['47.3002510', '96.8180007'], time:13.05602ms, swizzle: NOOP, TFLOPS: 42.11 (+12.76%) - out_tf32(mma2x4+...+stage3+dsmem): ['47.3002510', '96.8180007'], time:13.85312ms, swizzle: NOOP, TFLOPS: 39.68 - out_tf32(mma2x4+...+stage2+dsmem): ['47.3002510', '96.8180007'], time:13.07342ms, swizzle: NOOP, TFLOPS: 42.05 -out_tf32(mma2x4+...+stage3+swizzle): ['47.3002510', '96.8180007'], time:15.54510ms, swizzle: 1024, TFLOPS: 35.37 -out_tf32(mma2x4+...+stage2+swizzle): ['47.3002510', '96.8180007'], time:12.80124ms, swizzle: 1024, TFLOPS: 42.95 (+1.99%) - out_tf32(...+stage3+dsmem+swizzle): ['47.3002510', '96.8180007'], time:13.91153ms, swizzle: 1024, TFLOPS: 39.52 - out_tf32(...+stage2+dsmem+swizzle): ['47.3002510', '96.8180007'], time:12.78195ms, swizzle: 1024, TFLOPS: 43.01 (+0.15%) - out_tf32(cublas+tf32): ['47.3002510', '96.8180007'], time:10.32341ms, swizzle: NOOP, TFLOPS: 53.25 (+23.82%) + out_tf32(mma2x4+warp2x4+stage3): ['118.526184', '44.2716636'], time:15.61958ms, swizzle: NOOP, TFLOPS: 35.20 + out_tf32(mma2x4+warp2x4+stage2): ['118.526184', '44.2716636'], time:13.11204ms, swizzle: NOOP, TFLOPS: 41.93 (+12.65%) + out_tf32(mma2x4+...+stage3+dsmem): ['118.526184', '44.2716636'], time:13.86370ms, swizzle: NOOP, TFLOPS: 39.65 + out_tf32(mma2x4+...+stage2+dsmem): ['118.526184', '44.2716636'], time:13.01887ms, swizzle: NOOP, TFLOPS: 42.23 (+0.72%) +out_tf32(mma2x4+...+stage3+swizzle): ['118.526184', '44.2716636'], time:15.49036ms, swizzle: 512 , TFLOPS: 35.49 +out_tf32(mma2x4+...+stage2+swizzle): ['118.526184', '44.2716636'], time:12.93551ms, swizzle: 512 , TFLOPS: 42.50 (+0.64%) + out_tf32(...+stage3+dsmem+swizzle): ['118.526184', '44.2716636'], time:13.91084ms, swizzle: 512 , TFLOPS: 39.52 + out_tf32(...+stage2+dsmem+swizzle): ['118.526184', '44.2716636'], time:12.87522ms, swizzle: 512 , TFLOPS: 42.70 (+0.47%) + out_tf32(cublas+tf32): ['118.526184', '44.2716636'], time:10.32779ms, swizzle: NOOP, TFLOPS: 53.23 (+24.67%) ---------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------------------------------------- M=8192, N=8192, K=2048 - out_f32(naive): ['-17.849985', '0.19760081'], time:81.26928ms, swizzle: NOOP, TFLOPS: 3.38 (+0.00%) - out_f32x4(t8x8sk): ['-17.849985', '0.19760081'], time:9.237766ms, swizzle: NOOP, TFLOPS: 29.76 (+779.75%) - out_f32x4(t8x8bcf): ['-17.849985', '0.19760081'], time:8.254611ms, swizzle: NOOP, TFLOPS: 33.30 (+11.91%) - out_f32x4(t8x8dbuf): ['-17.849985', '0.19760081'], time:7.502532ms, swizzle: NOOP, TFLOPS: 36.64 (+10.02%) - out_f32(cublas): ['-17.849985', '0.19760081'], time:8.107531ms, swizzle: NOOP, TFLOPS: 33.90 - out_f32_th: ['-17.849985', '0.19760081'], time:7.478880ms, swizzle: NOOP, TFLOPS: 36.75 (+0.32%) + out_f32x4(t8x8sk): ['70.5949554', '26.1727619'], time:9.005260ms, swizzle: NOOP, TFLOPS: 30.52 (+0.00%) + out_f32x4(t8x8bcf): ['70.5949554', '26.1727619'], time:8.109664ms, swizzle: NOOP, TFLOPS: 33.90 (+11.04%) + out_f32x4(t8x8dbuf): ['70.5949554', '26.1727619'], time:7.237076ms, swizzle: NOOP, TFLOPS: 37.98 (+12.06%) + out_f32(cublas): ['70.5949554', '26.1727619'], time:7.283616ms, swizzle: NOOP, TFLOPS: 37.74 + out_f32_th: ['70.5949554', '26.1727619'], time:7.025599ms, swizzle: NOOP, TFLOPS: 39.13 (+3.01%) --------------------------------------------------------------WMMA---------------------------------------------------------------- - out_tf32(mma2x4+warp2x4+stage3): ['-17.849796', '0.19758633'], time:7.743370ms, swizzle: NOOP, TFLOPS: 35.50 - out_tf32(mma2x4+warp2x4+stage2): ['-17.849796', '0.19758633'], time:6.154835ms, swizzle: NOOP, TFLOPS: 44.66 (+21.51%) - out_tf32(mma2x4+...+stage3+dsmem): ['-17.849796', '0.19758633'], time:6.668007ms, swizzle: NOOP, TFLOPS: 41.22 - out_tf32(mma2x4+...+stage2+dsmem): ['-17.849796', '0.19758633'], time:6.114292ms, swizzle: NOOP, TFLOPS: 44.96 (+0.66%) -out_tf32(mma2x4+...+stage3+swizzle): ['-17.849796', '0.19758633'], time:7.382285ms, swizzle: 2048, TFLOPS: 37.23 -out_tf32(mma2x4+...+stage2+swizzle): ['-17.849796', '0.19758633'], time:6.063973ms, swizzle: 2048, TFLOPS: 45.33 (+0.83%) - out_tf32(...+stage3+dsmem+swizzle): ['-17.849796', '0.19758633'], time:6.617772ms, swizzle: 2048, TFLOPS: 41.54 - out_tf32(...+stage2+dsmem+swizzle): ['-17.849796', '0.19758633'], time:6.061935ms, swizzle: 2048, TFLOPS: 45.34 (+0.03%) - out_tf32(cublas+tf32): ['-17.849796', '0.19758633'], time:5.143392ms, swizzle: NOOP, TFLOPS: 53.44 (+17.86%) + out_tf32(mma2x4+warp2x4+stage3): ['70.5943985', '26.1725273'], time:7.638692ms, swizzle: NOOP, TFLOPS: 35.98 + out_tf32(mma2x4+warp2x4+stage2): ['70.5943985', '26.1725273'], time:6.153583ms, swizzle: NOOP, TFLOPS: 44.67 (+14.17%) + out_tf32(mma2x4+...+stage3+dsmem): ['70.5943985', '26.1725273'], time:6.675100ms, swizzle: NOOP, TFLOPS: 41.18 + out_tf32(mma2x4+...+stage2+dsmem): ['70.5943985', '26.1725273'], time:6.140279ms, swizzle: NOOP, TFLOPS: 44.77 (+0.22%) +out_tf32(mma2x4+...+stage3+swizzle): ['70.5943985', '26.1725273'], time:7.350254ms, swizzle: 1024, TFLOPS: 37.40 +out_tf32(mma2x4+...+stage2+swizzle): ['70.5943985', '26.1725273'], time:6.009721ms, swizzle: 1024, TFLOPS: 45.74 (+2.17%) + out_tf32(...+stage3+dsmem+swizzle): ['70.5943985', '26.1725273'], time:6.560659ms, swizzle: 1024, TFLOPS: 41.90 + out_tf32(...+stage2+dsmem+swizzle): ['70.5943985', '26.1725273'], time:6.008577ms, swizzle: 1024, TFLOPS: 45.75 (+0.02%) + out_tf32(cublas+tf32): ['70.5943985', '26.1725273'], time:5.121445ms, swizzle: NOOP, TFLOPS: 53.67 (+17.32%) ---------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------------------------------------- M=8192, N=8192, K=4096 - out_f32(naive): ['-24.539411', '26.0933971'], time:268.3647ms, swizzle: NOOP, TFLOPS: 2.05 (+0.00%) - out_f32x4(t8x8sk): ['-24.539411', '26.0933971'], time:19.59303ms, swizzle: NOOP, TFLOPS: 28.06 (+1269.69%) - out_f32x4(t8x8bcf): ['-24.539411', '26.0933971'], time:17.70466ms, swizzle: NOOP, TFLOPS: 31.05 (+10.67%) - out_f32x4(t8x8dbuf): ['-24.539411', '26.0933971'], time:19.52338ms, swizzle: NOOP, TFLOPS: 28.16 - out_f32(cublas): ['-24.539411', '26.0933971'], time:14.43643ms, swizzle: NOOP, TFLOPS: 38.08 (+22.64%) - out_f32_th: ['-24.539411', '26.0933971'], time:14.19519ms, swizzle: NOOP, TFLOPS: 38.73 (+1.70%) + out_f32x4(t8x8sk): ['151.796371', '4.59689951'], time:19.40293ms, swizzle: NOOP, TFLOPS: 28.33 (+0.00%) + out_f32x4(t8x8bcf): ['151.796371', '4.59689951'], time:17.21770ms, swizzle: NOOP, TFLOPS: 31.93 (+12.69%) + out_f32x4(t8x8dbuf): ['151.796371', '4.59689951'], time:17.95308ms, swizzle: NOOP, TFLOPS: 30.62 + out_f32(cublas): ['151.796371', '4.59689951'], time:14.42518ms, swizzle: NOOP, TFLOPS: 38.11 (+19.36%) + out_f32_th: ['151.796371', '4.59689951'], time:14.29438ms, swizzle: NOOP, TFLOPS: 38.46 (+0.92%) --------------------------------------------------------------WMMA---------------------------------------------------------------- - out_tf32(mma2x4+warp2x4+stage3): ['-24.539142', '26.0933208'], time:14.91366ms, swizzle: NOOP, TFLOPS: 36.86 - out_tf32(mma2x4+warp2x4+stage2): ['-24.539142', '26.0933208'], time:12.80041ms, swizzle: NOOP, TFLOPS: 42.95 (+10.90%) - out_tf32(mma2x4+...+stage3+dsmem): ['-24.539142', '26.0933208'], time:13.52418ms, swizzle: NOOP, TFLOPS: 40.65 - out_tf32(mma2x4+...+stage2+dsmem): ['-24.539142', '26.0933208'], time:12.89166ms, swizzle: NOOP, TFLOPS: 42.64 -out_tf32(mma2x4+...+stage3+swizzle): ['-24.539142', '26.0933208'], time:14.86917ms, swizzle: 2048, TFLOPS: 36.97 -out_tf32(mma2x4+...+stage2+swizzle): ['-24.539142', '26.0933208'], time:12.12794ms, swizzle: 2048, TFLOPS: 45.33 (+5.54%) - out_tf32(...+stage3+dsmem+swizzle): ['-24.539142', '26.0933208'], time:13.22101ms, swizzle: 2048, TFLOPS: 41.58 - out_tf32(...+stage2+dsmem+swizzle): ['-24.539142', '26.0933208'], time:12.12646ms, swizzle: 2048, TFLOPS: 45.34 (+0.01%) - out_tf32(cublas+tf32): ['-24.539142', '26.0933208'], time:9.958040ms, swizzle: NOOP, TFLOPS: 55.21 (+21.78%) + out_tf32(mma2x4+warp2x4+stage3): ['151.794143', '4.5965395 '], time:14.90476ms, swizzle: NOOP, TFLOPS: 36.88 + out_tf32(mma2x4+warp2x4+stage2): ['151.794143', '4.5965395 '], time:12.51502ms, swizzle: NOOP, TFLOPS: 43.93 (+14.22%) + out_tf32(mma2x4+...+stage3+dsmem): ['151.794143', '4.5965395 '], time:13.19789ms, swizzle: NOOP, TFLOPS: 41.65 + out_tf32(mma2x4+...+stage2+dsmem): ['151.794143', '4.5965395 '], time:12.53654ms, swizzle: NOOP, TFLOPS: 43.85 +out_tf32(mma2x4+...+stage3+swizzle): ['151.794143', '4.5965395 '], time:14.80431ms, swizzle: 1024, TFLOPS: 37.13 +out_tf32(mma2x4+...+stage2+swizzle): ['151.794143', '4.5965395 '], time:12.12592ms, swizzle: 1024, TFLOPS: 45.34 (+3.21%) + out_tf32(...+stage3+dsmem+swizzle): ['151.794143', '4.5965395 '], time:13.21063ms, swizzle: 1024, TFLOPS: 41.61 + out_tf32(...+stage2+dsmem+swizzle): ['151.794143', '4.5965395 '], time:12.12511ms, swizzle: 1024, TFLOPS: 45.34 (+0.01%) + out_tf32(cublas+tf32): ['151.794143', '4.5965395 '], time:10.02106ms, swizzle: NOOP, TFLOPS: 54.86 (+21.00%) ---------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------------------------------------- M=8192, N=8192, K=8192 - out_f32(naive): ['47.2999496', '96.8197784'], time:550.5259ms, swizzle: NOOP, TFLOPS: 2.00 (+0.00%) - out_f32x4(t8x8sk): ['47.2999496', '96.8197784'], time:39.90060ms, swizzle: NOOP, TFLOPS: 27.56 (+1279.74%) - out_f32x4(t8x8bcf): ['47.2999496', '96.8197784'], time:36.95698ms, swizzle: NOOP, TFLOPS: 29.75 (+7.96%) - out_f32x4(t8x8dbuf): ['47.2999496', '96.8197784'], time:38.06703ms, swizzle: NOOP, TFLOPS: 28.88 - out_f32(cublas): ['47.2999496', '96.8197784'], time:28.85241ms, swizzle: NOOP, TFLOPS: 38.11 (+28.09%) - out_f32_th: ['47.2999496', '96.8197784'], time:29.13621ms, swizzle: NOOP, TFLOPS: 37.74 + out_f32x4(t8x8sk): ['118.532104', '44.2729606'], time:39.05200ms, swizzle: NOOP, TFLOPS: 28.16 (+0.00%) + out_f32x4(t8x8bcf): ['118.532104', '44.2729606'], time:36.05434ms, swizzle: NOOP, TFLOPS: 30.50 (+8.31%) + out_f32x4(t8x8dbuf): ['118.532104', '44.2729606'], time:36.42346ms, swizzle: NOOP, TFLOPS: 30.19 + out_f32(cublas): ['118.532104', '44.2729606'], time:28.22470ms, swizzle: NOOP, TFLOPS: 38.96 (+27.74%) + out_f32_th: ['118.532104', '44.2729606'], time:28.45404ms, swizzle: NOOP, TFLOPS: 38.64 --------------------------------------------------------------WMMA---------------------------------------------------------------- - out_tf32(mma2x4+warp2x4+stage3): ['47.3002510', '96.8180007'], time:29.58662ms, swizzle: NOOP, TFLOPS: 37.16 - out_tf32(mma2x4+warp2x4+stage2): ['47.3002510', '96.8180007'], time:25.65428ms, swizzle: NOOP, TFLOPS: 42.86 (+12.47%) - out_tf32(mma2x4+...+stage3+dsmem): ['47.3002510', '96.8180007'], time:27.24572ms, swizzle: NOOP, TFLOPS: 40.36 - out_tf32(mma2x4+...+stage2+dsmem): ['47.3002510', '96.8180007'], time:25.65881ms, swizzle: NOOP, TFLOPS: 42.85 -out_tf32(mma2x4+...+stage3+swizzle): ['47.3002510', '96.8180007'], time:29.62752ms, swizzle: 2048, TFLOPS: 37.11 -out_tf32(mma2x4+...+stage2+swizzle): ['47.3002510', '96.8180007'], time:24.43090ms, swizzle: 2048, TFLOPS: 45.00 (+5.01%) - out_tf32(...+stage3+dsmem+swizzle): ['47.3002510', '96.8180007'], time:26.46713ms, swizzle: 2048, TFLOPS: 41.54 - out_tf32(...+stage2+dsmem+swizzle): ['47.3002510', '96.8180007'], time:24.40419ms, swizzle: 2048, TFLOPS: 45.05 (+0.11%) - out_tf32(cublas+tf32): ['47.3002510', '96.8180007'], time:19.60293ms, swizzle: NOOP, TFLOPS: 56.09 (+24.49%) + out_tf32(mma2x4+warp2x4+stage3): ['118.526184', '44.2716636'], time:29.65857ms, swizzle: NOOP, TFLOPS: 37.07 + out_tf32(mma2x4+warp2x4+stage2): ['118.526184', '44.2716636'], time:25.09703ms, swizzle: NOOP, TFLOPS: 43.81 (+12.46%) + out_tf32(mma2x4+...+stage3+dsmem): ['118.526184', '44.2716636'], time:26.67160ms, swizzle: NOOP, TFLOPS: 41.22 + out_tf32(mma2x4+...+stage2+dsmem): ['118.526184', '44.2716636'], time:25.22740ms, swizzle: NOOP, TFLOPS: 43.58 +out_tf32(mma2x4+...+stage3+swizzle): ['118.526184', '44.2716636'], time:29.67340ms, swizzle: 1024, TFLOPS: 37.05 +out_tf32(mma2x4+...+stage2+swizzle): ['118.526184', '44.2716636'], time:24.31735ms, swizzle: 1024, TFLOPS: 45.22 (+3.21%) + out_tf32(...+stage3+dsmem+swizzle): ['118.526184', '44.2716636'], time:26.41408ms, swizzle: 1024, TFLOPS: 41.63 + out_tf32(...+stage2+dsmem+swizzle): ['118.526184', '44.2716636'], time:24.30074ms, swizzle: 1024, TFLOPS: 45.25 (+0.07%) + out_tf32(cublas+tf32): ['118.526184', '44.2716636'], time:19.56663ms, swizzle: NOOP, TFLOPS: 56.19 (+24.19%) ---------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------------------------------------- M=8192, N=16384, K=2048 - out_f32(naive): ['-17.849985', '0.19760081'], time:276.2140ms, swizzle: NOOP, TFLOPS: 1.99 (+0.00%) - out_f32x4(t8x8sk): ['-17.849985', '0.19760081'], time:20.09291ms, swizzle: NOOP, TFLOPS: 27.36 (+1274.68%) - out_f32x4(t8x8bcf): ['-17.849985', '0.19760081'], time:18.52561ms, swizzle: NOOP, TFLOPS: 29.68 (+8.46%) - out_f32x4(t8x8dbuf): ['-17.849985', '0.19760081'], time:19.83964ms, swizzle: NOOP, TFLOPS: 27.71 - out_f32(cublas): ['-17.849985', '0.19760081'], time:14.76491ms, swizzle: NOOP, TFLOPS: 37.23 (+25.47%) - out_f32_th: ['-17.849985', '0.19760081'], time:14.15612ms, swizzle: NOOP, TFLOPS: 38.84 (+4.30%) + out_f32x4(t8x8sk): ['70.5949554', '26.1727619'], time:19.93403ms, swizzle: NOOP, TFLOPS: 27.58 (+0.00%) + out_f32x4(t8x8bcf): ['70.5949554', '26.1727619'], time:17.85275ms, swizzle: NOOP, TFLOPS: 30.79 (+11.66%) + out_f32x4(t8x8dbuf): ['70.5949554', '26.1727619'], time:17.60568ms, swizzle: NOOP, TFLOPS: 31.23 (+1.40%) + out_f32(cublas): ['70.5949554', '26.1727619'], time:14.66460ms, swizzle: NOOP, TFLOPS: 37.49 (+20.06%) + out_f32_th: ['70.5949554', '26.1727619'], time:14.66336ms, swizzle: NOOP, TFLOPS: 37.49 (+0.01%) --------------------------------------------------------------WMMA---------------------------------------------------------------- - out_tf32(mma2x4+warp2x4+stage3): ['-17.849796', '0.19758633'], time:14.75453ms, swizzle: NOOP, TFLOPS: 37.26 - out_tf32(mma2x4+warp2x4+stage2): ['-17.849796', '0.19758633'], time:13.11275ms, swizzle: NOOP, TFLOPS: 41.93 (+7.96%) - out_tf32(mma2x4+...+stage3+dsmem): ['-17.849796', '0.19758633'], time:13.43187ms, swizzle: NOOP, TFLOPS: 40.93 - out_tf32(mma2x4+...+stage2+dsmem): ['-17.849796', '0.19758633'], time:13.04426ms, swizzle: NOOP, TFLOPS: 42.15 (+0.53%) -out_tf32(mma2x4+...+stage3+swizzle): ['-17.849796', '0.19758633'], time:14.68685ms, swizzle: 4096, TFLOPS: 37.43 -out_tf32(mma2x4+...+stage2+swizzle): ['-17.849796', '0.19758633'], time:11.95122ms, swizzle: 4096, TFLOPS: 46.00 (+9.15%) - out_tf32(...+stage3+dsmem+swizzle): ['-17.849796', '0.19758633'], time:13.05289ms, swizzle: 4096, TFLOPS: 42.12 - out_tf32(...+stage2+dsmem+swizzle): ['-17.849796', '0.19758633'], time:11.98161ms, swizzle: 4096, TFLOPS: 45.88 - out_tf32(cublas+tf32): ['-17.849796', '0.19758633'], time:9.834122ms, swizzle: NOOP, TFLOPS: 55.90 (+21.53%) + out_tf32(mma2x4+warp2x4+stage3): ['70.5943985', '26.1725273'], time:14.75033ms, swizzle: NOOP, TFLOPS: 37.27 + out_tf32(mma2x4+warp2x4+stage2): ['70.5943985', '26.1725273'], time:12.68918ms, swizzle: NOOP, TFLOPS: 43.32 (+15.56%) + out_tf32(mma2x4+...+stage3+dsmem): ['70.5943985', '26.1725273'], time:13.28039ms, swizzle: NOOP, TFLOPS: 41.40 + out_tf32(mma2x4+...+stage2+dsmem): ['70.5943985', '26.1725273'], time:12.78223ms, swizzle: NOOP, TFLOPS: 43.01 +out_tf32(mma2x4+...+stage3+swizzle): ['70.5943985', '26.1725273'], time:14.66119ms, swizzle: 2048, TFLOPS: 37.50 +out_tf32(mma2x4+...+stage2+swizzle): ['70.5943985', '26.1725273'], time:11.99231ms, swizzle: 2048, TFLOPS: 45.84 (+5.81%) + out_tf32(...+stage3+dsmem+swizzle): ['70.5943985', '26.1725273'], time:13.03169ms, swizzle: 2048, TFLOPS: 42.19 + out_tf32(...+stage2+dsmem+swizzle): ['70.5943985', '26.1725273'], time:11.96327ms, swizzle: 2048, TFLOPS: 45.95 (+0.24%) + out_tf32(cublas+tf32): ['70.5943985', '26.1725273'], time:9.859824ms, swizzle: NOOP, TFLOPS: 55.76 (+21.33%) ---------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------------------------------------- M=8192, N=16384, K=4096 - out_f32(naive): ['-24.539411', '26.0933971'], time:551.5206ms, swizzle: NOOP, TFLOPS: 1.99 (+0.00%) - out_f32x4(t8x8sk): ['-24.539411', '26.0933971'], time:40.85628ms, swizzle: NOOP, TFLOPS: 26.91 (+1249.90%) - out_f32x4(t8x8bcf): ['-24.539411', '26.0933971'], time:38.69991ms, swizzle: NOOP, TFLOPS: 28.41 (+5.57%) - out_f32x4(t8x8dbuf): ['-24.539411', '26.0933971'], time:39.29961ms, swizzle: NOOP, TFLOPS: 27.98 - out_f32(cublas): ['-24.539411', '26.0933971'], time:28.43469ms, swizzle: NOOP, TFLOPS: 38.67 (+36.10%) - out_f32_th: ['-24.539411', '26.0933971'], time:28.36043ms, swizzle: NOOP, TFLOPS: 38.77 (+0.26%) + out_f32x4(t8x8sk): ['151.796371', '4.59689951'], time:40.03288ms, swizzle: NOOP, TFLOPS: 27.47 (+0.00%) + out_f32x4(t8x8bcf): ['151.796371', '4.59689951'], time:39.52372ms, swizzle: NOOP, TFLOPS: 27.82 (+1.29%) + out_f32x4(t8x8dbuf): ['151.796371', '4.59689951'], time:37.59534ms, swizzle: NOOP, TFLOPS: 29.25 (+5.13%) + out_f32(cublas): ['151.796371', '4.59689951'], time:27.83019ms, swizzle: NOOP, TFLOPS: 39.51 (+35.09%) + out_f32_th: ['151.796371', '4.59689951'], time:27.95956ms, swizzle: NOOP, TFLOPS: 39.33 --------------------------------------------------------------WMMA---------------------------------------------------------------- - out_tf32(mma2x4+warp2x4+stage3): ['-24.539142', '26.0933208'], time:29.58275ms, swizzle: NOOP, TFLOPS: 37.17 - out_tf32(mma2x4+warp2x4+stage2): ['-24.539142', '26.0933208'], time:25.94059ms, swizzle: NOOP, TFLOPS: 42.39 (+9.33%) - out_tf32(mma2x4+...+stage3+dsmem): ['-24.539142', '26.0933208'], time:27.37162ms, swizzle: NOOP, TFLOPS: 40.17 - out_tf32(mma2x4+...+stage2+dsmem): ['-24.539142', '26.0933208'], time:25.87283ms, swizzle: NOOP, TFLOPS: 42.50 (+0.26%) -out_tf32(mma2x4+...+stage3+swizzle): ['-24.539142', '26.0933208'], time:29.19225ms, swizzle: 4096, TFLOPS: 37.66 -out_tf32(mma2x4+...+stage2+swizzle): ['-24.539142', '26.0933208'], time:23.91030ms, swizzle: 4096, TFLOPS: 45.98 (+8.21%) - out_tf32(...+stage3+dsmem+swizzle): ['-24.539142', '26.0933208'], time:25.98127ms, swizzle: 4096, TFLOPS: 42.32 - out_tf32(...+stage2+dsmem+swizzle): ['-24.539142', '26.0933208'], time:23.91967ms, swizzle: 4096, TFLOPS: 45.97 - out_tf32(cublas+tf32): ['-24.539142', '26.0933208'], time:19.25871ms, swizzle: NOOP, TFLOPS: 57.09 (+24.15%) + out_tf32(mma2x4+warp2x4+stage3): ['151.794143', '4.5965395 '], time:29.30724ms, swizzle: NOOP, TFLOPS: 37.52 + out_tf32(mma2x4+warp2x4+stage2): ['151.794143', '4.5965395 '], time:25.27904ms, swizzle: NOOP, TFLOPS: 43.49 (+10.09%) + out_tf32(mma2x4+...+stage3+dsmem): ['151.794143', '4.5965395 '], time:27.31575ms, swizzle: NOOP, TFLOPS: 40.25 + out_tf32(mma2x4+...+stage2+dsmem): ['151.794143', '4.5965395 '], time:25.58822ms, swizzle: NOOP, TFLOPS: 42.97 +out_tf32(mma2x4+...+stage3+swizzle): ['151.794143', '4.5965395 '], time:29.27069ms, swizzle: 2048, TFLOPS: 37.56 +out_tf32(mma2x4+...+stage2+swizzle): ['151.794143', '4.5965395 '], time:23.81775ms, swizzle: 2048, TFLOPS: 46.16 (+6.14%) + out_tf32(...+stage3+dsmem+swizzle): ['151.794143', '4.5965395 '], time:26.00069ms, swizzle: 2048, TFLOPS: 42.29 + out_tf32(...+stage2+dsmem+swizzle): ['151.794143', '4.5965395 '], time:23.87239ms, swizzle: 2048, TFLOPS: 46.06 + out_tf32(cublas+tf32): ['151.794143', '4.5965395 '], time:19.24333ms, swizzle: NOOP, TFLOPS: 57.14 (+23.77%) ---------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------------------------------------- M=8192, N=16384, K=8192 - out_f32(naive): ['47.2999496', '96.8197784'], time:1102.202ms, swizzle: NOOP, TFLOPS: 2.00 (+0.00%) - out_f32x4(t8x8sk): ['47.2999496', '96.8197784'], time:82.22703ms, swizzle: NOOP, TFLOPS: 26.74 (+1240.44%) - out_f32x4(t8x8bcf): ['47.2999496', '96.8197784'], time:77.98941ms, swizzle: NOOP, TFLOPS: 28.20 (+5.43%) - out_f32x4(t8x8dbuf): ['47.2999496', '96.8197784'], time:78.90355ms, swizzle: NOOP, TFLOPS: 27.87 - out_f32(cublas): ['47.2999496', '96.8197784'], time:58.00436ms, swizzle: NOOP, TFLOPS: 37.91 (+34.45%) - out_f32_th: ['47.2999496', '96.8197784'], time:58.51061ms, swizzle: NOOP, TFLOPS: 37.58 + out_f32x4(t8x8sk): ['118.532104', '44.2729606'], time:81.30698ms, swizzle: NOOP, TFLOPS: 27.05 (+0.00%) + out_f32x4(t8x8bcf): ['118.532104', '44.2729606'], time:75.78270ms, swizzle: NOOP, TFLOPS: 29.02 (+7.29%) + out_f32x4(t8x8dbuf): ['118.532104', '44.2729606'], time:75.56617ms, swizzle: NOOP, TFLOPS: 29.10 (+0.29%) + out_f32(cublas): ['118.532104', '44.2729606'], time:56.42166ms, swizzle: NOOP, TFLOPS: 38.97 (+33.93%) + out_f32_th: ['118.532104', '44.2729606'], time:57.50610ms, swizzle: NOOP, TFLOPS: 38.24 --------------------------------------------------------------WMMA---------------------------------------------------------------- - out_tf32(mma2x4+warp2x4+stage3): ['47.3002510', '96.8180007'], time:59.37192ms, swizzle: NOOP, TFLOPS: 37.04 - out_tf32(mma2x4+warp2x4+stage2): ['47.3002510', '96.8180007'], time:51.85855ms, swizzle: NOOP, TFLOPS: 42.40 (+11.85%) - out_tf32(mma2x4+...+stage3+dsmem): ['47.3002510', '96.8180007'], time:55.02256ms, swizzle: NOOP, TFLOPS: 39.97 - out_tf32(mma2x4+...+stage2+dsmem): ['47.3002510', '96.8180007'], time:52.14641ms, swizzle: NOOP, TFLOPS: 42.17 -out_tf32(mma2x4+...+stage3+swizzle): ['47.3002510', '96.8180007'], time:58.22221ms, swizzle: 4096, TFLOPS: 37.77 -out_tf32(mma2x4+...+stage2+swizzle): ['47.3002510', '96.8180007'], time:49.53384ms, swizzle: 4096, TFLOPS: 44.39 (+4.69%) - out_tf32(...+stage3+dsmem+swizzle): ['47.3002510', '96.8180007'], time:53.26046ms, swizzle: 4096, TFLOPS: 41.29 - out_tf32(...+stage2+dsmem+swizzle): ['47.3002510', '96.8180007'], time:49.65736ms, swizzle: 4096, TFLOPS: 44.28 - out_tf32(cublas+tf32): ['47.3002510', '96.8180007'], time:38.19205ms, swizzle: NOOP, TFLOPS: 57.58 (+29.70%) + out_tf32(mma2x4+warp2x4+stage3): ['118.526184', '44.2716636'], time:58.45718ms, swizzle: NOOP, TFLOPS: 37.62 + out_tf32(mma2x4+warp2x4+stage2): ['118.526184', '44.2716636'], time:51.36411ms, swizzle: NOOP, TFLOPS: 42.81 (+9.85%) + out_tf32(mma2x4+...+stage3+dsmem): ['118.526184', '44.2716636'], time:53.86862ms, swizzle: NOOP, TFLOPS: 40.82 + out_tf32(mma2x4+...+stage2+dsmem): ['118.526184', '44.2716636'], time:51.22380ms, swizzle: NOOP, TFLOPS: 42.93 (+0.27%) +out_tf32(mma2x4+...+stage3+swizzle): ['118.526184', '44.2716636'], time:58.32481ms, swizzle: 2048, TFLOPS: 37.70 +out_tf32(mma2x4+...+stage2+swizzle): ['118.526184', '44.2716636'], time:47.85780ms, swizzle: 2048, TFLOPS: 45.95 (+7.03%) + out_tf32(...+stage3+dsmem+swizzle): ['118.526184', '44.2716636'], time:51.81453ms, swizzle: 2048, TFLOPS: 42.44 + out_tf32(...+stage2+dsmem+swizzle): ['118.526184', '44.2716636'], time:47.76165ms, swizzle: 2048, TFLOPS: 46.04 (+0.20%) + out_tf32(cublas+tf32): ['118.526184', '44.2716636'], time:38.08858ms, swizzle: NOOP, TFLOPS: 57.73 (+25.40%) ---------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------------------------------------- M=16384, N=4096, K=2048 - out_f32(naive): ['-17.849985', '0.19760081'], time:81.29155ms, swizzle: NOOP, TFLOPS: 3.38 (+0.00%) - out_f32x4(t8x8sk): ['-17.849985', '0.19760081'], time:9.530687ms, swizzle: NOOP, TFLOPS: 28.84 (+752.95%) - out_f32x4(t8x8bcf): ['-17.849985', '0.19760081'], time:8.524906ms, swizzle: NOOP, TFLOPS: 32.24 (+11.80%) - out_f32x4(t8x8dbuf): ['-17.849985', '0.19760081'], time:8.015465ms, swizzle: NOOP, TFLOPS: 34.29 (+6.36%) - out_f32(cublas): ['-17.849985', '0.19760081'], time:8.247447ms, swizzle: NOOP, TFLOPS: 33.33 - out_f32_th: ['-17.849985', '0.19760081'], time:7.800579ms, swizzle: NOOP, TFLOPS: 35.24 (+2.75%) + out_f32x4(t8x8sk): ['70.5949554', '26.1727619'], time:9.190845ms, swizzle: NOOP, TFLOPS: 29.91 (+0.00%) + out_f32x4(t8x8bcf): ['70.5949554', '26.1727619'], time:8.345413ms, swizzle: NOOP, TFLOPS: 32.94 (+10.13%) + out_f32x4(t8x8dbuf): ['70.5949554', '26.1727619'], time:7.679963ms, swizzle: NOOP, TFLOPS: 35.79 (+8.66%) + out_f32(cublas): ['70.5949554', '26.1727619'], time:7.500529ms, swizzle: NOOP, TFLOPS: 36.65 (+2.39%) + out_f32_th: ['70.5949554', '26.1727619'], time:7.146787ms, swizzle: NOOP, TFLOPS: 38.46 (+4.95%) --------------------------------------------------------------WMMA---------------------------------------------------------------- - out_tf32(mma2x4+warp2x4+stage3): ['-17.849796', '0.19758633'], time:7.802319ms, swizzle: NOOP, TFLOPS: 35.23 - out_tf32(mma2x4+warp2x4+stage2): ['-17.849796', '0.19758633'], time:6.260669ms, swizzle: NOOP, TFLOPS: 43.91 (+24.60%) - out_tf32(mma2x4+...+stage3+dsmem): ['-17.849796', '0.19758633'], time:6.773519ms, swizzle: NOOP, TFLOPS: 40.58 - out_tf32(mma2x4+...+stage2+dsmem): ['-17.849796', '0.19758633'], time:6.221926ms, swizzle: NOOP, TFLOPS: 44.18 (+0.62%) -out_tf32(mma2x4+...+stage3+swizzle): ['-17.849796', '0.19758633'], time:7.493686ms, swizzle: 1024, TFLOPS: 36.68 -out_tf32(mma2x4+...+stage2+swizzle): ['-17.849796', '0.19758633'], time:6.252801ms, swizzle: 1024, TFLOPS: 43.96 - out_tf32(...+stage3+dsmem+swizzle): ['-17.849796', '0.19758633'], time:6.773734ms, swizzle: 1024, TFLOPS: 40.58 - out_tf32(...+stage2+dsmem+swizzle): ['-17.849796', '0.19758633'], time:6.259787ms, swizzle: 1024, TFLOPS: 43.91 - out_tf32(cublas+tf32): ['-17.849796', '0.19758633'], time:5.200731ms, swizzle: NOOP, TFLOPS: 52.85 (+19.64%) + out_tf32(mma2x4+warp2x4+stage3): ['70.5943985', '26.1725273'], time:7.968235ms, swizzle: NOOP, TFLOPS: 34.50 + out_tf32(mma2x4+warp2x4+stage2): ['70.5943985', '26.1725273'], time:6.254506ms, swizzle: NOOP, TFLOPS: 43.95 (+14.27%) + out_tf32(mma2x4+...+stage3+dsmem): ['70.5943985', '26.1725273'], time:6.782460ms, swizzle: NOOP, TFLOPS: 40.53 + out_tf32(mma2x4+...+stage2+dsmem): ['70.5943985', '26.1725273'], time:6.247973ms, swizzle: NOOP, TFLOPS: 43.99 (+0.10%) +out_tf32(mma2x4+...+stage3+swizzle): ['70.5943985', '26.1725273'], time:7.488203ms, swizzle: 512 , TFLOPS: 36.71 +out_tf32(mma2x4+...+stage2+swizzle): ['70.5943985', '26.1725273'], time:6.200075ms, swizzle: 512 , TFLOPS: 44.33 (+0.77%) + out_tf32(...+stage3+dsmem+swizzle): ['70.5943985', '26.1725273'], time:6.759619ms, swizzle: 512 , TFLOPS: 40.66 + out_tf32(...+stage2+dsmem+swizzle): ['70.5943985', '26.1725273'], time:6.231451ms, swizzle: 512 , TFLOPS: 44.11 + out_tf32(cublas+tf32): ['70.5943985', '26.1725273'], time:5.184912ms, swizzle: NOOP, TFLOPS: 53.01 (+19.58%) ---------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------------------------------------- M=16384, N=4096, K=4096 - out_f32(naive): ['-24.539411', '26.0933971'], time:162.3996ms, swizzle: NOOP, TFLOPS: 3.39 (+0.00%) - out_f32x4(t8x8sk): ['-24.539411', '26.0933971'], time:18.98237ms, swizzle: NOOP, TFLOPS: 28.96 (+755.53%) - out_f32x4(t8x8bcf): ['-24.539411', '26.0933971'], time:17.01517ms, swizzle: NOOP, TFLOPS: 32.31 (+11.56%) - out_f32x4(t8x8dbuf): ['-24.539411', '26.0933971'], time:17.42953ms, swizzle: NOOP, TFLOPS: 31.54 - out_f32(cublas): ['-24.539411', '26.0933971'], time:14.61760ms, swizzle: NOOP, TFLOPS: 37.61 (+16.40%) - out_f32_th: ['-24.539411', '26.0933971'], time:14.48466ms, swizzle: NOOP, TFLOPS: 37.95 (+0.92%) + out_f32x4(t8x8sk): ['151.796371', '4.59689951'], time:18.67318ms, swizzle: NOOP, TFLOPS: 29.44 (+0.00%) + out_f32x4(t8x8bcf): ['151.796371', '4.59689951'], time:16.58837ms, swizzle: NOOP, TFLOPS: 33.14 (+12.57%) + out_f32x4(t8x8dbuf): ['151.796371', '4.59689951'], time:16.44637ms, swizzle: NOOP, TFLOPS: 33.43 (+0.86%) + out_f32(cublas): ['151.796371', '4.59689951'], time:14.57281ms, swizzle: NOOP, TFLOPS: 37.72 (+12.86%) + out_f32_th: ['151.796371', '4.59689951'], time:14.51504ms, swizzle: NOOP, TFLOPS: 37.87 (+0.40%) --------------------------------------------------------------WMMA---------------------------------------------------------------- - out_tf32(mma2x4+warp2x4+stage3): ['-24.539142', '26.0933208'], time:15.16782ms, swizzle: NOOP, TFLOPS: 36.24 - out_tf32(mma2x4+warp2x4+stage2): ['-24.539142', '26.0933208'], time:12.41897ms, swizzle: NOOP, TFLOPS: 44.27 (+16.63%) - out_tf32(mma2x4+...+stage3+dsmem): ['-24.539142', '26.0933208'], time:13.44157ms, swizzle: NOOP, TFLOPS: 40.90 - out_tf32(mma2x4+...+stage2+dsmem): ['-24.539142', '26.0933208'], time:12.39446ms, swizzle: NOOP, TFLOPS: 44.35 (+0.20%) -out_tf32(mma2x4+...+stage3+swizzle): ['-24.539142', '26.0933208'], time:14.96917ms, swizzle: 1024, TFLOPS: 36.73 -out_tf32(mma2x4+...+stage2+swizzle): ['-24.539142', '26.0933208'], time:12.42594ms, swizzle: 1024, TFLOPS: 44.24 - out_tf32(...+stage3+dsmem+swizzle): ['-24.539142', '26.0933208'], time:13.48555ms, swizzle: 1024, TFLOPS: 40.77 - out_tf32(...+stage2+dsmem+swizzle): ['-24.539142', '26.0933208'], time:12.44277ms, swizzle: 1024, TFLOPS: 44.18 - out_tf32(cublas+tf32): ['-24.539142', '26.0933208'], time:10.07603ms, swizzle: NOOP, TFLOPS: 54.56 (+23.01%) + out_tf32(mma2x4+warp2x4+stage3): ['151.794143', '4.5965395 '], time:15.47667ms, swizzle: NOOP, TFLOPS: 35.52 + out_tf32(mma2x4+warp2x4+stage2): ['151.794143', '4.5965395 '], time:12.47291ms, swizzle: NOOP, TFLOPS: 44.08 (+16.37%) + out_tf32(mma2x4+...+stage3+dsmem): ['151.794143', '4.5965395 '], time:13.44106ms, swizzle: NOOP, TFLOPS: 40.90 + out_tf32(mma2x4+...+stage2+dsmem): ['151.794143', '4.5965395 '], time:12.39275ms, swizzle: NOOP, TFLOPS: 44.36 (+0.65%) +out_tf32(mma2x4+...+stage3+swizzle): ['151.794143', '4.5965395 '], time:14.96281ms, swizzle: 512 , TFLOPS: 36.74 +out_tf32(mma2x4+...+stage2+swizzle): ['151.794143', '4.5965395 '], time:12.40277ms, swizzle: 512 , TFLOPS: 44.33 + out_tf32(...+stage3+dsmem+swizzle): ['151.794143', '4.5965395 '], time:13.47801ms, swizzle: 512 , TFLOPS: 40.79 + out_tf32(...+stage2+dsmem+swizzle): ['151.794143', '4.5965395 '], time:12.43972ms, swizzle: 512 , TFLOPS: 44.19 + out_tf32(cublas+tf32): ['151.794143', '4.5965395 '], time:10.85467ms, swizzle: NOOP, TFLOPS: 50.65 (+14.17%) ---------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------------------------------------- M=16384, N=4096, K=8192 - out_f32(naive): ['47.2999496', '96.8197784'], time:494.3340ms, swizzle: NOOP, TFLOPS: 2.22 (+0.00%) - out_f32x4(t8x8sk): ['47.2999496', '96.8197784'], time:39.43489ms, swizzle: NOOP, TFLOPS: 27.88 (+1153.54%) - out_f32x4(t8x8bcf): ['47.2999496', '96.8197784'], time:35.69089ms, swizzle: NOOP, TFLOPS: 30.81 (+10.49%) - out_f32x4(t8x8dbuf): ['47.2999496', '96.8197784'], time:37.27245ms, swizzle: NOOP, TFLOPS: 29.50 - out_f32(cublas): ['47.2999496', '96.8197784'], time:29.58321ms, swizzle: NOOP, TFLOPS: 37.17 (+20.65%) - out_f32_th: ['47.2999496', '96.8197784'], time:29.77937ms, swizzle: NOOP, TFLOPS: 36.92 + out_f32x4(t8x8sk): ['118.532104', '44.2729606'], time:38.72056ms, swizzle: NOOP, TFLOPS: 28.40 (+0.00%) + out_f32x4(t8x8bcf): ['118.532104', '44.2729606'], time:34.69905ms, swizzle: NOOP, TFLOPS: 31.69 (+11.59%) + out_f32x4(t8x8dbuf): ['118.532104', '44.2729606'], time:36.12399ms, swizzle: NOOP, TFLOPS: 30.44 + out_f32(cublas): ['118.532104', '44.2729606'], time:28.58903ms, swizzle: NOOP, TFLOPS: 38.46 (+21.37%) + out_f32_th: ['118.532104', '44.2729606'], time:28.67548ms, swizzle: NOOP, TFLOPS: 38.34 --------------------------------------------------------------WMMA---------------------------------------------------------------- - out_tf32(mma2x4+warp2x4+stage3): ['47.3002510', '96.8180007'], time:30.08049ms, swizzle: NOOP, TFLOPS: 36.55 - out_tf32(mma2x4+warp2x4+stage2): ['47.3002510', '96.8180007'], time:25.29913ms, swizzle: NOOP, TFLOPS: 43.46 (+16.93%) - out_tf32(mma2x4+...+stage3+dsmem): ['47.3002510', '96.8180007'], time:27.23990ms, swizzle: NOOP, TFLOPS: 40.36 - out_tf32(mma2x4+...+stage2+dsmem): ['47.3002510', '96.8180007'], time:25.82558ms, swizzle: NOOP, TFLOPS: 42.57 -out_tf32(mma2x4+...+stage3+swizzle): ['47.3002510', '96.8180007'], time:30.02738ms, swizzle: 1024, TFLOPS: 36.62 -out_tf32(mma2x4+...+stage2+swizzle): ['47.3002510', '96.8180007'], time:24.85141ms, swizzle: 1024, TFLOPS: 44.24 (+1.80%) - out_tf32(...+stage3+dsmem+swizzle): ['47.3002510', '96.8180007'], time:26.89465ms, swizzle: 1024, TFLOPS: 40.88 - out_tf32(...+stage2+dsmem+swizzle): ['47.3002510', '96.8180007'], time:24.83090ms, swizzle: 1024, TFLOPS: 44.28 (+0.08%) - out_tf32(cublas+tf32): ['47.3002510', '96.8180007'], time:19.69830ms, swizzle: NOOP, TFLOPS: 55.82 (+26.06%) + out_tf32(mma2x4+warp2x4+stage3): ['118.526184', '44.2716636'], time:30.18698ms, swizzle: NOOP, TFLOPS: 36.42 + out_tf32(mma2x4+warp2x4+stage2): ['118.526184', '44.2716636'], time:25.13649ms, swizzle: NOOP, TFLOPS: 43.74 (+13.74%) + out_tf32(mma2x4+...+stage3+dsmem): ['118.526184', '44.2716636'], time:26.80773ms, swizzle: NOOP, TFLOPS: 41.01 + out_tf32(mma2x4+...+stage2+dsmem): ['118.526184', '44.2716636'], time:25.33824ms, swizzle: NOOP, TFLOPS: 43.39 +out_tf32(mma2x4+...+stage3+swizzle): ['118.526184', '44.2716636'], time:30.06722ms, swizzle: 512 , TFLOPS: 36.57 +out_tf32(mma2x4+...+stage2+swizzle): ['118.526184', '44.2716636'], time:25.04580ms, swizzle: 512 , TFLOPS: 43.90 (+0.36%) + out_tf32(...+stage3+dsmem+swizzle): ['118.526184', '44.2716636'], time:26.84135ms, swizzle: 512 , TFLOPS: 40.96 + out_tf32(...+stage2+dsmem+swizzle): ['118.526184', '44.2716636'], time:24.95355ms, swizzle: 512 , TFLOPS: 44.06 (+0.37%) + out_tf32(cublas+tf32): ['118.526184', '44.2716636'], time:19.74155ms, swizzle: NOOP, TFLOPS: 55.70 (+26.40%) ---------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------------------------------------- M=16384, N=8192, K=2048 - out_f32(naive): ['-17.849985', '0.19760081'], time:162.4853ms, swizzle: NOOP, TFLOPS: 3.38 (+0.00%) - out_f32x4(t8x8sk): ['-17.849985', '0.19760081'], time:18.61557ms, swizzle: NOOP, TFLOPS: 29.53 (+772.85%) - out_f32x4(t8x8bcf): ['-17.849985', '0.19760081'], time:16.65081ms, swizzle: NOOP, TFLOPS: 33.02 (+11.80%) - out_f32x4(t8x8dbuf): ['-17.849985', '0.19760081'], time:16.60894ms, swizzle: NOOP, TFLOPS: 33.10 (+0.25%) - out_f32(cublas): ['-17.849985', '0.19760081'], time:14.59673ms, swizzle: NOOP, TFLOPS: 37.66 (+13.79%) - out_f32_th: ['-17.849985', '0.19760081'], time:14.25113ms, swizzle: NOOP, TFLOPS: 38.58 (+2.43%) + out_f32x4(t8x8sk): ['70.5949554', '26.1727619'], time:18.36364ms, swizzle: NOOP, TFLOPS: 29.94 (+0.00%) + out_f32x4(t8x8bcf): ['70.5949554', '26.1727619'], time:16.34912ms, swizzle: NOOP, TFLOPS: 33.63 (+12.32%) + out_f32x4(t8x8dbuf): ['70.5949554', '26.1727619'], time:14.82284ms, swizzle: NOOP, TFLOPS: 37.09 (+10.30%) + out_f32(cublas): ['70.5949554', '26.1727619'], time:14.45541ms, swizzle: NOOP, TFLOPS: 38.03 (+2.54%) + out_f32_th: ['70.5949554', '26.1727619'], time:14.56203ms, swizzle: NOOP, TFLOPS: 37.75 --------------------------------------------------------------WMMA---------------------------------------------------------------- - out_tf32(mma2x4+warp2x4+stage3): ['-17.849796', '0.19758633'], time:14.85270ms, swizzle: NOOP, TFLOPS: 37.01 - out_tf32(mma2x4+warp2x4+stage2): ['-17.849796', '0.19758633'], time:12.02559ms, swizzle: NOOP, TFLOPS: 45.72 (+18.51%) - out_tf32(mma2x4+...+stage3+dsmem): ['-17.849796', '0.19758633'], time:13.04501ms, swizzle: NOOP, TFLOPS: 42.14 - out_tf32(mma2x4+...+stage2+dsmem): ['-17.849796', '0.19758633'], time:12.00811ms, swizzle: NOOP, TFLOPS: 45.78 (+0.15%) -out_tf32(mma2x4+...+stage3+swizzle): ['-17.849796', '0.19758633'], time:14.59468ms, swizzle: 2048, TFLOPS: 37.67 -out_tf32(mma2x4+...+stage2+swizzle): ['-17.849796', '0.19758633'], time:12.04820ms, swizzle: 2048, TFLOPS: 45.63 - out_tf32(...+stage3+dsmem+swizzle): ['-17.849796', '0.19758633'], time:13.12594ms, swizzle: 2048, TFLOPS: 41.88 - out_tf32(...+stage2+dsmem+swizzle): ['-17.849796', '0.19758633'], time:12.04621ms, swizzle: 2048, TFLOPS: 45.64 - out_tf32(cublas+tf32): ['-17.849796', '0.19758633'], time:9.895133ms, swizzle: NOOP, TFLOPS: 55.56 (+21.35%) + out_tf32(mma2x4+warp2x4+stage3): ['70.5943985', '26.1725273'], time:14.91312ms, swizzle: NOOP, TFLOPS: 36.86 + out_tf32(mma2x4+warp2x4+stage2): ['70.5943985', '26.1725273'], time:12.08066ms, swizzle: NOOP, TFLOPS: 45.51 (+19.66%) + out_tf32(mma2x4+...+stage3+dsmem): ['70.5943985', '26.1725273'], time:13.07072ms, swizzle: NOOP, TFLOPS: 42.06 + out_tf32(mma2x4+...+stage2+dsmem): ['70.5943985', '26.1725273'], time:12.01016ms, swizzle: NOOP, TFLOPS: 45.77 (+0.59%) +out_tf32(mma2x4+...+stage3+swizzle): ['70.5943985', '26.1725273'], time:14.58058ms, swizzle: 1024, TFLOPS: 37.70 +out_tf32(mma2x4+...+stage2+swizzle): ['70.5943985', '26.1725273'], time:12.01112ms, swizzle: 1024, TFLOPS: 45.77 + out_tf32(...+stage3+dsmem+swizzle): ['70.5943985', '26.1725273'], time:13.10825ms, swizzle: 1024, TFLOPS: 41.94 + out_tf32(...+stage2+dsmem+swizzle): ['70.5943985', '26.1725273'], time:12.03854ms, swizzle: 1024, TFLOPS: 45.67 + out_tf32(cublas+tf32): ['70.5943985', '26.1725273'], time:9.944319ms, swizzle: NOOP, TFLOPS: 55.28 (+20.77%) ---------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------------------------------------- M=16384, N=8192, K=4096 - out_f32(naive): ['-24.539411', '26.0933971'], time:536.4258ms, swizzle: NOOP, TFLOPS: 2.05 (+0.00%) - out_f32x4(t8x8sk): ['-24.539411', '26.0933971'], time:40.05281ms, swizzle: NOOP, TFLOPS: 27.45 (+1239.30%) - out_f32x4(t8x8bcf): ['-24.539411', '26.0933971'], time:35.56506ms, swizzle: NOOP, TFLOPS: 30.92 (+12.62%) - out_f32x4(t8x8dbuf): ['-24.539411', '26.0933971'], time:38.55193ms, swizzle: NOOP, TFLOPS: 28.52 - out_f32(cublas): ['-24.539411', '26.0933971'], time:28.43211ms, swizzle: NOOP, TFLOPS: 38.67 (+25.09%) - out_f32_th: ['-24.539411', '26.0933971'], time:28.49795ms, swizzle: NOOP, TFLOPS: 38.58 + out_f32x4(t8x8sk): ['151.796371', '4.59689951'], time:39.44745ms, swizzle: NOOP, TFLOPS: 27.87 (+0.00%) + out_f32x4(t8x8bcf): ['151.796371', '4.59689951'], time:35.19003ms, swizzle: NOOP, TFLOPS: 31.24 (+12.10%) + out_f32x4(t8x8dbuf): ['151.796371', '4.59689951'], time:36.57977ms, swizzle: NOOP, TFLOPS: 30.06 + out_f32(cublas): ['151.796371', '4.59689951'], time:27.93822ms, swizzle: NOOP, TFLOPS: 39.36 (+25.96%) + out_f32_th: ['151.796371', '4.59689951'], time:27.93700ms, swizzle: NOOP, TFLOPS: 39.36 (+0.00%) --------------------------------------------------------------WMMA---------------------------------------------------------------- - out_tf32(mma2x4+warp2x4+stage3): ['-24.539142', '26.0933208'], time:29.34467ms, swizzle: NOOP, TFLOPS: 37.47 - out_tf32(mma2x4+warp2x4+stage2): ['-24.539142', '26.0933208'], time:25.17921ms, swizzle: NOOP, TFLOPS: 43.67 (+12.92%) - out_tf32(mma2x4+...+stage3+dsmem): ['-24.539142', '26.0933208'], time:26.83230ms, swizzle: NOOP, TFLOPS: 40.98 - out_tf32(mma2x4+...+stage2+dsmem): ['-24.539142', '26.0933208'], time:25.15255ms, swizzle: NOOP, TFLOPS: 43.71 (+0.11%) -out_tf32(mma2x4+...+stage3+swizzle): ['-24.539142', '26.0933208'], time:29.20446ms, swizzle: 2048, TFLOPS: 37.65 -out_tf32(mma2x4+...+stage2+swizzle): ['-24.539142', '26.0933208'], time:23.84365ms, swizzle: 2048, TFLOPS: 46.11 (+5.49%) - out_tf32(...+stage3+dsmem+swizzle): ['-24.539142', '26.0933208'], time:26.03495ms, swizzle: 2048, TFLOPS: 42.23 - out_tf32(...+stage2+dsmem+swizzle): ['-24.539142', '26.0933208'], time:23.87169ms, swizzle: 2048, TFLOPS: 46.06 - out_tf32(cublas+tf32): ['-24.539142', '26.0933208'], time:19.30768ms, swizzle: NOOP, TFLOPS: 56.95 (+23.49%) + out_tf32(mma2x4+warp2x4+stage3): ['151.794143', '4.5965395 '], time:29.24573ms, swizzle: NOOP, TFLOPS: 37.60 + out_tf32(mma2x4+warp2x4+stage2): ['151.794143', '4.5965395 '], time:24.57020ms, swizzle: NOOP, TFLOPS: 44.75 (+13.70%) + out_tf32(mma2x4+...+stage3+dsmem): ['151.794143', '4.5965395 '], time:26.55055ms, swizzle: NOOP, TFLOPS: 41.41 + out_tf32(mma2x4+...+stage2+dsmem): ['151.794143', '4.5965395 '], time:24.88572ms, swizzle: NOOP, TFLOPS: 44.18 +out_tf32(mma2x4+...+stage3+swizzle): ['151.794143', '4.5965395 '], time:29.28466ms, swizzle: 1024, TFLOPS: 37.55 +out_tf32(mma2x4+...+stage2+swizzle): ['151.794143', '4.5965395 '], time:23.89683ms, swizzle: 1024, TFLOPS: 46.01 (+2.82%) + out_tf32(...+stage3+dsmem+swizzle): ['151.794143', '4.5965395 '], time:26.11415ms, swizzle: 1024, TFLOPS: 42.10 + out_tf32(...+stage2+dsmem+swizzle): ['151.794143', '4.5965395 '], time:23.87890ms, swizzle: 1024, TFLOPS: 46.05 (+0.08%) + out_tf32(cublas+tf32): ['151.794143', '4.5965395 '], time:19.27731ms, swizzle: NOOP, TFLOPS: 57.04 (+23.87%) ---------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------------------------------------- M=16384, N=8192, K=8192 - out_f32(naive): ['47.2999496', '96.8197784'], time:1100.691ms, swizzle: NOOP, TFLOPS: 2.00 (+0.00%) - out_f32x4(t8x8sk): ['47.2999496', '96.8197784'], time:79.86506ms, swizzle: NOOP, TFLOPS: 27.53 (+1278.19%) - out_f32x4(t8x8bcf): ['47.2999496', '96.8197784'], time:74.10305ms, swizzle: NOOP, TFLOPS: 29.68 (+7.78%) - out_f32x4(t8x8dbuf): ['47.2999496', '96.8197784'], time:74.76978ms, swizzle: NOOP, TFLOPS: 29.41 - out_f32(cublas): ['47.2999496', '96.8197784'], time:57.91260ms, swizzle: NOOP, TFLOPS: 37.97 (+27.96%) - out_f32_th: ['47.2999496', '96.8197784'], time:58.26066ms, swizzle: NOOP, TFLOPS: 37.74 + out_f32x4(t8x8sk): ['118.532104', '44.2729606'], time:79.11319ms, swizzle: NOOP, TFLOPS: 27.80 (+0.00%) + out_f32x4(t8x8bcf): ['118.532104', '44.2729606'], time:70.98405ms, swizzle: NOOP, TFLOPS: 30.98 (+11.45%) + out_f32x4(t8x8dbuf): ['118.532104', '44.2729606'], time:71.76809ms, swizzle: NOOP, TFLOPS: 30.64 + out_f32(cublas): ['118.532104', '44.2729606'], time:55.91969ms, swizzle: NOOP, TFLOPS: 39.32 (+26.94%) + out_f32_th: ['118.532104', '44.2729606'], time:56.78405ms, swizzle: NOOP, TFLOPS: 38.73 --------------------------------------------------------------WMMA---------------------------------------------------------------- - out_tf32(mma2x4+warp2x4+stage3): ['47.3002510', '96.8180007'], time:58.64821ms, swizzle: NOOP, TFLOPS: 37.50 - out_tf32(mma2x4+warp2x4+stage2): ['47.3002510', '96.8180007'], time:49.91295ms, swizzle: NOOP, TFLOPS: 44.06 (+16.03%) - out_tf32(mma2x4+...+stage3+dsmem): ['47.3002510', '96.8180007'], time:54.05123ms, swizzle: NOOP, TFLOPS: 40.68 - out_tf32(mma2x4+...+stage2+dsmem): ['47.3002510', '96.8180007'], time:50.64512ms, swizzle: NOOP, TFLOPS: 43.42 -out_tf32(mma2x4+...+stage3+swizzle): ['47.3002510', '96.8180007'], time:58.26839ms, swizzle: 2048, TFLOPS: 37.74 -out_tf32(mma2x4+...+stage2+swizzle): ['47.3002510', '96.8180007'], time:48.46489ms, swizzle: 2048, TFLOPS: 45.37 (+2.99%) - out_tf32(...+stage3+dsmem+swizzle): ['47.3002510', '96.8180007'], time:51.88893ms, swizzle: 2048, TFLOPS: 42.38 - out_tf32(...+stage2+dsmem+swizzle): ['47.3002510', '96.8180007'], time:48.42858ms, swizzle: 2048, TFLOPS: 45.41 (+0.07%) - out_tf32(cublas+tf32): ['47.3002510', '96.8180007'], time:38.56168ms, swizzle: NOOP, TFLOPS: 57.03 (+25.59%) + out_tf32(mma2x4+warp2x4+stage3): ['118.526184', '44.2716636'], time:58.23874ms, swizzle: NOOP, TFLOPS: 37.76 + out_tf32(mma2x4+warp2x4+stage2): ['118.526184', '44.2716636'], time:49.20217ms, swizzle: NOOP, TFLOPS: 44.69 (+13.65%) + out_tf32(mma2x4+...+stage3+dsmem): ['118.526184', '44.2716636'], time:53.33271ms, swizzle: NOOP, TFLOPS: 41.23 + out_tf32(mma2x4+...+stage2+dsmem): ['118.526184', '44.2716636'], time:49.59840ms, swizzle: NOOP, TFLOPS: 44.34 +out_tf32(mma2x4+...+stage3+swizzle): ['118.526184', '44.2716636'], time:58.33761ms, swizzle: 1024, TFLOPS: 37.69 +out_tf32(mma2x4+...+stage2+swizzle): ['118.526184', '44.2716636'], time:47.81997ms, swizzle: 1024, TFLOPS: 45.99 (+2.89%) + out_tf32(...+stage3+dsmem+swizzle): ['118.526184', '44.2716636'], time:51.88267ms, swizzle: 1024, TFLOPS: 42.38 + out_tf32(...+stage2+dsmem+swizzle): ['118.526184', '44.2716636'], time:47.80828ms, swizzle: 1024, TFLOPS: 46.00 (+0.02%) + out_tf32(cublas+tf32): ['118.526184', '44.2716636'], time:38.11509ms, swizzle: NOOP, TFLOPS: 57.69 (+25.43%) ---------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------------------------------------- M=16384, N=16384, K=2048 - out_f32(naive): ['-17.849985', '0.19760081'], time:552.1761ms, swizzle: NOOP, TFLOPS: 1.99 (+0.00%) - out_f32x4(t8x8sk): ['-17.849985', '0.19760081'], time:41.13322ms, swizzle: NOOP, TFLOPS: 26.73 (+1242.41%) - out_f32x4(t8x8bcf): ['-17.849985', '0.19760081'], time:39.94113ms, swizzle: NOOP, TFLOPS: 27.53 (+2.98%) - out_f32x4(t8x8dbuf): ['-17.849985', '0.19760081'], time:39.03629ms, swizzle: NOOP, TFLOPS: 28.17 (+2.32%) - out_f32(cublas): ['-17.849985', '0.19760081'], time:29.52983ms, swizzle: NOOP, TFLOPS: 37.23 (+32.19%) - out_f32_th: ['-17.849985', '0.19760081'], time:29.55764ms, swizzle: NOOP, TFLOPS: 37.20 + out_f32x4(t8x8sk): ['70.5949554', '26.1727619'], time:40.08102ms, swizzle: NOOP, TFLOPS: 27.43 (+0.00%) + out_f32x4(t8x8bcf): ['70.5949554', '26.1727619'], time:39.66226ms, swizzle: NOOP, TFLOPS: 27.72 (+1.06%) + out_f32x4(t8x8dbuf): ['70.5949554', '26.1727619'], time:36.46554ms, swizzle: NOOP, TFLOPS: 30.15 (+8.77%) + out_f32(cublas): ['70.5949554', '26.1727619'], time:28.34019ms, swizzle: NOOP, TFLOPS: 38.80 (+28.67%) + out_f32_th: ['70.5949554', '26.1727619'], time:28.30972ms, swizzle: NOOP, TFLOPS: 38.84 (+0.11%) --------------------------------------------------------------WMMA---------------------------------------------------------------- - out_tf32(mma2x4+warp2x4+stage3): ['-17.849796', '0.19758633'], time:29.32828ms, swizzle: NOOP, TFLOPS: 37.49 (+0.69%) - out_tf32(mma2x4+warp2x4+stage2): ['-17.849796', '0.19758633'], time:26.04321ms, swizzle: NOOP, TFLOPS: 42.22 (+12.61%) - out_tf32(mma2x4+...+stage3+dsmem): ['-17.849796', '0.19758633'], time:27.00719ms, swizzle: NOOP, TFLOPS: 40.71 - out_tf32(mma2x4+...+stage2+dsmem): ['-17.849796', '0.19758633'], time:26.04371ms, swizzle: NOOP, TFLOPS: 42.22 -out_tf32(mma2x4+...+stage3+swizzle): ['-17.849796', '0.19758633'], time:28.73388ms, swizzle: 4096, TFLOPS: 38.27 -out_tf32(mma2x4+...+stage2+swizzle): ['-17.849796', '0.19758633'], time:23.58162ms, swizzle: 4096, TFLOPS: 46.63 (+10.44%) - out_tf32(...+stage3+dsmem+swizzle): ['-17.849796', '0.19758633'], time:25.63526ms, swizzle: 4096, TFLOPS: 42.89 - out_tf32(...+stage2+dsmem+swizzle): ['-17.849796', '0.19758633'], time:23.56603ms, swizzle: 4096, TFLOPS: 46.66 (+0.07%) - out_tf32(cublas+tf32): ['-17.849796', '0.19758633'], time:19.51431ms, swizzle: NOOP, TFLOPS: 56.34 (+20.76%) + out_tf32(mma2x4+warp2x4+stage3): ['70.5943985', '26.1725273'], time:28.73399ms, swizzle: NOOP, TFLOPS: 38.27 + out_tf32(mma2x4+warp2x4+stage2): ['70.5943985', '26.1725273'], time:25.33073ms, swizzle: NOOP, TFLOPS: 43.41 (+11.76%) + out_tf32(mma2x4+...+stage3+dsmem): ['70.5943985', '26.1725273'], time:26.69138ms, swizzle: NOOP, TFLOPS: 41.19 + out_tf32(mma2x4+...+stage2+dsmem): ['70.5943985', '26.1725273'], time:25.41232ms, swizzle: NOOP, TFLOPS: 43.27 +out_tf32(mma2x4+...+stage3+swizzle): ['70.5943985', '26.1725273'], time:28.79602ms, swizzle: 2048, TFLOPS: 38.18 +out_tf32(mma2x4+...+stage2+swizzle): ['70.5943985', '26.1725273'], time:23.39887ms, swizzle: 2048, TFLOPS: 46.99 (+8.26%) + out_tf32(...+stage3+dsmem+swizzle): ['70.5943985', '26.1725273'], time:25.56235ms, swizzle: 2048, TFLOPS: 43.01 + out_tf32(...+stage2+dsmem+swizzle): ['70.5943985', '26.1725273'], time:23.46084ms, swizzle: 2048, TFLOPS: 46.87 + out_tf32(cublas+tf32): ['70.5943985', '26.1725273'], time:19.40128ms, swizzle: NOOP, TFLOPS: 56.67 (+20.60%) ---------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------------------------------------- M=16384, N=16384, K=4096 - out_f32(naive): ['-24.539411', '26.0933971'], time:1103.022ms, swizzle: NOOP, TFLOPS: 1.99 (+0.00%) - out_f32x4(t8x8sk): ['-24.539411', '26.0933971'], time:83.92670ms, swizzle: NOOP, TFLOPS: 26.20 (+1214.27%) - out_f32x4(t8x8bcf): ['-24.539411', '26.0933971'], time:80.32040ms, swizzle: NOOP, TFLOPS: 27.38 (+4.49%) - out_f32x4(t8x8dbuf): ['-24.539411', '26.0933971'], time:81.26325ms, swizzle: NOOP, TFLOPS: 27.06 - out_f32(cublas): ['-24.539411', '26.0933971'], time:58.41444ms, swizzle: NOOP, TFLOPS: 37.65 (+37.50%) - out_f32_th: ['-24.539411', '26.0933971'], time:58.71068ms, swizzle: NOOP, TFLOPS: 37.46 + out_f32x4(t8x8sk): ['151.796371', '4.59689951'], time:81.40509ms, swizzle: NOOP, TFLOPS: 27.01 (+0.00%) + out_f32x4(t8x8bcf): ['151.796371', '4.59689951'], time:75.39424ms, swizzle: NOOP, TFLOPS: 29.17 (+7.97%) + out_f32x4(t8x8dbuf): ['151.796371', '4.59689951'], time:75.67217ms, swizzle: NOOP, TFLOPS: 29.06 + out_f32(cublas): ['151.796371', '4.59689951'], time:55.54578ms, swizzle: NOOP, TFLOPS: 39.59 (+35.73%) + out_f32_th: ['151.796371', '4.59689951'], time:56.35116ms, swizzle: NOOP, TFLOPS: 39.02 --------------------------------------------------------------WMMA---------------------------------------------------------------- - out_tf32(mma2x4+warp2x4+stage3): ['-24.539142', '26.0933208'], time:59.40877ms, swizzle: NOOP, TFLOPS: 37.02 - out_tf32(mma2x4+warp2x4+stage2): ['-24.539142', '26.0933208'], time:51.33775ms, swizzle: NOOP, TFLOPS: 42.83 (+13.78%) - out_tf32(mma2x4+...+stage3+dsmem): ['-24.539142', '26.0933208'], time:54.76785ms, swizzle: NOOP, TFLOPS: 40.15 - out_tf32(mma2x4+...+stage2+dsmem): ['-24.539142', '26.0933208'], time:51.49431ms, swizzle: NOOP, TFLOPS: 42.70 -out_tf32(mma2x4+...+stage3+swizzle): ['-24.539142', '26.0933208'], time:57.20764ms, swizzle: 4096, TFLOPS: 38.44 -out_tf32(mma2x4+...+stage2+swizzle): ['-24.539142', '26.0933208'], time:47.47045ms, swizzle: 4096, TFLOPS: 46.32 (+8.15%) - out_tf32(...+stage3+dsmem+swizzle): ['-24.539142', '26.0933208'], time:50.96282ms, swizzle: 4096, TFLOPS: 43.15 - out_tf32(...+stage2+dsmem+swizzle): ['-24.539142', '26.0933208'], time:47.44813ms, swizzle: 4096, TFLOPS: 46.35 (+0.05%) - out_tf32(cublas+tf32): ['-24.539142', '26.0933208'], time:38.30858ms, swizzle: NOOP, TFLOPS: 57.40 (+23.86%) + out_tf32(mma2x4+warp2x4+stage3): ['151.794143', '4.5965395 '], time:57.64467ms, swizzle: NOOP, TFLOPS: 38.15 + out_tf32(mma2x4+warp2x4+stage2): ['151.794143', '4.5965395 '], time:50.40433ms, swizzle: NOOP, TFLOPS: 43.63 (+10.20%) + out_tf32(mma2x4+...+stage3+dsmem): ['151.794143', '4.5965395 '], time:53.50663ms, swizzle: NOOP, TFLOPS: 41.10 + out_tf32(mma2x4+...+stage2+dsmem): ['151.794143', '4.5965395 '], time:50.22649ms, swizzle: NOOP, TFLOPS: 43.78 (+0.35%) +out_tf32(mma2x4+...+stage3+swizzle): ['151.794143', '4.5965395 '], time:57.27660ms, swizzle: 2048, TFLOPS: 38.39 +out_tf32(mma2x4+...+stage2+swizzle): ['151.794143', '4.5965395 '], time:46.61462ms, swizzle: 2048, TFLOPS: 47.17 (+7.75%) + out_tf32(...+stage3+dsmem+swizzle): ['151.794143', '4.5965395 '], time:50.91807ms, swizzle: 2048, TFLOPS: 43.19 + out_tf32(...+stage2+dsmem+swizzle): ['151.794143', '4.5965395 '], time:46.73092ms, swizzle: 2048, TFLOPS: 47.06 + out_tf32(cublas+tf32): ['151.794143', '4.5965395 '], time:38.29209ms, swizzle: NOOP, TFLOPS: 57.43 (+21.73%) ---------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------------------------------------- M=16384, N=16384, K=8192 - out_f32(naive): ['47.2999496', '96.8197784'], time:2203.968ms, swizzle: NOOP, TFLOPS: 2.00 (+0.00%) - out_f32x4(t8x8sk): ['47.2999496', '96.8197784'], time:164.6066ms, swizzle: NOOP, TFLOPS: 26.72 (+1238.93%) - out_f32x4(t8x8bcf): ['47.2999496', '96.8197784'], time:156.5503ms, swizzle: NOOP, TFLOPS: 28.09 (+5.15%) - out_f32x4(t8x8dbuf): ['47.2999496', '96.8197784'], time:157.3980ms, swizzle: NOOP, TFLOPS: 27.94 - out_f32(cublas): ['47.2999496', '96.8197784'], time:115.1302ms, swizzle: NOOP, TFLOPS: 38.20 (+35.98%) - out_f32_th: ['47.2999496', '96.8197784'], time:115.7525ms, swizzle: NOOP, TFLOPS: 38.00 + out_f32x4(t8x8sk): ['118.532104', '44.2729606'], time:162.8879ms, swizzle: NOOP, TFLOPS: 27.00 (+0.00%) + out_f32x4(t8x8bcf): ['118.532104', '44.2729606'], time:151.1848ms, swizzle: NOOP, TFLOPS: 29.09 (+7.74%) + out_f32x4(t8x8dbuf): ['118.532104', '44.2729606'], time:151.3025ms, swizzle: NOOP, TFLOPS: 29.07 + out_f32(cublas): ['118.532104', '44.2729606'], time:112.4181ms, swizzle: NOOP, TFLOPS: 39.12 (+34.48%) + out_f32_th: ['118.532104', '44.2729606'], time:112.4917ms, swizzle: NOOP, TFLOPS: 39.10 --------------------------------------------------------------WMMA---------------------------------------------------------------- - out_tf32(mma2x4+warp2x4+stage3): ['47.3002510', '96.8180007'], time:117.6291ms, swizzle: NOOP, TFLOPS: 37.39 - out_tf32(mma2x4+warp2x4+stage2): ['47.3002510', '96.8180007'], time:102.4631ms, swizzle: NOOP, TFLOPS: 42.92 (+12.36%) - out_tf32(mma2x4+...+stage3+dsmem): ['47.3002510', '96.8180007'], time:108.2939ms, swizzle: NOOP, TFLOPS: 40.61 - out_tf32(mma2x4+...+stage2+dsmem): ['47.3002510', '96.8180007'], time:104.3870ms, swizzle: NOOP, TFLOPS: 42.13 -out_tf32(mma2x4+...+stage3+swizzle): ['47.3002510', '96.8180007'], time:114.1043ms, swizzle: 4096, TFLOPS: 38.54 -out_tf32(mma2x4+...+stage2+swizzle): ['47.3002510', '96.8180007'], time:98.40396ms, swizzle: 4096, TFLOPS: 44.69 (+4.13%) - out_tf32(...+stage3+dsmem+swizzle): ['47.3002510', '96.8180007'], time:106.3529ms, swizzle: 4096, TFLOPS: 41.35 - out_tf32(...+stage2+dsmem+swizzle): ['47.3002510', '96.8180007'], time:97.71902ms, swizzle: 4096, TFLOPS: 45.01 (+0.70%) - out_tf32(cublas+tf32): ['47.3002510', '96.8180007'], time:75.97206ms, swizzle: NOOP, TFLOPS: 57.89 (+28.62%) + out_tf32(mma2x4+warp2x4+stage3): ['118.526184', '44.2716636'], time:115.7331ms, swizzle: NOOP, TFLOPS: 38.00 + out_tf32(mma2x4+warp2x4+stage2): ['118.526184', '44.2716636'], time:100.3637ms, swizzle: NOOP, TFLOPS: 43.82 (+12.01%) + out_tf32(mma2x4+...+stage3+dsmem): ['118.526184', '44.2716636'], time:106.3712ms, swizzle: NOOP, TFLOPS: 41.35 + out_tf32(mma2x4+...+stage2+dsmem): ['118.526184', '44.2716636'], time:102.4972ms, swizzle: NOOP, TFLOPS: 42.91 +out_tf32(mma2x4+...+stage3+swizzle): ['118.526184', '44.2716636'], time:114.2313ms, swizzle: 2048, TFLOPS: 38.50 +out_tf32(mma2x4+...+stage2+swizzle): ['118.526184', '44.2716636'], time:93.91186ms, swizzle: 2048, TFLOPS: 46.83 (+6.87%) + out_tf32(...+stage3+dsmem+swizzle): ['118.526184', '44.2716636'], time:101.5390ms, swizzle: 2048, TFLOPS: 43.31 + out_tf32(...+stage2+dsmem+swizzle): ['118.526184', '44.2716636'], time:93.69635ms, swizzle: 2048, TFLOPS: 46.94 (+0.23%) + out_tf32(cublas+tf32): ['118.526184', '44.2716636'], time:75.96850ms, swizzle: NOOP, TFLOPS: 57.89 (+23.34%) ---------------------------------------------------------------------------------------------------------------------------------- ``` diff --git a/sgemm/sgemm.py b/sgemm/sgemm.py index 5facc1eb..c8fda5b9 100644 --- a/sgemm/sgemm.py +++ b/sgemm/sgemm.py @@ -29,7 +29,7 @@ def run_benchmark(perf_func: callable, tag: str, out: Optional[torch.Tensor] = None, stages: int = -1, swizzle: bool = False, swizzle_stride: int = 1, - warmup: int = 2, iters: int = 50, + warmup: int = 2, iters: int = 20, show_all: bool = False): global MAX_TFLOPS @@ -40,11 +40,11 @@ def run_benchmark(perf_func: callable, if (a.size(0) > 1024 or a.size(1) >= 1024 or b.size(1) > 1024): - iters = 20 + iters = 10 if swizzle: # make swizzle stride as N/4 and multiples of 256 - swizzle_stride = int((int(N / 4) // 256) * 256) + swizzle_stride = int((int(N / 8) // 256) * 256) swizzle_stride = swizzle_stride if swizzle_stride >= 256 else 1 swizzle = swizzle if swizzle_stride >= 256 else False else: @@ -127,7 +127,7 @@ def run_benchmark(perf_func: callable, torch.cuda.synchronize() # CUDA Cores FP32 - run_benchmark(lib.sgemm_naive_f32, a, b, "f32(naive)", c) + # run_benchmark(lib.sgemm_naive_f32, a, b, "f32(naive)", c) run_benchmark(lib.sgemm_t_8x8_sliced_k_f32x4, a, b, "f32x4(t8x8sk)", c) run_benchmark(lib.sgemm_t_8x8_sliced_k_f32x4_bcf, a, b, "f32x4(t8x8bcf)", c) run_benchmark(lib.sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf, a, b, "f32x4(t8x8dbuf)", c) diff --git a/sgemm/sgemm_wmma_tf32_stage.cu b/sgemm/sgemm_wmma_tf32_stage.cu index 6e609bb9..1294ad12 100644 --- a/sgemm/sgemm_wmma_tf32_stage.cu +++ b/sgemm/sgemm_wmma_tf32_stage.cu @@ -50,16 +50,24 @@ __global__ void f32x4_tf32x4_kernel(float* x, float* y, int N) { } // stage2/3/4 (stage2=double buffers+copy async) -// 1. 当使用的shared memory超过48 KB时,需要使用dynamic shared -// memory, 即extern __shared__ float smem[];这样声明一块动态 -// 共享内存,调用kernel时 需要指定动态共享内存大小,且smem的寻址 -// 方式需要按照一维数组来使用 2. 提高L2 Cache的局部性(Thread -// Block Swizzle): https://zhuanlan.zhihu.com/p/555339335 -template +// 1. When using shared memory exceeds 48 KB, dynamic shared memory needs to be used, +// i.e., declare a block of dynamic shared memory with extern shared half smem[];. +// When calling the kernel, the size of the dynamic shared memory needs to be specified, +// and smem addressing should be used in a one-dimensional array manner. +// 2. Improve L2 Cache locality (Thread Block Swizzle): https://zhuanlan.zhihu.com/p/555339335 +// 3. __launch_bounds__: avoid error 'too many resources required for launch' +// reference: https://blog.csdn.net/feng__shuai/article/details/124395023 +template __global__ void sgemm_wmma_m16n16k8_mma4x2_warp2x4_stages_kernel( float* A, float* B, float* C, int M, int N, int K) { // 256 threads(8 warps) per block. @@ -71,10 +79,7 @@ __global__ void sgemm_wmma_m16n16k8_mma4x2_warp2x4_stages_kernel( constexpr int BM = WMMA_M * WMMA_TILE_M * WARP_TILE_M; // 16x4*2=128 constexpr int BN = WMMA_N * WMMA_TILE_N * WARP_TILE_N; // 16x2*4=128 constexpr int BK = WMMA_K; // 8 - // s2: 2*128*(8+4)*4=12KB, 2*8*(128+4)*4=8.25KB, ~21KB - // s3: 3*128*(8+4)*4=18KB, 3*8*(128+4)*4=12.375KB, ~31KB - // s4: 4*128*(8+4)*4=24KB, 4*8*(128+4)*4=16.5KB, ~41KB - __shared__ float s_a[K_STAGE][BM][BK+OFFSET], s_b[K_STAGE][BK][BN+OFFSET]; + __shared__ float s_a[K_STAGE][BM][BK+A_PAD], s_b[K_STAGE][BK][BN+B_PAD]; // 要保证相同的warp下thread执行相同的指令 const int tid = threadIdx.y * blockDim.x + threadIdx.x; @@ -162,14 +167,14 @@ __global__ void sgemm_wmma_m16n16k8_mma4x2_warp2x4_stages_kernel( for (int i = 0; i < WARP_TILE_M; ++i) { // load 2 tiles -> reg, smem a -> frags a, warp_m 0~3 const int warp_smem_a_m = warp_m * (WMMA_M * WARP_TILE_M) + i * WMMA_M; - wmma::load_matrix_sync(A_frag[i], &s_a[smem_sel][warp_smem_a_m][0], BK+OFFSET); + wmma::load_matrix_sync(A_frag[i], &s_a[smem_sel][warp_smem_a_m][0], BK+A_PAD); } #pragma unroll for (int j = 0; j < WARP_TILE_N; ++j) { // load 4 tiles -> reg, smem b -> frags b, warp_n 0~2 const int warp_smem_b_n = warp_n * (WMMA_N * WARP_TILE_N) + j * WMMA_N; - wmma::load_matrix_sync(B_frag[j], &s_b[smem_sel][0][warp_smem_b_n], BN+OFFSET); + wmma::load_matrix_sync(B_frag[j], &s_b[smem_sel][0][warp_smem_b_n], BN+B_PAD); } #pragma unroll @@ -203,14 +208,14 @@ __global__ void sgemm_wmma_m16n16k8_mma4x2_warp2x4_stages_kernel( for (int i = 0; i < WARP_TILE_M; ++i) { // load 2 tiles -> reg, smem a -> frags a, warp_m 0~3 const int warp_smem_a_m = warp_m * (WMMA_M * WARP_TILE_M) + i * WMMA_M; - wmma::load_matrix_sync(A_frag[i], &s_a[stage_sel][warp_smem_a_m][0], BK+OFFSET); + wmma::load_matrix_sync(A_frag[i], &s_a[stage_sel][warp_smem_a_m][0], BK+A_PAD); } #pragma unroll for (int j = 0; j < WARP_TILE_N; ++j) { // load 4 tiles -> reg, smem b -> frags b, warp_n 0~2 const int warp_smem_b_n = warp_n * (WMMA_N * WARP_TILE_N) + j * WMMA_N; - wmma::load_matrix_sync(B_frag[j], &s_b[stage_sel][0][warp_smem_b_n], BN+OFFSET); + wmma::load_matrix_sync(B_frag[j], &s_b[stage_sel][0][warp_smem_b_n], BN+B_PAD); } #pragma unroll @@ -237,16 +242,24 @@ __global__ void sgemm_wmma_m16n16k8_mma4x2_warp2x4_stages_kernel( } // stage2/3/4 (stage2=double buffers+copy async) -// 1. 当使用的shared memory超过48 KB时,需要使用dynamic shared -// memory, 即extern __shared__ float smem[];这样声明一块动态 -// 共享内存,调用kernel时 需要指定动态共享内存大小,且smem的寻址 -// 方式需要按照一维数组来使用 2. 提高L2 Cache的局部性(Thread -// Block Swizzle): https://zhuanlan.zhihu.com/p/555339335 -template +// 1. When using shared memory exceeds 48 KB, dynamic shared memory needs to be used, +// i.e., declare a block of dynamic shared memory with extern shared half smem[];. +// When calling the kernel, the size of the dynamic shared memory needs to be specified, +// and smem addressing should be used in a one-dimensional array manner. +// 2. Improve L2 Cache locality (Thread Block Swizzle): https://zhuanlan.zhihu.com/p/555339335 +// 3. __launch_bounds__: avoid error 'too many resources required for launch' +// reference: https://blog.csdn.net/feng__shuai/article/details/124395023 +template __global__ void sgemm_wmma_m16n16k8_mma4x2_warp2x4_stages_dsmem_kernel( float* A, float* B, float* C, int M, int N, int K) { // 256 threads(8 warps) per block. @@ -263,9 +276,9 @@ __global__ void sgemm_wmma_m16n16k8_mma4x2_warp2x4_stages_dsmem_kernel( // s4: 4*128*(8+4)*4=24KB, 4*8*(128+4)*4=16.5KB, ~41KB extern __shared__ float smem[]; float* s_a = smem; - float* s_b = smem + K_STAGE * BM * (BK + OFFSET); - constexpr int s_a_stage_offset = BM * (BK + OFFSET); - constexpr int s_b_stage_offset = BK * (BN + OFFSET); + float* s_b = smem + K_STAGE * BM * (BK + A_PAD); + constexpr int s_a_stage_offset = BM * (BK + A_PAD); + constexpr int s_b_stage_offset = BK * (BN + B_PAD); // 要保证相同的warp下thread执行相同的指令 const int tid = threadIdx.y * blockDim.x + threadIdx.x; @@ -311,14 +324,14 @@ __global__ void sgemm_wmma_m16n16k8_mma4x2_warp2x4_stages_dsmem_kernel( uint32_t load_smem_a_ptr = ( smem_a_base_ptr + (k * s_a_stage_offset + - load_smem_a_m * (BK + OFFSET) + + load_smem_a_m * (BK + A_PAD) + load_smem_a_k) * sizeof(float) ); CP_ASYNC_CG(load_smem_a_ptr, &A[load_gmem_a_addr], 16); uint32_t load_smem_b_ptr = ( smem_b_base_ptr + (k * s_b_stage_offset + - load_smem_b_k * (BN + OFFSET) + + load_smem_b_k * (BN + B_PAD) + load_smem_b_n) * sizeof(float) ); CP_ASYNC_CG(load_smem_b_ptr, &B[load_gmem_b_addr], 16); @@ -346,14 +359,14 @@ __global__ void sgemm_wmma_m16n16k8_mma4x2_warp2x4_stages_dsmem_kernel( // load stage 2, k start from 2 uint32_t load_smem_a_ptr = ( smem_a_base_ptr + (smem_sel_next * s_a_stage_offset + - load_smem_a_m * (BK + OFFSET) + + load_smem_a_m * (BK + A_PAD) + load_smem_a_k) * sizeof(float) ); CP_ASYNC_CG(load_smem_a_ptr, &A[load_gmem_a_addr], 16); uint32_t load_smem_b_ptr = ( smem_b_base_ptr + (smem_sel_next * s_b_stage_offset + - load_smem_b_k * (BN + OFFSET) + + load_smem_b_k * (BN + B_PAD) + load_smem_b_n) * sizeof(float) ); CP_ASYNC_CG(load_smem_b_ptr, &B[load_gmem_b_addr], 16); @@ -370,9 +383,9 @@ __global__ void sgemm_wmma_m16n16k8_mma4x2_warp2x4_stages_dsmem_kernel( // load 2 tiles -> reg, smem a -> frags a, warp_m 0~3 int warp_smem_a_m = warp_m * (WMMA_M * WARP_TILE_M) + i * WMMA_M; float* load_smem_a_frag_ptr = (s_a + smem_sel * s_a_stage_offset + - warp_smem_a_m * (BK + OFFSET) + warp_smem_a_m * (BK + A_PAD) + 0); // BK=WMMA_K=8 - wmma::load_matrix_sync(A_frag[i], load_smem_a_frag_ptr, BK + OFFSET); + wmma::load_matrix_sync(A_frag[i], load_smem_a_frag_ptr, BK + A_PAD); } #pragma unroll @@ -380,9 +393,9 @@ __global__ void sgemm_wmma_m16n16k8_mma4x2_warp2x4_stages_dsmem_kernel( // load 4 tiles -> reg, smem b -> frags b, warp_n 0~2 int warp_smem_b_n = warp_n * (WMMA_N * WARP_TILE_N) + j * WMMA_N; float* load_smem_b_frag_ptr = (s_b + smem_sel * s_b_stage_offset + - 0 * (BN + OFFSET) + + 0 * (BN + B_PAD) + warp_smem_b_n); // BK=WMMA_K=8 - wmma::load_matrix_sync(B_frag[j], load_smem_b_frag_ptr, BN + OFFSET); + wmma::load_matrix_sync(B_frag[j], load_smem_b_frag_ptr, BN + B_PAD); } #pragma unroll @@ -417,9 +430,9 @@ __global__ void sgemm_wmma_m16n16k8_mma4x2_warp2x4_stages_dsmem_kernel( // load 2 tiles -> reg, smem a -> frags a, warp_m 0~3 int warp_smem_a_m = warp_m * (WMMA_M * WARP_TILE_M) + i * WMMA_M; float* load_smem_a_frag_ptr = (s_a + stage_sel * s_a_stage_offset + - warp_smem_a_m * (BK + OFFSET) + warp_smem_a_m * (BK + A_PAD) + 0); // BK=WMMA_K=8 - wmma::load_matrix_sync(A_frag[i], load_smem_a_frag_ptr, BK + OFFSET); + wmma::load_matrix_sync(A_frag[i], load_smem_a_frag_ptr, BK + A_PAD); } #pragma unroll @@ -427,9 +440,9 @@ __global__ void sgemm_wmma_m16n16k8_mma4x2_warp2x4_stages_dsmem_kernel( // load 4 tiles -> reg, smem b -> frags b, warp_n 0~2 int warp_smem_b_n = warp_n * (WMMA_N * WARP_TILE_N) + j * WMMA_N; float* load_smem_b_frag_ptr = (s_b + stage_sel * s_b_stage_offset + - 0 * (BN + OFFSET) + + 0 * (BN + B_PAD) + warp_smem_b_n); // BK=WMMA_K=8 - wmma::load_matrix_sync(B_frag[j], load_smem_b_frag_ptr, BN + OFFSET); + wmma::load_matrix_sync(B_frag[j], load_smem_b_frag_ptr, BN + B_PAD); } #pragma unroll @@ -481,8 +494,8 @@ if (((T).size(0) != (S0)) || ((T).size(1) != (S1))) { \ N_SWIZZLE); \ sgemm_wmma_m16n16k8_mma4x2_warp2x4_stages_kernel< \ WMMA_M, WMMA_N, WMMA_K, WMMA_TILE_M, WMMA_TILE_N, \ - WARP_TILE_M, WARP_TILE_N, (stages), OFFSET, true><<< \ - grid, block>>>( \ + WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD, \ + (stages), true><<>>( \ reinterpret_cast(a.data_ptr()), \ reinterpret_cast(b.data_ptr()), \ reinterpret_cast(c.data_ptr()), \ @@ -496,8 +509,8 @@ if (((T).size(0) != (S0)) || ((T).size(1) != (S1))) { \ dim3 grid(div_ceil(N, BN), div_ceil(M, BM)); \ sgemm_wmma_m16n16k8_mma4x2_warp2x4_stages_kernel< \ WMMA_M, WMMA_N, WMMA_K, WMMA_TILE_M, WMMA_TILE_N, \ - WARP_TILE_M, WARP_TILE_N, (stages), OFFSET, false><<< \ - grid, block>>>( \ + WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD, \ + (stages), false><<>>( \ reinterpret_cast(a.data_ptr()), \ reinterpret_cast(b.data_ptr()), \ reinterpret_cast(c.data_ptr()), \ @@ -509,12 +522,12 @@ if (((T).size(0) != (S0)) || ((T).size(1) != (S1))) { \ #define LAUNCH_16168_STAGE_SWIZZLE_DSMEM_KERNEL(stages, stride) \ { \ const int smem_max_size = ( \ - (stages) * BM * (BK + OFFSET) * sizeof(float) + \ - (stages) * BK * (BN + OFFSET) * sizeof(float)); \ + (stages) * BM * (BK + A_PAD) * sizeof(float) + \ + (stages) * BK * (BN + B_PAD) * sizeof(float)); \ cudaFuncSetAttribute( \ sgemm_wmma_m16n16k8_mma4x2_warp2x4_stages_dsmem_kernel< \ WMMA_M, WMMA_N, WMMA_K, WMMA_TILE_M, WMMA_TILE_N, \ - WARP_TILE_M, WARP_TILE_N, (stages), OFFSET, true>, \ + WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD, (stages), true>, \ cudaFuncAttributeMaxDynamicSharedMemorySize, \ 98304); \ const int N_SWIZZLE = (N + (stride) - 1) / (stride); \ @@ -524,7 +537,7 @@ if (((T).size(0) != (S0)) || ((T).size(1) != (S1))) { \ N_SWIZZLE); \ sgemm_wmma_m16n16k8_mma4x2_warp2x4_stages_dsmem_kernel< \ WMMA_M, WMMA_N, WMMA_K, WMMA_TILE_M, WMMA_TILE_N, \ - WARP_TILE_M, WARP_TILE_N, (stages), OFFSET, true><<< \ + WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD, (stages), true><<< \ grid, block, smem_max_size>>>( \ reinterpret_cast(a.data_ptr()), \ reinterpret_cast(b.data_ptr()), \ @@ -536,19 +549,19 @@ if (((T).size(0) != (S0)) || ((T).size(1) != (S1))) { \ #define LAUNCH_16168_STAGE_NO_SWIZZLE_DSMEM_KERNEL(stages) \ { \ const int smem_max_size = ( \ - (stages) * BM * (BK + OFFSET) * sizeof(float) + \ - (stages) * BK * (BN + OFFSET) * sizeof(float)); \ + (stages) * BM * (BK + A_PAD) * sizeof(float) + \ + (stages) * BK * (BN + B_PAD) * sizeof(float)); \ cudaFuncSetAttribute( \ sgemm_wmma_m16n16k8_mma4x2_warp2x4_stages_dsmem_kernel< \ WMMA_M, WMMA_N, WMMA_K, WMMA_TILE_M, WMMA_TILE_N, \ - WARP_TILE_M, WARP_TILE_N, (stages), OFFSET, false>, \ + WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD, (stages), false>,\ cudaFuncAttributeMaxDynamicSharedMemorySize, \ 98304); \ dim3 block(NUM_THREADS); \ dim3 grid(div_ceil(N, BN), div_ceil(M, BM)); \ sgemm_wmma_m16n16k8_mma4x2_warp2x4_stages_dsmem_kernel< \ WMMA_M, WMMA_N, WMMA_K, WMMA_TILE_M, WMMA_TILE_N, \ - WARP_TILE_M, WARP_TILE_N, (stages), OFFSET, false><<< \ + WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD, (stages), false><<<\ grid, block, smem_max_size>>>( \ reinterpret_cast(a.data_ptr()), \ reinterpret_cast(b.data_ptr()), \ @@ -576,7 +589,6 @@ void sgemm_wmma_m16n16k8_mma4x2_warp2x4_stages( const int Nb = K * N; constexpr int T = 256; - // TODO: multi streams for a and b. f32x4_tf32x4_kernel<<<((Na + T * 4 - 1)/(T * 4)), T>>>( reinterpret_cast(a.data_ptr()), reinterpret_cast(a.data_ptr()), @@ -594,24 +606,32 @@ void sgemm_wmma_m16n16k8_mma4x2_warp2x4_stages( constexpr int WMMA_TILE_N = 2; constexpr int WARP_TILE_M = 2; constexpr int WARP_TILE_N = 4; - constexpr int OFFSET = 0; + // s_a 2 ways bank conflicts within warp, after pad 4 -> 2 ways bank conflicts. + // s_b 8 ways bank conflicts within warp, after pad 4 -> 4 ways bank conflicts. + // so, the best padding policy for s_a and s_b is A_PAD=0, B_PAD=0/4/8. + // B_PAD consume 16x~ less smem than A_PAD, 8xB_PAD vs 128xA_PAD. + constexpr int A_PAD = 0; + constexpr int B_PAD = 0; constexpr int NUM_THREADS= ( WMMA_TILE_M * WMMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256 constexpr int BM = WMMA_M * WMMA_TILE_M * WARP_TILE_M; constexpr int BN = WMMA_N * WMMA_TILE_N * WARP_TILE_N; - constexpr int BK = WMMA_K; + constexpr int BK = WMMA_K; + // s2: 2*128*(8)*4=8KB, 2*8*(128+0~4)*4=8.25KB, 12~13KB + // s3: 3*128*(8)*4=12KB, 3*8*(128+0~4)*4=12.375KB, 24~25KB + // s4: 4*128*(8)*4=16KB, 4*8*(128+0~4)*4=16.5KB, 32~33KB if (swizzle) { assert(swizzle_stride % 256 == 0); switch (stages) { - case 2: // ~21KB + case 2: LAUNCH_16168_STAGE_SWIZZLE_KERNEL(2, swizzle_stride); break; - case 3: // ~31KB + case 3: LAUNCH_16168_STAGE_SWIZZLE_KERNEL(3, swizzle_stride); break; - case 4: // ~41KB + case 4: LAUNCH_16168_STAGE_SWIZZLE_KERNEL(4, swizzle_stride); break; default: @@ -655,7 +675,6 @@ void sgemm_wmma_m16n16k8_mma4x2_warp2x4_stages_dsmem( const int Nb = K * N; constexpr int T = 256; - // TODO: multi streams for a and b. f32x4_tf32x4_kernel<<<((Na + T * 4 - 1)/(T * 4)), T>>>( reinterpret_cast(a.data_ptr()), reinterpret_cast(a.data_ptr()), @@ -673,27 +692,35 @@ void sgemm_wmma_m16n16k8_mma4x2_warp2x4_stages_dsmem( constexpr int WMMA_TILE_N = 2; constexpr int WARP_TILE_M = 2; constexpr int WARP_TILE_N = 4; - constexpr int OFFSET = 0; + // s_a 2 ways bank conflicts within warp, after pad 4 -> 2 ways bank conflicts. + // s_b 8 ways bank conflicts within warp, after pad 4 -> 4 ways bank conflicts. + // so, the best padding policy for s_a and s_b is A_PAD=0, B_PAD=0/4/8. + // B_PAD consume 16x~ less smem than A_PAD, 8xB_PAD vs 128xA_PAD. + constexpr int A_PAD = 0; + constexpr int B_PAD = 0; constexpr int NUM_THREADS= ( WMMA_TILE_M * WMMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256 constexpr int BM = WMMA_M * WMMA_TILE_M * WARP_TILE_M; constexpr int BN = WMMA_N * WMMA_TILE_N * WARP_TILE_N; - constexpr int BK = WMMA_K; + constexpr int BK = WMMA_K; + // s2: 2*128*(8)*4=8KB, 2*8*(128+0~4)*4=8.25KB, 12~13KB + // s3: 3*128*(8)*4=12KB, 3*8*(128+0~4)*4=12.375KB, 24~25KB + // s4: 4*128*(8)*4=16KB, 4*8*(128+0~4)*4=16.5KB, 32~33KB if (swizzle) { assert(swizzle_stride % 256 == 0); switch (stages) { - case 2: // ~21KB + case 2: LAUNCH_16168_STAGE_SWIZZLE_DSMEM_KERNEL(2, swizzle_stride); break; - case 3: // ~31KB + case 3: LAUNCH_16168_STAGE_SWIZZLE_DSMEM_KERNEL(3, swizzle_stride); break; - case 4: // ~41KB + case 4: LAUNCH_16168_STAGE_SWIZZLE_DSMEM_KERNEL(4, swizzle_stride); break; - case 5: // ~52KB + case 5: LAUNCH_16168_STAGE_SWIZZLE_DSMEM_KERNEL(5, swizzle_stride); break; default: