Skip to content

Commit

Permalink
[Docs] rename mat_transpose -> mat-transpose (#93)
Browse files Browse the repository at this point in the history
* Update sgemm_wmma_tf32_stage.cu

* Update sgemm.py

* Update README.md

* Update sgemm_wmma_tf32_stage.cu

* Update hgemm_wmma_stage.cu

* Update hgemm.cu

* Update hgemm.py

* Update hgemm.py

* rename mat_transpose->mat-transpose

* update hgemm benchmark

* update hgemm benchmark
  • Loading branch information
DefTruth authored Oct 19, 2024
1 parent 2f854e8 commit 523a610
Show file tree
Hide file tree
Showing 11 changed files with 922 additions and 522 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@
| ✔️ [embedding_f16x2](./embedding/embedding.cu)|f16|/|[link](./embedding/)|⭐️|
| ✔️ [embedding_f16x8](./embedding/embedding.cu)|f16|/|[link](./embedding/)|⭐️|
| ✔️ [embedding_f16x8_pack](./embedding/embedding.cu)|f16|/|[link](./embedding/)|⭐️⭐️|
| ✔️ [mat_trans_f32_col2row{2d}](./mat_transpose/mat_transpose.cu)|f32|/|[link](./mat_transpose/)|⭐️|
| ✔️ [mat_trans_f32_row2col{2d}](./mat_transpose/mat_transpose.cu)|f32|/|[link](./mat_transpose/)|⭐️|
| ✔️ [mat_trans_f32_diagonal2d](./mat_transpose/mat_transpose.cu)|f32|/|[link](./mat_transpose/)|⭐️⭐️|
| ✔️ [mat_trans_f32x4_col2row{2d}](./mat_transpose/mat_transpose.cu)|f32|/|[link](./mat_transpose/)|⭐️⭐️|
| ✔️ [mat_trans_f32x4_row2col{2d}](./mat_transpose/mat_transpose.cu)|f32|/|[link](./mat_transpose/)|⭐️⭐️|
| ✔️ [mat_trans_f32_col2row{2d}](./mat-transpose/mat_transpose.cu)|f32|/|[link](./mat-transpose/)|⭐️|
| ✔️ [mat_trans_f32_row2col{2d}](./mat-transpose/mat_transpose.cu)|f32|/|[link](./mat-transpose/)|⭐️|
| ✔️ [mat_trans_f32_diagonal2d](./mat-transpose/mat_transpose.cu)|f32|/|[link](./mat-transpose/)|⭐️⭐️|
| ✔️ [mat_trans_f32x4_col2row{2d}](./mat-transpose/mat_transpose.cu)|f32|/|[link](./mat-transpose/)|⭐️⭐️|
| ✔️ [mat_trans_f32x4_row2col{2d}](./mat-transpose/mat_transpose.cu)|f32|/|[link](./mat-transpose/)|⭐️⭐️|
| ✔️ [warp_reduce_[all]](./reduce/reduce.cu)|all|all|[link](./reduce/)|⭐️⭐️|
| ✔️ [reduce_f32_f32](./reduce/reduce.cu)|f32|f32|[link](./reduce/)|⭐️⭐️|
| ✔️ [reduce_f32x4_f32](./reduce/reduce.cu)|f32|f32|[link](./reduce/)|⭐️⭐️|
Expand Down
3 changes: 3 additions & 0 deletions hgemm/hgemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1237,6 +1237,8 @@ void hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem(torch::Tensor a, torch::Te
int stages, bool swizzle, int swizzle_stride);
void hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c,
int stages, bool swizzle, int swizzle_stride);
void hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c,
int stages, bool swizzle, int swizzle_stride);


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
Expand Down Expand Up @@ -1284,5 +1286,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages)
TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem)
TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem)
TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_stages_dsmem)
}

78 changes: 44 additions & 34 deletions hgemm/hgemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from torch.utils.cpp_extension import load
from functools import partial
from typing import Optional
import argparse

torch.set_grad_enabled(False)

