Skip to content

Commit 0b3683b

Browse files
botbwLeiWang1999
andauthored
[feat] support gemm_sp for ampere and ada arch (#691)
* [feat] add an example mma atom * [fix] fix typo naming * [feat] add a template to enable compilation * [feat] add print util * [WIP] pass on single block tile * [feat] add sm80 metadata layout * [chore] clean codebase * [CI] format.sh * [feat] add sm80 compress utils * [bugfix] fix C fragment layout * [refactor] use nvcc version instead of str * [test] add test cases * [chore] add a param check * [chore] format a bit * [chore] rename func to satisfy PEP 8 and appease gemini * [chore] add check * [feat] support sm75 layout && add assertion && chore * [bug] fix illegal memory access when using two warps over N=32 This could be a missing check related to cutlass 2.x implementation. Using the cutlass example can't trigger this cause it's bypassed by padding the input. For now I think it might be safe to increase the atom size and inve- sgate in the future. * [chore] add example * [chore] format * [example] update benchmark * [bugfix] fix namespace and format * [bugfix] fix incorrect param passing * [refactor] update variable declaration for clarity in gemm_layouts and gemm_sp * [Cleanup] Remove unnecessary blank lines in metadata layout functions in gemm_sp.py * [CI] fix arch * [example] add torch sparse benchmark * [misc] polish && add reference && apply review suggestionsi && format * [CI] format with clang-tidy * [Cleanup] Format and align template struct definitions in half.hpp, common.h, and gemm_sp_sm80.h * [Update] Modify CUDA version requirements in test_gemm_sp_sm80 and mark cutlass subproject as dirty --------- Co-authored-by: LeiWang1999 <[email protected]>
1 parent f0d6669 commit 0b3683b

File tree

17 files changed

+1055
-94
lines changed

17 files changed

+1055
-94
lines changed

benchmark/matmul/benchmark_matmul_sp.py

Lines changed: 56 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,21 @@
44
import torch
55
from triton.testing import do_bench
66

7+
import tilelang
78
import tilelang.language as T
89
from tilelang.autotuner import autotune
910
from tilelang import jit
11+
from tilelang.contrib import nvcc
1012
from tilelang.layout import make_metadata_layout
13+
1114
# Configure logger
1215
logger = logging.getLogger(__name__)
1316
logger.setLevel(logging.DEBUG)
1417

18+
arch = nvcc.get_target_compute_version()
19+
20+
ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")}
21+
1522

1623
def ref_program(A, B):
1724
"""
@@ -79,11 +86,11 @@ def get_configs(M, N, K):
7986
return configs
8087

8188

82-
def matmul_sp(M, N, K):
89+
def matmul_sp(M, N, K, accum_dtype):
8390
"""
8491
Create an autotuned matrix multiplication kernel for matrices of shape:
8592
- A: (M, K)
86-
- B: (N, K)
93+
- B: (K, N)
8794
- C: (M, N)
8895
8996
Parameters
@@ -155,14 +162,14 @@ def kernel(
155162
# Use half-precision for input data to reduce memory bandwidth,
156163
# accumulate in float for better numerical accuracy
157164
dtype = "float16"
158-
accum_dtype = "float"
165+
e_factor, e_dtype = ARCH_INFO[arch]
159166

160167
@T.prim_func
161168
def main(
162169
A_sparse: T.Tensor((M, K // 2), dtype),
163-
E: T.Tensor((M, K // 8), 'uint8'),
164-
B: T.Tensor((N, K), dtype),
165-
C: T.Tensor((M, N), dtype),
170+
E: T.Tensor((M, K // e_factor), e_dtype),
171+
B: T.Tensor((K, N), dtype),
172+
C: T.Tensor((M, N), accum_dtype),
166173
):
167174
"""
168175
The compiled TVM function for block-level matrix multiplication.
@@ -182,13 +189,13 @@ def main(
182189
# Allocate shared memory for A sub-block of shape (block_M, block_K)
183190
A_shared = T.alloc_shared((block_M, block_K // 2), dtype)
184191
# Allocate shared memory for B sub-block of shape (block_N, block_K)
185-
B_shared = T.alloc_shared((block_N, block_K), dtype)
192+
B_shared = T.alloc_shared((block_K, block_N), dtype)
186193
# Allocate shared memory for E sub-block of shape (block_M, block_K // E_factor)
187-
E_shared = T.alloc_shared((block_M, block_K // 8), 'uint8')
194+
E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
188195
# Allocate a local fragment for intermediate accumulation
189196
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
190197
# Allocate a shared memory for C sub-block of shape (block_M, block_N)
191-
C_shared = T.alloc_shared((block_M, block_N), dtype)
198+
C_shared = T.alloc_shared((block_M, block_N), accum_dtype)
192199

193200
# Clear out the accumulation buffer
194201
T.clear(C_local)
@@ -198,32 +205,27 @@ def main(
198205
T.annotate_layout({
199206
E:
200207
make_metadata_layout(
201-
E, mma_dtype="float16", arch="sm90", backend="cutlass",
202-
block_k=block_K),
208+
E, mma_dtype="float16", backend="cutlass", block_k=block_K),
203209
E_shared:
204210
make_metadata_layout(
205-
E_shared,
206-
mma_dtype="float16",
207-
arch="sm90",
208-
backend="cutlass",
209-
block_k=block_K),
211+
E_shared, mma_dtype="float16", backend="cutlass", block_k=block_K),
210212
})
211213
# Loop over sub-blocks in K dimension, pipelined by num_stages
212214
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
213215
# Load a sub-block of A from global memory into A_shared
214-
T.copy(A_sparse[by * block_M, k * block_K], A_shared)
216+
T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
215217
# Load a sub-block of E from global memory into E_shared
216-
T.copy(E[by * block_M, k * block_K // 8], E_shared)
218+
T.copy(E[by * block_M, k * block_K // e_factor], E_shared)
217219
# Load a sub-block of B from global memory into B_shared
218-
T.copy(B[bx * block_N, k * block_K], B_shared)
220+
T.copy(B[k * block_K, bx * block_N], B_shared)
219221
# Perform a partial matrix multiplication:
220-
# C_local += A_shared @ B_shared^T
222+
# C_local += A_shared @ B_shared
221223
T.gemm_sp(
222224
A_shared,
223225
E_shared,
224226
B_shared,
225227
C_local,
226-
transpose_B=True,
228+
transpose_B=False,
227229
policy=policy,
228230
)
229231
# Write back the results from C_local to the global memory C
@@ -241,24 +243,53 @@ def main(
241243
parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
242244
parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
243245
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
246+
parser.add_argument("--disable_cache", action="store_true")
247+
parser.add_argument(
248+
"--accum_dtype",
249+
type=str,
250+
default="float",
251+
choices=["float", "float16"],
252+
help="Accumulation datatype")
253+
parser.add_argument(
254+
"--bench_torch_sparse",
255+
type=str,
256+
choices=['cutlass', 'cusparselt'],
257+
default=None,
258+
help="Whether to benchmark against torch sparse implementation, note that at current time only sm80 is supported"
259+
)
244260
args = parser.parse_args()
245261

262+
if args.disable_cache:
263+
tilelang.disable_cache()
264+
246265
M, N, K = args.m, args.n, args.k
247266

248267
# Compute total floating-point operations to measure throughput
249268
total_flops = 2 * M * N * K
250269

251270
# matmul(...) returns (best_latency, best_config, ref_latency)
252-
best_result = matmul_sp(M, N, K)
271+
best_result = matmul_sp(M, N, K, args.accum_dtype)
253272
best_latency = best_result.latency
254273
best_config = best_result.config
255274
A = torch.randn(M, K, dtype=torch.float16, device="cuda")
256-
B = torch.randn(N, K, dtype=torch.float16, device="cuda")
257-
ref_latency = do_bench(lambda: A @ B.T)
275+
B = torch.randn(K, N, dtype=torch.float16, device="cuda")
276+
ref_latency = do_bench(lambda: A @ B)
277+
278+
if args.bench_torch_sparse is not None:
279+
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
280+
if args.bench_torch_sparse == 'cutlass':
281+
SparseSemiStructuredTensor._FORCE_CUTLASS = True
282+
A_sp = to_sparse_semi_structured(A, transposed=False)
283+
torch_sparse_latency = do_bench(lambda: A_sp @ B)
258284

259285
# Print out the benchmark results
260286
print(f"Best latency (s): {best_latency}")
261287
print(f"Best TFlops: {total_flops / best_latency * 1e-9:.3f}")
262288
print(f"Best config: {best_config}")
263289

264-
print(f"Reference TFlops: {total_flops / ref_latency * 1e-9:.3f}")
290+
if args.bench_torch_sparse is not None:
291+
print(
292+
f"Torch sparse ({args.bench_torch_sparse}) TFlops: {total_flops / torch_sparse_latency * 1e-9:.3f}"
293+
)
294+
295+
print(f"Reference Dense TFlops: {total_flops / ref_latency * 1e-9:.3f}")
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# Copyright (c) Tile-AI Corporation.
2+
# Licensed under the MIT License.
3+
import argparse
4+
5+
import tilelang
6+
import tilelang.language as T
7+
8+
from tilelang.layout import make_metadata_layout
9+
from tilelang.utils.sparse import compress
10+
from tilelang.contrib import nvcc
11+
from triton.testing import do_bench
12+
13+
import torch
14+
15+
arch = nvcc.get_target_compute_version()
16+
17+
ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")}
18+
19+
default_config = { # take best config from autotune script
20+
"4090": {
21+
'float': {
22+
'block_M': 128,
23+
'block_N': 64,
24+
'block_K': 64,
25+
'num_stages': 1,
26+
'thread_num': 128,
27+
'policy': T.GemmWarpPolicy.Square,
28+
'enable_rasterization': True
29+
},
30+
'float16': {
31+
'block_M': 256,
32+
'block_N': 128,
33+
'block_K': 64,
34+
'num_stages': 2,
35+
'thread_num': 128,
36+
'policy': T.GemmWarpPolicy.Square,
37+
'enable_rasterization': True
38+
}
39+
},
40+
"h20": {
41+
'float': {
42+
'block_M': 128,
43+
'block_N': 64,
44+
'block_K': 128,
45+
'num_stages': 3,
46+
'thread_num': 128,
47+
'policy': T.GemmWarpPolicy.Square,
48+
'enable_rasterization': True
49+
},
50+
'float16': {
51+
'block_M': 128,
52+
'block_N': 64,
53+
'block_K': 128,
54+
'num_stages': 3,
55+
'thread_num': 128,
56+
'policy': T.GemmWarpPolicy.Square,
57+
'enable_rasterization': True
58+
}
59+
}
60+
}
61+
62+
63+
def generate_sparse_tensor(M: int, K: int, dtype=torch.float16, device='cuda'):
64+
elem, group = 2, 4
65+
full_tensor = torch.randn((M, K), dtype=dtype, device=device).view(M, -1, group)
66+
indice = full_tensor.topk(elem, dim=-1).indices
67+
full_tensor.scatter_(-1, indice, 0)
68+
return full_tensor.view(M, K)
69+
70+
71+
@tilelang.jit(out_idx=[-1])
72+
def matmul_sp_fp16(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy,
73+
enable_rasterization):
74+
e_factor, e_dtype = ARCH_INFO[arch]
75+
76+
@T.prim_func
77+
def gemm_sp_fp16(
78+
A_sparse: T.Tensor((M, K // 2), 'float16'),
79+
E: T.Tensor((M, K // e_factor), e_dtype),
80+
B: T.Tensor((K, N), 'float16'),
81+
C: T.Tensor((M, N), accum_dtype),
82+
):
83+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
84+
A_shared = T.alloc_shared((block_M, block_K // 2), 'float16')
85+
E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
86+
B_shared = T.alloc_shared((block_K, block_N), 'float16')
87+
C_shared = T.alloc_shared((block_M, block_N), accum_dtype)
88+
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
89+
90+
T.clear(C_local)
91+
T.disable_warp_group_reg_alloc()
92+
T.use_swizzle(panel_size=10, enable=enable_rasterization)
93+
T.annotate_layout({
94+
E:
95+
make_metadata_layout(
96+
E, mma_dtype="float16", backend="cutlass", block_k=block_K, arch=arch),
97+
E_shared:
98+
make_metadata_layout(
99+
E_shared,
100+
mma_dtype="float16",
101+
backend="cutlass",
102+
block_k=block_K,
103+
arch=arch),
104+
})
105+
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
106+
T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
107+
T.copy(E[by * block_M, k * block_K // e_factor], E_shared)
108+
T.copy(B[k * block_K, bx * block_N], B_shared)
109+
T.gemm_sp(A_shared, E_shared, B_shared, C_local, False, False, policy=policy)
110+
111+
T.copy(C_local, C_shared)
112+
T.copy(C_shared, C[by * block_M, bx * block_N])
113+
114+
return gemm_sp_fp16
115+
116+
117+
def main():
118+
parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
119+
parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
120+
parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
121+
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
122+
parser.add_argument(
123+
"--accum_dtype",
124+
type=str,
125+
default="float",
126+
choices=["float", "float16"],
127+
help="Accumulation datatype")
128+
parser.add_argument("--cfg", type=str, choices=["4090", "h20"], required=True)
129+
args = parser.parse_args()
130+
kernel = matmul_sp_fp16(args.m, args.n, args.k, args.accum_dtype,
131+
**default_config[args.cfg][args.accum_dtype])
132+
133+
a = generate_sparse_tensor(args.m, args.k, device='cuda', dtype=torch.half)
134+
b = torch.randn(args.k, args.n, device='cuda', dtype=torch.half)
135+
136+
a_sparse, e = compress(
137+
a,
138+
transposed=False,
139+
block_k=default_config[args.cfg][args.accum_dtype]['block_K'],
140+
arch=arch)
141+
c = kernel(a_sparse, e, b)
142+
143+
ref_c = a @ b
144+
145+
assert not c.isnan().any(), "Reference result contains NaNs, please report an issue"
146+
torch.testing.assert_close(c, ref_c.to(c.dtype), rtol=1e-2, atol=1e-2)
147+
print(f"Precision check passed. diff: {(c - ref_c).abs().mean()}")
148+
149+
latency = do_bench(lambda: kernel(a_sparse, e, b))
150+
ref_latency = do_bench(lambda: a @ b)
151+
152+
total_flops = 2 * args.m * args.n * args.k
153+
tflops = total_flops / latency / 1e9
154+
ref_tflops = total_flops / ref_latency / 1e9
155+
print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency/1e3} s")
156+
print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency/1e3:} s")
157+
158+
159+
if __name__ == "__main__":
160+
main()

examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,12 @@ def main(
4141
T.annotate_layout({
4242
E:
4343
make_metadata_layout(
44-
E, mma_dtype="float16", arch="sm90", backend="cutlass", block_k=block_K),
44+
E, mma_dtype="float16", arch="9.0", backend="cutlass", block_k=block_K),
4545
E_shared:
4646
make_metadata_layout(
4747
E_shared,
4848
mma_dtype="float16",
49-
arch="sm90",
49+
arch="9.0",
5050
backend="cutlass",
5151
block_k=block_K),
5252
})

0 commit comments

Comments
 (0)