Skip to content

Commit f70ccd0

Browse files
committed
int8 tensorize working
1 parent 20321fa commit f70ccd0

File tree

1 file changed

+50
-53
lines changed

1 file changed

+50
-53
lines changed

tests/python/unittest/test_mma_16x8x32_int8.py

Lines changed: 50 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,8 @@ def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None:
2323
with T.block("A_shared_warp"):
2424
v0, v1 = T.axis.remap("SS", [ax0, ax1])
2525
T.reads(A_shared[v0, v1])
26-
T.writes(A_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2])
27-
A_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2] = A_shared[
28-
v0, v1
29-
]
26+
T.writes(A_warp[v0 % 8 * 4 + v1 % 16 // 4, v1 // 16 * 8 + v0 // 8 * 4 + v1 % 4])
27+
A_warp[v0 % 8 * 4 + v1 % 16 // 4, v1 // 16 * 8 + v0 // 8 * 4 + v1 % 4] = A_shared[v0, v1]
3028

3129

3230
@T.prim_func
@@ -65,21 +63,19 @@ def ldmatrix_a_impl(a: T.handle, c: T.handle) -> None:
6563

6664
@T.prim_func
6765
def ldmatrix_b_desc(a: T.handle, c: T.handle) -> None:
68-
B_shared = T.match_buffer(a, (16, 32), "int8", align=128, offset_factor=16, scope="shared")
66+
B_shared = T.match_buffer(a, (32, 16), "int8", align=128, offset_factor=16, scope="shared")
6967
B_warp = T.match_buffer(c, (32, 16), "int8", align=128, offset_factor=16, scope="warp")
7068

7169
with T.block("root"):
72-
T.reads(B_shared[0:16, 0:32])
70+
T.reads(B_shared[0:32, 0:16])
7371
T.writes(B_warp[0:32, 0:16])
7472

75-
for ax0, ax1 in T.grid(16, 32):
73+
for ax0, ax1 in T.grid(32, 16):
7674
with T.block("B_shared_warp"):
7775
v0, v1 = T.axis.remap("SS", [ax0, ax1])
7876
T.reads(B_shared[v0, v1])
79-
T.writes(B_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2])
80-
B_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2] = B_shared[
81-
v0, v1
82-
]
77+
T.writes(B_warp[v1 % 8 * 4 + v0 % 4, v1 // 8 * 8 + v0 // 16 * 4 + v0 % 4])
78+
B_warp[v1 % 8 * 4 + v0 % 4, v1 // 8 * 8 + v0 // 16 * 4 + v0 % 4] = B_shared[v0, v1]
8379

8480

8581
@T.prim_func
@@ -88,7 +84,7 @@ def ldmatrix_b_impl(a: T.handle, c: T.handle) -> None:
8884
s0 = T.var("int32")
8985
B_shared = T.match_buffer(
9086
a,
91-
(16, 32),
87+
(32, 16),
9288
"int8",
9389
align=128,
9490
offset_factor=16,
@@ -97,7 +93,7 @@ def ldmatrix_b_impl(a: T.handle, c: T.handle) -> None:
9793
)
9894
B_warp = T.match_buffer(c, (32, 16), "int8", align=128, offset_factor=16, scope="warp")
9995
with T.block("root"):
100-
T.reads(B_shared[0:16, 0:32])
96+
T.reads(B_shared[0:32, 0:16])
10197
T.writes(B_warp[0:32, 0:16])
10298
tx = T.env_thread("threadIdx.x")
10399
T.launch_thread(tx, 32)
@@ -110,7 +106,7 @@ def ldmatrix_b_impl(a: T.handle, c: T.handle) -> None:
110106
B_warp.data,
111107
16 * tx,
112108
B_shared.data,
113-
32 * (tx % 16) + 16 * (tx // 16),
109+
16 * tx,
114110
dtype="int8",
115111
)
116112
)
@@ -125,22 +121,12 @@ def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
125121
with T.block("root"):
126122
T.reads(C[0:32, 0:8], A[0:32, 0:16], B[0:32, 0:16])
127123
T.writes(C[0:32, 0:8])
128-
for i, j, k in T.grid(32, 8, 16):
124+
for i, j, k in T.grid(16, 16, 32):
129125
with T.block("C"):
130126
i, j, k = T.axis.remap("SSR", [i, j, k])
131-
T.reads(
132-
C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2],
133-
A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k % 2],
134-
B[k % 8 * 4 + j % 8 // 2, j // 8 * 4 + k // 8 * 2 + j % 2],
135-
)
127+
T.reads(C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2], A[i % 8 * 4 + k % 16 // 4, k // 16 * 8 + i // 8 * 4 + k % 4], B[j % 8 * 4 + k % 4, j // 8 * 8 + k // 16 * 4 + k % 4])
136128
T.writes(C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2])
137-
C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2] = C[
138-
i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2
139-
] + T.cast(
140-
A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k % 2], "int32"
141-
) * T.cast(
142-
B[k % 8 * 4 + j % 8 // 2, j // 8 * 4 + k // 8 * 2 + j % 2], "int32"
143-
)
129+
C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2] = C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2] + T.cast(A[i % 8 * 4 + k % 16 // 4, k // 16 * 8 + i // 8 * 4 + k % 4], "int32") * T.cast(B[j % 8 * 4 + k % 4, j // 8 * 8 + k // 16 * 4 + k % 4], "int32")
144130

145131

146132
@T.prim_func
@@ -271,7 +257,9 @@ def mma_fill_impl(a: T.handle) -> None:
271257
tir.TensorIntrin.register("mma_fill", mma_fill_desc, mma_fill_impl)
272258

273259

274-
M = N = K = 16
260+
M = 16
261+
N = 16
262+
K = 32
275263

276264
def matmul_int8(n, m, k):
277265
a = te.placeholder((n, k), name="A", dtype="int8")
@@ -300,13 +288,12 @@ def f_compute(i, j):
300288

301289
def fetch_to_shared(block, idx):
302290
block_read = sch.cache_read(block, idx, "shared")
303-
if use_gpu:
304-
sch.compute_at(block_read, i1, True)
305-
warp_size = 32
306-
loops = sch.get_loops(block_read)
307-
fused = sch.fuse(*loops[-2:])
308-
f_0, f_1 = sch.split(fused, factors=[None, warp_size])
309-
sch.bind(f_1, "threadIdx.x")
291+
sch.compute_at(block_read, i1, True)
292+
warp_size = 32
293+
loops = sch.get_loops(block_read)
294+
fused = sch.fuse(*loops[-2:])
295+
f_0, f_1 = sch.split(fused, factors=[None, warp_size])
296+
sch.bind(f_1, "threadIdx.x")
310297

311298
return block_read
312299

@@ -320,18 +307,28 @@ def shared_16x16_to_ldmatrix_32x8_layout(i, j):
320307
return thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 8) % 2
321308

322309

310+
def shared_16x32_to_ldmatrix_32x16_layout(i, j):
311+
thread_id = 4 * (i % 8) + (j % 16) // 4
312+
return thread_id, 8 * (j // 16) + (i // 8) * 4 + j % 4
313+
314+
315+
def shared_32x16_to_ldmatrix_32x16_layout(i, j):
316+
thread_id = (i % 4) + 4 * (j % 8)
317+
return thread_id, 8 * (j // 8) + (i // 16) * 4 + i % 4
318+
319+
323320
block = sch.get_block("C")
324321

325322
A_warp = sch.cache_read(block, 0, "warp")
326323

327-
# sch.transform_layout(A_warp, 0, "write", index_map=shared_16x16_to_ldmatrix_32x8_layout)
324+
sch.transform_layout(A_warp, 0, "write", index_map=shared_16x32_to_ldmatrix_32x16_layout)
328325

329326
B_warp = sch.cache_read(block, 1, "warp")
330327

331-
# sch.transform_layout(B_warp, 0, "write", index_map=shared_16x16_to_ldmatrix_32x8_layout)
328+
sch.transform_layout(B_warp, 0, "write", index_map=shared_32x16_to_ldmatrix_32x16_layout)
332329

333-
# sch.tensorize(sch.get_loops(A_warp)[1], "mma.ldmatrix_a")
334-
# sch.tensorize(sch.get_loops(B_warp)[1], "mma.ldmatrix_b")
330+
sch.tensorize(sch.get_loops(A_warp)[1], "mma.ldmatrix_a")
331+
sch.tensorize(sch.get_loops(B_warp)[1], "mma.ldmatrix_b")
335332

336333
C_warp = sch.cache_write(block, 0, "warp")
337334
sch.reverse_compute_at(C_warp, sch.get_loops(block)[0])
@@ -344,7 +341,7 @@ def shared_16x16_to_ldmatrix_32x8_layout(i, j):
344341
fused_1 = sch.fuse(f_1, f_2)
345342
fused_2 = sch.fuse(f_0, f_3)
346343

347-
# sch.tensorize(outer, "mma_store")
344+
sch.tensorize(outer, "mma_store")
348345

349346
block_init_c = sch.decompose_reduction(block, sch.get_loops(block)[1])
350347

@@ -356,25 +353,25 @@ def shared_16x16_to_ldmatrix_32x8_layout(i, j):
356353
fused_2 = sch.fuse(f_0, f_3)
357354
sch.tensorize(fused_1, "mma_fill")
358355

359-
# sch.tensorize(sch.get_loops(block)[1], "mma.mma_sync")
356+
sch.tensorize(sch.get_loops(block)[1], "mma.mma_sync")
360357

361358
print(sch.mod.script())
362359

363360
# lowered = tvm.lower(sch.mod["main"])
364361

365-
# target = "cuda"
362+
target = "cuda"
366363

367-
# f = tvm.build(sch.mod["main"], target=target, name="dense")
368-
# dev = tvm.device(target, 0)
364+
f = tvm.build(sch.mod["main"], target=target, name="dense")
365+
dev = tvm.device(target, 0)
369366

370-
# a_np = np.random.randint(-128, 128, (M, K)).astype("int8")
371-
# b_np = np.random.randint(-128, 128, (K, N)).astype("int8")
372-
# c_np = np.dot(a_np.astype("int3232"), b_np.astype("in32"))
367+
a_np = np.random.randint(-128, 128, (M, K)).astype("int8")
368+
b_np = np.random.randint(-128, 128, (K, N)).astype("int8")
369+
c_np = np.dot(a_np.astype("int3232"), b_np.astype("in32"))
373370

374-
# a = tvm.nd.array(a_np, dev)
375-
# b = tvm.nd.array(b_np, dev)
376-
# c = tvm.nd.array(np.zeros((16, K), dtype="float32"), dev)
371+
a = tvm.nd.array(a_np, dev)
372+
b = tvm.nd.array(b_np, dev)
373+
c = tvm.nd.array(np.zeros((16, K), dtype="float32"), dev)
377374

378-
# # print(f.imported_modules[0].get_source())
379-
# f(a, b, c)
380-
# np.testing.assert_equal(c.numpy(), c_np)
375+
# print(f.imported_modules[0].get_source())
376+
f(a, b, c)
377+
np.testing.assert_equal(c.numpy(), c_np)

0 commit comments

Comments
 (0)