Expand Down Expand Up @@ -95,15 +96,24 @@ def run_benchmark(perf_func: callable,
else:
improve = 0
MAX_TFLOPS = TFLOPS
print(f"{out_info:>40}: {out_val}, time:{mean_time}ms, "
print(f"{out_info:>35}: {out_val}, time:{mean_time}ms, "
f"swizzle: {swizzle_stride:<4}, TFLOPS: {TFLOPS:<6.2f}(+{improve:.2f}%)")
else:
print(f"{out_info:>40}: {out_val}, time:{mean_time}ms, "
print(f"{out_info:>35}: {out_val}, time:{mean_time}ms, "
f"swizzle: {swizzle_stride:<4}, TFLOPS: {TFLOPS:<6.2f}")
if show_all: print(out)
return out, mean_time


def get_args():
parser = argparse.ArgumentParser(description="hgemm benchmark")
parser.add_argument("--enable-mma-all", "-ma", action="store_true")
parser.add_argument("--enable-wmma-all", "-wa", action="store_true")
parser.add_argument("--enable-cuda-all", "-ca", action="store_true")
return parser.parse_args()


args = get_args()
Ms = [4096, 8192, 16384]
Ns = [4096, 8192, 16384]
Ks = [2048, 4096, 8192]
Expand All @@ -124,44 +134,44 @@ def run_benchmark(perf_func: callable,
c = C[:M, :N].contiguous()
torch.cuda.synchronize()

# CUDA Cores FP16
# run_benchmark(lib.hgemm_naive_f16, a, b, "f16(naive)", c)
# run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x8_pack_bcf, a, b, "f16x8pack(t8x8+bcf)", c)
if args.enable_cuda_all:
# CUDA Cores FP16
run_benchmark(lib.hgemm_naive_f16, a, b, "f16(naive)", c)
run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x8_pack_bcf, a, b, "f16x8pack(t8x8+bcf)", c)

run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x8_pack_bcf_dbuf, a, b, "f16x8pack(t8x8+dbuf)", c)
run_benchmark(lib.hgemm_t_8x8_sliced_k16_f16x8_pack_dbuf, a, b, "f16x8pack(t8x8+k16+dbuf)", c)

print("-" * 68 + "WMMA" + "-" * 58)
# run_benchmark(lib.hgemm_wmma_m16n16k16_naive, a, b, "f16wmma(naive)", c)
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2, a, b, "f16wmma(mma4x2)", c)
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4, a, b, "f16wmma(mma4x2+warp2x4)", c)
# run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_dbuf_async_offset, a, b, "f16wmma(mma2x4+warp2x4+dbuf)", c)

# Stages, dsmem
# run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "f16wmma(mma2x4+warp2x4+stage4)", c, stages=4)
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "f16wmma(mma2x4+warp2x4+stage3)", c, stages=3)
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "f16wmma(mma2x4+warp2x4+stage2)", c, stages=2)

# run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "f16wmma(mma2x4+...+stage4+dsmem)", c, stages=4)
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "f16wmma(mma2x4+...+stage3+dsmem)", c, stages=3)
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "f16wmma(mma2x4+...+stage2+dsmem)", c, stages=2)

# run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "f16wmma(mma4x4+...+stage4+dsmem)", c, stages=4)
# run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "f16wmma(mma4x4+...+stage3+dsmem)", c, stages=3)
# run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "f16wmma(mma4x4+...+stage2+dsmem)", c, stages=2)
# wmma api, stages, dsmem, swizzle
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2, a, b, "(mma4x2)", c)
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4, a, b, "(mma4x2+warp2x4)", c)

# Thread block swizzle
# run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "f16wmma(mma2x4+...+stage4+swizzle)", c, stages=4, swizzle=True)
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "f16wmma(mma2x4+...+stage3+swizzle)", c, stages=3, swizzle=True)
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "f16wmma(mma2x4+...+stage2+swizzle)", c, stages=2, swizzle=True)

# run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "f16wmma(...+stage4+dsmem+swizzle)", c, stages=4, swizzle=True)
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "f16wmma(...+stage3+dsmem+swizzle)", c, stages=3, swizzle=True)
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "f16wmma(...+stage2+dsmem+swizzle)", c, stages=2, swizzle=True)

# run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "f16wmma(mma4x4+stage4+dsmem+swizzle)", c, stages=4, swizzle=True)
# run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "f16wmma(mma4x4+stage3+dsmem+swizzle)", c, stages=3, swizzle=True)
# run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "f16wmma(mma4x4+stage2+dsmem+swizzle)", c, stages=2, swizzle=True)
# prefer on NVIDIA L20 device.
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "(mma2x4+warp2x4+stage3)", c, stages=3)
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "(mma2x4+warp2x4+stage2)", c, stages=2)

run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "(mma2x4+...+stage3+dsmem)", c, stages=3)
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "(mma2x4+...+stage2+dsmem)", c, stages=2)

run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "(mma2x4+...+stage3+swizzle)", c, stages=3, swizzle=True)
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "(mma2x4+...+stage2+swizzle)", c, stages=2, swizzle=True)

run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "(...+stage3+dsmem+swizzle)", c, stages=3, swizzle=True)
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "(...+stage2+dsmem+swizzle)", c, stages=2, swizzle=True)

if args.enable_wmma_all:
# prefer on NVIDIA TRX 3080 Laptop 16GB GDDR6 device.
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "(mma4x4+...+stage3+dsmem)", c, stages=3)
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "(mma4x4+...+stage2+dsmem)", c, stages=2)
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_stages_dsmem, a, b, "(warp2x4x2+...+stage3+dsmem)", c, stages=3)
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_stages_dsmem, a, b, "(warp2x4x2+...+stage2+dsmem)", c, stages=2)

run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "(mma4x4+stage3+dsmem+swizzle)", c, stages=3, swizzle=True)
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "(mma4x4+stage2+dsmem+swizzle)", c, stages=2, swizzle=True)
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_stages_dsmem, a, b, "(warp2x4x2+stage3+dsmem+swizzle)", c, stages=3, swizzle=True)
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_stages_dsmem, a, b, "(warp2x4x2+stage2+dsmem+swizzle)", c, stages=2, swizzle=True)

run_benchmark(lib.hgemm_cublas_tensor_op, a, b, "f16(cublas)", c)
run_benchmark(partial(torch.matmul, out=c), a, b, "f16_th")
torch.cuda.synchronize()
Expand Down
Loading

0 comments on commit 523a610

Please sign in to comment.