Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 26 additions & 20 deletions examples/gemm_sp/example_custom_compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,28 +291,17 @@ def kernel(
return kernel


def main():
parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
parser.add_argument("--use_cutlass_layout", action="store_true", help="Use cutlass layout for E tensor")
parser.add_argument("--use_torch_compressor", action="store_true", help="Use torch sparse for reference")
parser.add_argument("--accum_dtype", type=str, default=T.float, choices=[T.float, T.float16], help="Accumulation datatype")
parser.add_argument("--cfg", type=str, choices=["4090"], default="4090")
args = parser.parse_args()
kernel = matmul_sp_fp16_custom_compress(
args.m, args.n, args.k, args.accum_dtype, **DEFAULT_CONFIG[args.cfg][args.accum_dtype], use_cutlass_layout=args.use_cutlass_layout
)
def main(M=1024, N=1024, K=1024, use_cutlass_layout=False, use_torch_compressor=False, accum_dtype=T.float, cfg="4090"):
kernel = matmul_sp_fp16_custom_compress(M, N, K, accum_dtype, **DEFAULT_CONFIG[cfg][accum_dtype], use_cutlass_layout=use_cutlass_layout)

a = randn_semi_sparse(args.m, args.k, device="cuda", dtype=torch.half)
b = torch.randn(args.k, args.n, device="cuda", dtype=torch.half)
a = randn_semi_sparse(M, K, device="cuda", dtype=torch.half)
b = torch.randn(K, N, device="cuda", dtype=torch.half)

if args.use_torch_compressor:
assert not args.use_cutlass_layout, "torch sparse must be used with naive layout"
if use_torch_compressor:
assert not use_cutlass_layout, "torch sparse must be used with naive layout"
a_sparse, e = torch_compress(a)
else:
a_sparse, e = compress_kernel(args.m, args.k, 32, 32, T.float16, use_cutlass_layout=args.use_cutlass_layout)(a)
a_sparse, e = compress_kernel(M, K, 32, 32, T.float16, use_cutlass_layout=use_cutlass_layout)(a)

c = kernel(a_sparse, e, b)

Expand All @@ -325,12 +314,29 @@ def main():
latency = do_bench(lambda: kernel(a_sparse, e, b))
ref_latency = do_bench(lambda: a @ b)

total_flops = 2 * args.m * args.n * args.k
total_flops = 2 * M * N * K
tflops = total_flops / latency / 1e9
ref_tflops = total_flops / ref_latency / 1e9
print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency / 1e3} s")
print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency / 1e3:} s")


if __name__ == "__main__":
main()
parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
parser.add_argument("--use_cutlass_layout", action="store_true", help="Use cutlass layout for E tensor")
parser.add_argument("--use_torch_compressor", action="store_true", help="Use torch sparse for reference")
parser.add_argument("--accum_dtype", type=str, default=T.float, choices=[T.float, T.float16], help="Accumulation datatype")
parser.add_argument("--cfg", type=str, choices=["4090"], default="4090")
Comment on lines +330 to +332
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Argparse type mismatch for accum_dtype.

The --accum_dtype argument is declared with type=str but choices=[T.float, T.float16] contains type objects, not strings. Additionally, the default is T.float (a type object). This configuration is inconsistent: if a user passes --accum_dtype float, argparse will compare the string "float" against T.float, which will fail validation.

Proposed fix
-    parser.add_argument("--accum_dtype", type=str, default=T.float, choices=[T.float, T.float16], help="Accumulation datatype")
+    parser.add_argument("--accum_dtype", type=str, default="float", choices=["float", "float16"], help="Accumulation datatype")

Then map the string to the type in the main call:

+    dtype_map = {"float": T.float, "float16": T.float16}
     main(
         M=args.m,
         N=args.n,
         K=args.k,
         use_cutlass_layout=args.use_cutlass_layout,
         use_torch_compressor=args.use_torch_compressor,
-        accum_dtype=args.accum_dtype,
+        accum_dtype=dtype_map[args.accum_dtype],
         cfg=args.cfg,
     )

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In @examples/gemm_sp/example_custom_compress.py around lines 330 - 332, The
argparse entry for accum_dtype is inconsistent: change
parser.add_argument("--accum_dtype", type=str, choices=[T.float, T.float16],
default=T.float, ...) to use string choices and default (e.g., type=str,
choices=["float","float16"], default="float") in the parser.add_argument call,
and then map that string to the actual torch dtype before use (e.g., create a
mapping like accum_dtype_map = {"float": T.float, "float16": T.float16} and set
accum_dtype = accum_dtype_map[args.accum_dtype] where the script uses
accum_dtype).

