From 45b3a3a7defeb8b6971ea8ddc67e081565569687 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 8 Jan 2026 13:57:21 +0800 Subject: [PATCH] [Refactor] Update main function signatures in example scripts to accept parameters directly (#1630) * Modify main functions in example_custom_compress.py, example_gemm_sp.py, and example_vertical_slash_sparse_attn.py to accept parameters directly instead of using argparse for improved flexibility. * Update corresponding calls to main functions in the script execution section. * Ensure consistency in matrix dimensions and argument handling across examples. --- examples/gemm_sp/example_custom_compress.py | 46 ++++++++++-------- examples/gemm_sp/example_gemm_sp.py | 28 +++++------ .../example_vertical_slash_sparse_attn.py | 47 +++++++++---------- examples/minference/test_vs_sparse_attn.py | 2 +- 4 files changed, 62 insertions(+), 61 deletions(-) diff --git a/examples/gemm_sp/example_custom_compress.py b/examples/gemm_sp/example_custom_compress.py index 7b93f2a77..351cabdf7 100644 --- a/examples/gemm_sp/example_custom_compress.py +++ b/examples/gemm_sp/example_custom_compress.py @@ -291,28 +291,17 @@ def kernel( return kernel -def main(): - parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") - parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") - parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") - parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") - parser.add_argument("--use_cutlass_layout", action="store_true", help="Use cutlass layout for E tensor") - parser.add_argument("--use_torch_compressor", action="store_true", help="Use torch sparse for reference") - parser.add_argument("--accum_dtype", type=str, default=T.float, choices=[T.float, T.float16], help="Accumulation datatype") - parser.add_argument("--cfg", type=str, choices=["4090"], default="4090") - args = parser.parse_args() - kernel = matmul_sp_fp16_custom_compress( - args.m, args.n, args.k, args.accum_dtype, **DEFAULT_CONFIG[args.cfg][args.accum_dtype], use_cutlass_layout=args.use_cutlass_layout - ) +def main(M=1024, N=1024, K=1024, use_cutlass_layout=False, use_torch_compressor=False, accum_dtype=T.float, cfg="4090"): + kernel = matmul_sp_fp16_custom_compress(M, N, K, accum_dtype, **DEFAULT_CONFIG[cfg][accum_dtype], use_cutlass_layout=use_cutlass_layout) - a = randn_semi_sparse(args.m, args.k, device="cuda", dtype=torch.half) - b = torch.randn(args.k, args.n, device="cuda", dtype=torch.half) + a = randn_semi_sparse(M, K, device="cuda", dtype=torch.half) + b = torch.randn(K, N, device="cuda", dtype=torch.half) - if args.use_torch_compressor: - assert not args.use_cutlass_layout, "torch sparse must be used with naive layout" + if use_torch_compressor: + assert not use_cutlass_layout, "torch sparse must be used with naive layout" a_sparse, e = torch_compress(a) else: - a_sparse, e = compress_kernel(args.m, args.k, 32, 32, T.float16, use_cutlass_layout=args.use_cutlass_layout)(a) + a_sparse, e = compress_kernel(M, K, 32, 32, T.float16, use_cutlass_layout=use_cutlass_layout)(a) c = kernel(a_sparse, e, b) @@ -325,7 +314,7 @@ def main(): latency = do_bench(lambda: kernel(a_sparse, e, b)) ref_latency = do_bench(lambda: a @ b) - total_flops = 2 * args.m * args.n * args.k + total_flops = 2 * M * N * K tflops = total_flops / latency / 1e9 ref_tflops = total_flops / ref_latency / 1e9 print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency / 1e3} s") @@ -333,4 +322,21 @@ def main(): if __name__ == "__main__": - main() + parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") + parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") + parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") + parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") + parser.add_argument("--use_cutlass_layout", action="store_true", help="Use cutlass layout for E tensor") + parser.add_argument("--use_torch_compressor", action="store_true", help="Use torch sparse for reference") + parser.add_argument("--accum_dtype", type=str, default=T.float, choices=[T.float, T.float16], help="Accumulation datatype") + parser.add_argument("--cfg", type=str, choices=["4090"], default="4090") + args = parser.parse_args() + main( + M=args.m, + N=args.n, + K=args.k, + use_cutlass_layout=args.use_cutlass_layout, + use_torch_compressor=args.use_torch_compressor, + accum_dtype=args.accum_dtype, + cfg=args.cfg, + ) diff --git a/examples/gemm_sp/example_gemm_sp.py b/examples/gemm_sp/example_gemm_sp.py index 10f524adb..e0026e30a 100644 --- a/examples/gemm_sp/example_gemm_sp.py +++ b/examples/gemm_sp/example_gemm_sp.py @@ -97,20 +97,13 @@ def gemm_sp_fp16( return gemm_sp_fp16 -def main(): - parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") - parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") - parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") - parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") - parser.add_argument("--accum_dtype", type=str, default=T.float, choices=[T.float, T.float16], help="Accumulation datatype") - parser.add_argument("--cfg", type=str, choices=["4090", "h20"], default="4090") - args = parser.parse_args() - kernel = matmul_sp_fp16(args.m, args.n, args.k, args.accum_dtype, **DEFAULT_CONFIG[args.cfg][args.accum_dtype]) +def main(M=1024, N=1024, K=1024, accum_dtype=T.float, cfg="h20"): + kernel = matmul_sp_fp16(M, N, K, accum_dtype, **DEFAULT_CONFIG[cfg][accum_dtype]) - a = randn_semi_sparse(args.m, args.k, device="cuda", dtype=torch.half) - b = torch.randn(args.k, args.n, device="cuda", dtype=torch.half) + a = randn_semi_sparse(M, K, device="cuda", dtype=torch.half) + b = torch.randn(K, N, device="cuda", dtype=torch.half) - a_sparse, e = compress(a, transposed=False, block_k=DEFAULT_CONFIG[args.cfg][args.accum_dtype]["block_K"], arch=arch) + a_sparse, e = compress(a, transposed=False, block_k=DEFAULT_CONFIG[cfg][accum_dtype]["block_K"], arch=arch) c = kernel(a_sparse, e, b) ref_c = a @ b @@ -122,7 +115,7 @@ def main(): latency = do_bench(lambda: kernel(a_sparse, e, b)) ref_latency = do_bench(lambda: a @ b) - total_flops = 2 * args.m * args.n * args.k + total_flops = 2 * M * N * K tflops = total_flops / latency / 1e9 ref_tflops = total_flops / ref_latency / 1e9 print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency / 1e3} s") @@ -130,4 +123,11 @@ def main(): if __name__ == "__main__": - main() + parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") + parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") + parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") + parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") + parser.add_argument("--accum_dtype", type=str, default=T.float, choices=[T.float, T.float16], help="Accumulation datatype") + parser.add_argument("--cfg", type=str, choices=["4090", "h20"], default="4090") + args = parser.parse_args() + main(M=args.m, N=args.n, K=args.k, accum_dtype=args.accum_dtype, cfg=args.cfg) diff --git a/examples/minference/example_vertical_slash_sparse_attn.py b/examples/minference/example_vertical_slash_sparse_attn.py index 91af8b454..526a2029e 100644 --- a/examples/minference/example_vertical_slash_sparse_attn.py +++ b/examples/minference/example_vertical_slash_sparse_attn.py @@ -560,21 +560,10 @@ def sum_all_diagonal_matrix(mat: torch.tensor): return sum_diags[:, :, 1:] -def main(argv=None): - parser = argparse.ArgumentParser() +def main(batch=1, heads=1, seq_len=4096, head_dim=64, vertical_size=1000, slash_size=200): + BATCH, N_HEADS, SEQ_LEN, D_HEAD = batch, heads, seq_len, head_dim - parser.add_argument("--batch", type=int, default=1) - parser.add_argument("--heads", type=int, default=1) - parser.add_argument("--seq_len", type=int, default=16384) - parser.add_argument("--head_dim", type=int, default=64) - parser.add_argument("--vertical_size", type=int, default=1000) - parser.add_argument("--slash_size", type=int, default=200) - - args = parser.parse_args(argv) - - BATCH, N_HEADS, SEQ_LEN, D_HEAD = args.batch, args.heads, args.seq_len, args.head_dim - - vertical_size, slash_size = args.vertical_size, args.slash_size + vertical_size, slash_size = vertical_size, slash_size torch.manual_seed(0) q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) @@ -613,17 +602,8 @@ def main(argv=None): print(f"speedup: {triton_time / tilelang_time:.2f}x") -def run_regression_perf(argv=None): - parser = argparse.ArgumentParser() - parser.add_argument("--batch", type=int, default=1) - parser.add_argument("--heads", type=int, default=1) - parser.add_argument("--seq_len", type=int, default=16384) - parser.add_argument("--head_dim", type=int, default=64) - parser.add_argument("--vertical_size", type=int, default=1000) - parser.add_argument("--slash_size", type=int, default=200) - args = parser.parse_args(argv) - BATCH, N_HEADS, SEQ_LEN, D_HEAD = args.batch, args.heads, args.seq_len, args.head_dim - vertical_size, slash_size = args.vertical_size, args.slash_size +def run_regression_perf(batch=1, heads=1, seq_len=16384, head_dim=64, vertical_size=1000, slash_size=200): + BATCH, N_HEADS, SEQ_LEN, D_HEAD = batch, heads, seq_len, head_dim torch.manual_seed(0) q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) @@ -687,4 +667,19 @@ def run_kernel_only(): if __name__ == "__main__": - main() + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=1) + parser.add_argument("--heads", type=int, default=1) + parser.add_argument("--seq_len", type=int, default=16384) + parser.add_argument("--head_dim", type=int, default=64) + parser.add_argument("--vertical_size", type=int, default=1000) + parser.add_argument("--slash_size", type=int, default=200) + args = parser.parse_args() + main( + batch=args.batch, + heads=args.heads, + seq_len=args.seq_len, + head_dim=args.head_dim, + vertical_size=args.vertical_size, + slash_size=args.slash_size, + ) diff --git a/examples/minference/test_vs_sparse_attn.py b/examples/minference/test_vs_sparse_attn.py index f01df3808..9e6741dcf 100644 --- a/examples/minference/test_vs_sparse_attn.py +++ b/examples/minference/test_vs_sparse_attn.py @@ -5,7 +5,7 @@ @tilelang.testing.requires_cuda def test_vs_sparse_attn(): - example_vertical_slash_sparse_attn.main(argv=[]) + example_vertical_slash_sparse_attn.main() if __name__ == "__main__":