Skip to content

Commit 50ac2cc

Browse files
committed
Summary:
This PR adds sparsify overhead benchmark, omitted in ICLR workshop paper: https://arxiv.org/abs/2503.16672 In the paper, there are two parts for the benchmark: 1) Sparsify operation overhead, 2) Sparse-GEMM kernel performance. Part 1) was omitted from the original benchmark, so this PR adds the missing sparsify-only benchmark comparing `torchao.sparse24_sm90_sparsify` against `torch._cslt_compress` (cuSPASRELt) baseline. Test plan: CI
1 parent afe5cab commit 50ac2cc

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

benchmarks/benchmark_e2e_fp8_sparse_linear.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,20 @@ def benchmark(num_tokens, hidden_size=8192, intermediate_size=8192):
4040
input_tensor = torch.randn(num_tokens, hidden_size).to(torch.bfloat16).cuda()
4141
fp16_time = benchmark_microseconds(ffn_ref, input_tensor)
4242

43+
# Sparsify-only benchmarks
44+
X_scale = torch.empty([num_tokens, 1], device="cuda", dtype=torch.float32)
45+
ao_cusparse_time = benchmark_microseconds(
46+
lambda: torch.ops.torchao.sparse24_sm90_sparsify(
47+
input_tensor,
48+
"cutlass",
49+
"srelu",
50+
"largest",
51+
dtype=torch.float8_e4m3fn,
52+
scale=X_scale,
53+
)
54+
)
55+
cusparse_time = benchmark_microseconds(lambda: torch._cslt_compress(input_tensor))
56+
4357
# bf16
4458
ffn_clone = (
4559
nn.Sequential(
@@ -117,7 +131,10 @@ def benchmark(num_tokens, hidden_size=8192, intermediate_size=8192):
117131
"fp8_c_time (us)": fp8_c_time,
118132
"fp8_c_sparse_time (us)": fp8_c_sparse_time,
119133
"fp8_c_activation_sparse_time (us)": fp8_c_activation_sparse_time,
134+
"ao_cusparse_time (us)": ao_cusparse_time,
135+
"cusparse_compress_time (us)": cusparse_time,
120136
"speedup": fp8_c_time / fp8_c_activation_sparse_time,
137+
"sparsify_speedup": cusparse_time / ao_cusparse_time,
121138
}
122139

123140

0 commit comments

Comments
 (0)