args = parser.parse_args()
main(
M=args.m,
N=args.n,
K=args.k,
use_cutlass_layout=args.use_cutlass_layout,
use_torch_compressor=args.use_torch_compressor,
accum_dtype=args.accum_dtype,
cfg=args.cfg,
)
28 changes: 14 additions & 14 deletions examples/gemm_sp/example_gemm_sp.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,20 +97,13 @@ def gemm_sp_fp16(
return gemm_sp_fp16


def main():
parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
parser.add_argument("--accum_dtype", type=str, default=T.float, choices=[T.float, T.float16], help="Accumulation datatype")
parser.add_argument("--cfg", type=str, choices=["4090", "h20"], default="4090")
args = parser.parse_args()
kernel = matmul_sp_fp16(args.m, args.n, args.k, args.accum_dtype, **DEFAULT_CONFIG[args.cfg][args.accum_dtype])
def main(M=1024, N=1024, K=1024, accum_dtype=T.float, cfg="h20"):
kernel = matmul_sp_fp16(M, N, K, accum_dtype, **DEFAULT_CONFIG[cfg][accum_dtype])

a = randn_semi_sparse(args.m, args.k, device="cuda", dtype=torch.half)
b = torch.randn(args.k, args.n, device="cuda", dtype=torch.half)
a = randn_semi_sparse(M, K, device="cuda", dtype=torch.half)
b = torch.randn(K, N, device="cuda", dtype=torch.half)

a_sparse, e = compress(a, transposed=False, block_k=DEFAULT_CONFIG[args.cfg][args.accum_dtype]["block_K"], arch=arch)
a_sparse, e = compress(a, transposed=False, block_k=DEFAULT_CONFIG[cfg][accum_dtype]["block_K"], arch=arch)
c = kernel(a_sparse, e, b)

ref_c = a @ b
Expand All @@ -122,12 +115,19 @@ def main():
latency = do_bench(lambda: kernel(a_sparse, e, b))
ref_latency = do_bench(lambda: a @ b)

total_flops = 2 * args.m * args.n * args.k
total_flops = 2 * M * N * K
tflops = total_flops / latency / 1e9
ref_tflops = total_flops / ref_latency / 1e9
print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency / 1e3} s")
print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency / 1e3:} s")


if __name__ == "__main__":
main()
parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
parser.add_argument("--accum_dtype", type=str, default=T.float, choices=[T.float, T.float16], help="Accumulation datatype")
parser.add_argument("--cfg", type=str, choices=["4090", "h20"], default="4090")
Comment on lines +130 to +131
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Argparse type mismatch for accum_dtype.

Same issue as in example_custom_compress.py: the --accum_dtype argument uses type=str with choices=[T.float, T.float16] (type objects), which will cause validation failures when CLI arguments are passed.

Proposed fix
-    parser.add_argument("--accum_dtype", type=str, default=T.float, choices=[T.float, T.float16], help="Accumulation datatype")
+    parser.add_argument("--accum_dtype", type=str, default="float", choices=["float", "float16"], help="Accumulation datatype")

Then map in the call:

+    dtype_map = {"float": T.float, "float16": T.float16}
-    main(M=args.m, N=args.n, K=args.k, accum_dtype=args.accum_dtype, cfg=args.cfg)
+    main(M=args.m, N=args.n, K=args.k, accum_dtype=dtype_map[args.accum_dtype], cfg=args.cfg)

Committable suggestion skipped: line range outside the PR's diff.

args = parser.parse_args()
main(M=args.m, N=args.n, K=args.k, accum_dtype=args.accum_dtype, cfg=args.cfg)
47 changes: 21 additions & 26 deletions examples/minference/example_vertical_slash_sparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,21 +560,10 @@ def sum_all_diagonal_matrix(mat: torch.tensor):
return sum_diags[:, :, 1:]


def main(argv=None):
parser = argparse.ArgumentParser()
def main(batch=1, heads=1, seq_len=4096, head_dim=64, vertical_size=1000, slash_size=200):
BATCH, N_HEADS, SEQ_LEN, D_HEAD = batch, heads, seq_len, head_dim

parser.add_argument("--batch", type=int, default=1)
parser.add_argument("--heads", type=int, default=1)
parser.add_argument("--seq_len", type=int, default=16384)
parser.add_argument("--head_dim", type=int, default=64)
parser.add_argument("--vertical_size", type=int, default=1000)
parser.add_argument("--slash_size", type=int, default=200)

args = parser.parse_args(argv)

BATCH, N_HEADS, SEQ_LEN, D_HEAD = args.batch, args.heads, args.seq_len, args.head_dim

vertical_size, slash_size = args.vertical_size, args.slash_size
vertical_size, slash_size = vertical_size, slash_size

torch.manual_seed(0)
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
Expand Down Expand Up @@ -613,17 +602,8 @@ def main(argv=None):
print(f"speedup: {triton_time / tilelang_time:.2f}x")


def run_regression_perf(argv=None):
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=1)
parser.add_argument("--heads", type=int, default=1)
parser.add_argument("--seq_len", type=int, default=16384)
parser.add_argument("--head_dim", type=int, default=64)
parser.add_argument("--vertical_size", type=int, default=1000)
parser.add_argument("--slash_size", type=int, default=200)
args = parser.parse_args(argv)
BATCH, N_HEADS, SEQ_LEN, D_HEAD = args.batch, args.heads, args.seq_len, args.head_dim
vertical_size, slash_size = args.vertical_size, args.slash_size
def run_regression_perf(batch=1, heads=1, seq_len=16384, head_dim=64, vertical_size=1000, slash_size=200):
BATCH, N_HEADS, SEQ_LEN, D_HEAD = batch, heads, seq_len, head_dim
torch.manual_seed(0)
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
Expand Down Expand Up @@ -687,4 +667,19 @@ def run_kernel_only():


if __name__ == "__main__":
main()
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=1)
parser.add_argument("--heads", type=int, default=1)
parser.add_argument("--seq_len", type=int, default=16384)
parser.add_argument("--head_dim", type=int, default=64)
parser.add_argument("--vertical_size", type=int, default=1000)
parser.add_argument("--slash_size", type=int, default=200)
args = parser.parse_args()
main(
batch=args.batch,
heads=args.heads,
seq_len=args.seq_len,
head_dim=args.head_dim,
vertical_size=args.vertical_size,
slash_size=args.slash_size,
)
2 changes: 1 addition & 1 deletion examples/minference/test_vs_sparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

@tilelang.testing.requires_cuda
def test_vs_sparse_attn():
example_vertical_slash_sparse_attn.main(argv=[])
example_vertical_slash_sparse_attn.main()


if __name__ == "__main__":
Expand Down
Loading