Skip to content

Commit 86ee6da

Browse files
committed
int8 4k tensorize works
1 parent 39f9e32 commit 86ee6da

File tree

1 file changed

+66
-57
lines changed

1 file changed

+66
-57
lines changed

tests/python/unittest/test_mma_16x8x32_4k_tune.py

Lines changed: 66 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None:
2222
v0, v1 = T.axis.remap("SS", [ax0, ax1])
2323
T.reads(A_shared[v0, v1])
2424
T.writes(A_warp[v0 % 8 * 4 + v1 % 16 // 4, v1 // 16 * 8 + v0 // 8 * 4 + v1 % 4])
25-
A_warp[v0 % 8 * 4 + v1 % 16 // 4, v1 // 16 * 8 + v0 // 8 * 4 + v1 % 4] = A_shared[v0, v1]
25+
A_warp[v0 % 8 * 4 + v1 % 16 // 4, v1 // 16 * 8 + v0 // 8 * 4 + v1 % 4] = A_shared[
26+
v0, v1
27+
]
2628

2729

2830
@T.prim_func
@@ -122,9 +124,19 @@ def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
122124
for i, j, k in T.grid(16, 16, 32):
123125
with T.block("C"):
124126
i, j, k = T.axis.remap("SSR", [i, j, k])
125-
T.reads(C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2], A[i % 8 * 4 + k % 16 // 4, k // 16 * 8 + i // 8 * 4 + k % 4], B[j % 8 * 4 + k % 4, j // 8 * 8 + k // 16 * 4 + k % 4])
126-
T.writes(C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2])
127-
C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2] = C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2] + T.cast(A[i % 8 * 4 + k % 16 // 4, k // 16 * 8 + i // 8 * 4 + k % 4], "int32") * T.cast(B[j % 8 * 4 + k % 4, j // 8 * 8 + k // 16 * 4 + k % 4], "int32")
127+
T.reads(
128+
C[i % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 2],
129+
A[i % 8 * 4 + k % 16 // 4, k % 32 // 16 * 8 + i % 16 // 8 * 4 + k % 4],
130+
B[j % 8 * 4 + k % 4, j % 16 // 8 * 8 + k % 32 // 16 * 4 + k % 4],
131+
)
132+
T.writes(C[i % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 2])
133+
C[i % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 8 % 2] = C[
134+
i % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 8 % 2
135+
] + T.cast(
136+
A[i % 8 * 4 + k % 16 // 4, k % 32 // 16 * 8 + i % 16 // 8 * 4 + k % 4], "int32"
137+
) * T.cast(
138+
B[j % 8 * 4 + k % 4, j % 16 // 8 * 8 + k % 32 // 16 * 4 + k % 4], "int32"
139+
)
128140

129141

130142
@T.prim_func
@@ -266,6 +278,7 @@ def mma_fill_impl(a: T.handle) -> None:
266278
M = 4096
267279
K = 4096
268280

281+
269282
def matmul_int8(n, m, k):
270283
a = te.placeholder((n, k), name="A", dtype="int8")
271284
b = te.placeholder((k, m), name="B", dtype="int8")
@@ -289,8 +302,8 @@ def schedule(sch: tir.Schedule):
289302
block = sch.get_block("C")
290303
i, j, k = sch.get_loops(block)
291304
i, i_tc = sch.split(i, factors=[None, 16])
292-
j, j_tc = sch.split(j, factors=[None, 32])
293-
k, k_tc = sch.split(k, factors=[None, 16])
305+
j, j_tc = sch.split(j, factors=[None, 16])
306+
k, k_tc = sch.split(k, factors=[None, 32])
294307

295308
sch.reorder(
296309
i,
@@ -311,8 +324,8 @@ def schedule(sch: tir.Schedule):
311324
num_ty = sch.get(i_factors[2]) * sch.get(j_factors[2])
312325
else:
313326
i_factors = [4, 8, 2, 4, 1]
314-
j_factors = [1, 32, 2, 1, 2]
315-
k_factors = [128, 2, 1]
327+
j_factors = [1, 64, 2, 1, 2]
328+
k_factors = [64, 2, 1]
316329

317330
num_ty = i_factors[2] * j_factors[2]
318331

@@ -381,13 +394,10 @@ def fetch_to_shared(block, idx, ndim):
381394

382395
block_init_c = sch.decompose_reduction(block_outer, sch.get_loops(block_outer)[3])
383396

384-
def tile_wmma_fragment(block_read, height, is_b=False):
397+
def tile_wmma_fragment(block_read, height, width):
385398
i, j = sch.get_loops(block_read)[-2:]
386399
i0, i1 = sch.split(i, factors=[None, height])
387-
if is_b:
388-
j0, j1 = sch.split(j, factors=[32, None])
389-
else:
390-
j0, j1 = sch.split(j, factors=[None, 32])
400+
j0, j1 = sch.split(j, factors=[None, width])
391401
sch.reorder(i0, j0, i1, j1)
392402
return i1
393403

@@ -411,7 +421,6 @@ def shared_16x32_to_ldmatrix_32x16_layout(i, j):
411421
thread_id = 4 * (i % 8) + (j % 16) // 4
412422
return i_0, j_0, thread_id, 8 * (j // 16) + (i // 8) * 4 + j % 4
413423

414-
415424
def shared_32x16_to_ldmatrix_32x16_layout(i, j):
416425
i_0 = i // 32
417426
j_0 = j // 16
@@ -422,8 +431,8 @@ def shared_32x16_to_ldmatrix_32x16_layout(i, j):
422431
thread_id = (i % 4) + 4 * (j % 8)
423432
return i_0, j_0, thread_id, 8 * (j // 8) + (i // 16) * 4 + i % 4
424433

425-
loop_a = tile_wmma_fragment(A_warp, 16)
426-
loop_b = tile_wmma_fragment(B_warp, 16, True)
434+
loop_a = tile_wmma_fragment(A_warp, 16, 32)
435+
loop_b = tile_wmma_fragment(B_warp, 32, 16)
427436

428437
sch.transform_layout(A_warp, 0, "write", index_map=shared_16x32_to_ldmatrix_32x16_layout)
429438
sch.transform_layout(B_warp, 0, "write", index_map=shared_32x16_to_ldmatrix_32x16_layout)
@@ -460,44 +469,44 @@ def shared_32x16_to_ldmatrix_32x16_layout(i, j):
460469
schedule(sch)
461470
print(sch.mod.script())
462471

463-
# if tune:
464-
# with tempfile.TemporaryDirectory() as work_dir:
465-
# sch = ms.tune_tir(
466-
# mod=workload,
467-
# target=tvm.target.Target("nvidia/geforce-rtx-3070"),
468-
# config=ms.TuneConfig(
469-
# strategy="evolutionary",
470-
# num_trials_per_iter=32,
471-
# max_trials_per_task=128,
472-
# max_trials_global=128,
473-
# ),
474-
# work_dir=work_dir,
475-
# space=ms.space_generator.ScheduleFn(schedule),
476-
# )
477-
# if sch is None:
478-
# print("No valid schedule found!")
479-
# else:
480-
# print(sch.mod.script())
481-
# print(sch.trace)
482-
# else:
483-
# target = "cuda"
484-
# f = tvm.build(sch.mod["main"], target=target, name="dense")
485-
486-
# dev = tvm.device("cuda", 0)
487-
# a_np = np.random.uniform(size=(N, K)).astype("int8")
488-
# b_np = np.random.uniform(size=(K, M)).astype("int8")
489-
# c_np = np.dot(a_np.astype("int32"), b_np.astype("int32"))
490-
# a = tvm.nd.array(a_np, dev)
491-
# b = tvm.nd.array(b_np, dev)
492-
# c = tvm.nd.array(np.zeros((M, N), dtype="int32"), dev)
493-
# f = tvm.build(sch.mod["main"], target="cuda", name="dense")
494-
495-
# print(f.imported_modules[0].get_source())
496-
# f(a, b, c)
497-
# tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
498-
# print("ok")
499-
500-
# evaluator = f.time_evaluator(f.entry_name, dev, number=1000)
501-
# gflops = (N * M * K) * 2 / 1e9
502-
# time_ms = evaluator(a, b, c).mean * 1e3
503-
# print("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms, gflops / (time_ms / 1e3)))
472+
if tune:
473+
with tempfile.TemporaryDirectory() as work_dir:
474+
sch = ms.tune_tir(
475+
mod=workload,
476+
target=tvm.target.Target("nvidia/geforce-rtx-3070"),
477+
config=ms.TuneConfig(
478+
strategy="evolutionary",
479+
num_trials_per_iter=32,
480+
max_trials_per_task=128,
481+
max_trials_global=128,
482+
),
483+
work_dir=work_dir,
484+
space=ms.space_generator.ScheduleFn(schedule),
485+
)
486+
if sch is None:
487+
print("No valid schedule found!")
488+
else:
489+
print(sch.mod.script())
490+
print(sch.trace)
491+
else:
492+
target = "cuda"
493+
f = tvm.build(sch.mod["main"], target=target, name="dense")
494+
495+
dev = tvm.device("cuda", 0)
496+
a_np = np.random.uniform(size=(N, K)).astype("int8")
497+
b_np = np.random.uniform(size=(K, M)).astype("int8")
498+
c_np = np.dot(a_np.astype("int32"), b_np.astype("int32"))
499+
a = tvm.nd.array(a_np, dev)
500+
b = tvm.nd.array(b_np, dev)
501+
c = tvm.nd.array(np.zeros((M, N), dtype="int32"), dev)
502+
f = tvm.build(sch.mod["main"], target="cuda", name="dense")
503+
504+
print(f.imported_modules[0].get_source())
505+
f(a, b, c)
506+
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
507+
print("ok")
508+
509+
evaluator = f.time_evaluator(f.entry_name, dev, number=1000)
510+
gflops = (N * M * K) * 2 / 1e9
511+
time_ms = evaluator(a, b, c).mean * 1e3
512+
print("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms, gflops / (time_ms / 1e3)))

0 commit comments

Comments
 (0)