Skip to content

Commit 848de63

Browse files
committed
wip
1 parent b35bff9 commit 848de63

File tree

1 file changed

+76
-90
lines changed

1 file changed

+76
-90
lines changed

tests/python/unittest/test_mma_16x8x32_4k_tune.py

Lines changed: 76 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,27 @@
88
import numpy as np
99

1010

11+
def shared_16x16_to_ldmatrix_32x8_layout(i, j):
12+
thread_id = 4 * (i % 8) + (j % 8) // 2
13+
return thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 2)
14+
15+
16+
def shared_16x32_to_ldmatrix_32x16_layout(i, j):
17+
thread_id = 4 * (i % 8) + (j % 16) // 4
18+
return thread_id, 8 * (j // 16) + (i // 8) * 4 + j % 4
19+
20+
21+
def shared_32x16_to_ldmatrix_32x16_layout(i, j):
22+
thread_id = (i % 4) + 4 * (j % 8)
23+
return thread_id, 8 * (j // 8) + (i // 16) * 4 + i % 4
24+
25+
26+
@tvm._ffi.register_func("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout")
27+
def index_map_shared_16x16_to_ldmatrix_32x8_layout(i, j):
28+
thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(i, j)
29+
return tvm.runtime.convert([thread_id, local_id])
30+
31+
1132
@T.prim_func
1233
def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None:
1334
A_shared = T.match_buffer(a, (16, 32), "int8", align=128, offset_factor=16, scope="shared")
@@ -21,10 +42,9 @@ def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None:
2142
with T.block("A_shared_warp"):
2243
v0, v1 = T.axis.remap("SS", [ax0, ax1])
2344
T.reads(A_shared[v0, v1])
24-
T.writes(A_warp[v0 % 8 * 4 + v1 % 16 // 4, v1 // 16 * 8 + v0 // 8 * 4 + v1 % 4])
25-
A_warp[v0 % 8 * 4 + v1 % 16 // 4, v1 // 16 * 8 + v0 // 8 * 4 + v1 % 4] = A_shared[
26-
v0, v1
27-
]
45+
thread_id, local_id = shared_16x32_to_ldmatrix_32x16_layout(v0, v1)
46+
T.writes(A_warp[thread_id, local_id])
47+
A_warp[thread_id, local_id] = A_shared[v0, v1]
2848

2949

3050
@T.prim_func
@@ -74,8 +94,9 @@ def ldmatrix_b_desc(a: T.handle, c: T.handle) -> None:
7494
with T.block("B_shared_warp"):
7595
v0, v1 = T.axis.remap("SS", [ax0, ax1])
7696
T.reads(B_shared[v0, v1])
77-
T.writes(B_warp[v1 % 8 * 4 + v0 % 4, v1 // 8 * 8 + v0 // 16 * 4 + v0 % 4])
78-
B_warp[v1 % 8 * 4 + v0 % 4, v1 // 8 * 8 + v0 // 16 * 4 + v0 % 4] = B_shared[v0, v1]
97+
thread_id, local_id = shared_32x16_to_ldmatrix_32x16_layout(v0, v1)
98+
T.writes(B_warp[thread_id, local_id])
99+
B_warp[thread_id, local_id] = B_shared[v0, v1]
79100

80101

81102
@T.prim_func
@@ -124,18 +145,19 @@ def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
124145
for i, j, k in T.grid(16, 16, 32):
125146
with T.block("C"):
126147
i, j, k = T.axis.remap("SSR", [i, j, k])
148+
149+
thread_id_C, local_id_C = shared_16x16_to_ldmatrix_32x8_layout(i, j)
150+
thread_id_A, local_id_A = shared_16x32_to_ldmatrix_32x16_layout(i, k)
151+
thread_id_B, local_id_B = shared_32x16_to_ldmatrix_32x16_layout(k, j)
152+
127153
T.reads(
128-
C[i % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 2],
129-
A[i % 8 * 4 + k % 16 // 4, k % 32 // 16 * 8 + i % 16 // 8 * 4 + k % 4],
130-
B[j % 8 * 4 + k % 4, j % 16 // 8 * 8 + k % 32 // 16 * 4 + k % 4],
154+
C[thread_id_C, local_id_C],
155+
A[thread_id_A, local_id_A],
156+
B[thread_id_B, local_id_B],
131157
)
132-
T.writes(C[i % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 2])
133-
C[i % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 8 % 2] = C[
134-
i % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 8 % 2
135-
] + T.cast(
136-
A[i % 8 * 4 + k % 16 // 4, k % 32 // 16 * 8 + i % 16 // 8 * 4 + k % 4], "int32"
137-
) * T.cast(
138-
B[j % 8 * 4 + k % 4, j % 16 // 8 * 8 + k % 32 // 16 * 4 + k % 4], "int32"
158+
T.writes(C[thread_id_C, local_id_C])
159+
C[thread_id_C, local_id_C] += T.cast(A[thread_id_A, local_id_A], "int32") * T.cast(
160+
B[thread_id_B, local_id_B], "int32"
139161
)
140162

141163

@@ -198,14 +220,13 @@ def mma_store_desc(a: T.handle, c: T.handle) -> None:
198220
with T.block("root"):
199221
T.reads(C_warp[0:32, 0:8])
200222
T.writes(C[0:16, 0:16])
201-
for ax1_0, i0, i1 in T.grid(2, 32, 4):
223+
for i0, i1 in T.grid(16, 16):
202224
with T.block("C_warp"):
203-
v0 = T.axis.spatial(16, i1 // 2 * 8 + i0 // 4)
204-
v1 = T.axis.spatial(16, ax1_0 * 8 + i0 % 4 * 2 + i1 % 2)
205-
206-
T.reads(C_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2])
225+
v0, v1 = T.axis.remap("SS", [i0, i1])
226+
thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(v0, v1)
227+
T.reads(C_warp[thread_id, local_id])
207228
T.writes(C[v0, v1])
208-
C[v0, v1] = C_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2]
229+
C[v0, v1] = C_warp[thread_id, local_id]
209230

210231

211232
@T.prim_func
@@ -238,21 +259,13 @@ def mma_fill_desc(a: T.handle) -> None:
238259
with T.block("root"):
239260
T.reads()
240261
T.writes(C_warp[0:32, 0:8])
241-
for i0, i1 in T.grid(32, 8):
262+
for i0, i1 in T.grid(16, 16):
242263
with T.block("C_warp"):
243-
i_init = T.axis.spatial(16, i1 // 4 * 8 + i0 // 4)
244-
j_init = T.axis.spatial(16, (i0 % 4) * 4 + i1 % 4)
264+
i_init, j_init = T.axis.remap("SS", [i0, i1])
265+
thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(i_init, j_init)
245266
T.reads()
246-
T.writes(
247-
C_warp[
248-
i_init % 8 * 4 + j_init % 8 // 2,
249-
j_init % 16 // 8 * 4 + i_init % 16 // 8 * 2 + j_init % 2,
250-
]
251-
)
252-
C_warp[
253-
i_init % 8 * 4 + j_init % 8 // 2,
254-
j_init % 16 // 8 * 4 + i_init % 16 // 8 * 2 + j_init % 8 % 2,
255-
] = T.int32(0)
267+
T.writes(C_warp[thread_id, local_id])
268+
C_warp[thread_id, local_id] = T.int32(0)
256269

257270

258271
@T.prim_func
@@ -394,7 +407,8 @@ def fetch_to_shared(block, idx, ndim, vec=False):
394407
jo, ji = sch.split(jj, factors=[None, 16])
395408
sch.reorder(io, jo, ii, ji)
396409

397-
block_init_c = sch.decompose_reduction(block_outer, sch.get_loops(block_outer)[3])
410+
sch.decompose_reduction(block_outer, sch.get_loops(block_outer)[3])
411+
block_init_c = sch.get_block("C_init")
398412

399413
def tile_wmma_fragment(block_read, height, width):
400414
i, j = sch.get_loops(block_read)[-2:]
@@ -403,67 +417,39 @@ def tile_wmma_fragment(block_read, height, width):
403417
sch.reorder(i0, j0, i1, j1)
404418
return i1
405419

406-
def shared_16x16_to_ldmatrix_32x8_layout(i, j):
407-
i_0 = i // 16
408-
j_0 = j // 16
409-
410-
i = i % 16
411-
j = j % 16
412-
413-
thread_id = 4 * (i % 8) + (j % 8) // 2
414-
return i_0, j_0, thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 8) % 2
415-
416-
def shared_16x32_to_ldmatrix_32x16_layout(i, j):
417-
i_0 = i // 16
418-
j_0 = j // 32
419-
420-
i = i % 16
421-
j = j % 32
422-
423-
thread_id = 4 * (i % 8) + (j % 16) // 4
424-
return i_0, j_0, thread_id, 8 * (j // 16) + (i // 8) * 4 + j % 4
425-
426-
def shared_32x16_to_ldmatrix_32x16_layout(i, j):
427-
i_0 = i // 32
428-
j_0 = j // 16
420+
loop_a = tile_wmma_fragment(A_warp, 16, 32)
421+
loop_b = tile_wmma_fragment(B_warp, 32, 16)
429422

430-
i = i % 32
431-
j = j % 16
423+
def index_map_A(i, j):
424+
return (
425+
i // 16,
426+
j // 32,
427+
*shared_16x32_to_ldmatrix_32x16_layout(i % 16, j % 32),
428+
)
432429

433-
thread_id = (i % 4) + 4 * (j % 8)
434-
return i_0, j_0, thread_id, 8 * (j // 8) + (i // 16) * 4 + i % 4
430+
def index_map_B(i, j):
431+
return (
432+
i // 32,
433+
j // 16,
434+
*shared_32x16_to_ldmatrix_32x16_layout(i % 32, j % 16),
435+
)
435436

436-
loop_a = tile_wmma_fragment(A_warp, 16, 32)
437-
loop_b = tile_wmma_fragment(B_warp, 32, 16)
437+
def index_map_C(i, j):
438+
return (
439+
i // 16,
440+
j // 16,
441+
*shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16),
442+
)
438443

439-
sch.transform_layout(A_warp, 0, "write", index_map=shared_16x32_to_ldmatrix_32x16_layout)
440-
sch.transform_layout(B_warp, 0, "write", index_map=shared_32x16_to_ldmatrix_32x16_layout)
441-
sch.transform_layout(C_warp, 0, "read", index_map=shared_16x16_to_ldmatrix_32x8_layout)
444+
sch.transform_layout(A_warp, 0, "write", index_map_A)
445+
sch.transform_layout(B_warp, 0, "write", index_map_B)
446+
sch.transform_layout(C_warp, 0, "read", index_map_C)
442447

443448
sch.tensorize(loop_a, "mma.ldmatrix_a")
444449
sch.tensorize(loop_b, "mma.ldmatrix_b")
445-
446-
mma_loop = sch.get_loops(block_inner)[-3]
447-
sch.tensorize(mma_loop, "mma_sync")
448-
449-
block_init_c = sch.get_block("C_init")
450-
init_loop1, init_loop2 = sch.get_loops(block_init_c)[-2:]
451-
f_0, f_1 = sch.split(init_loop1, factors=[None, 8])
452-
f_2, f_3 = sch.split(init_loop2, factors=[None, 4])
453-
sch.reorder(f_1, f_2, f_0, f_3)
454-
fused_1 = sch.fuse(f_1, f_2)
455-
fused_2 = sch.fuse(f_0, f_3)
456-
sch.tensorize(fused_1, "mma_fill")
457-
458-
warp_loop1, warp_loop2 = sch.get_loops(C_warp)[-2:]
459-
f_0, f_1 = sch.split(warp_loop1, factors=[None, 8])
460-
outer, f_2, f_3 = sch.split(warp_loop2, factors=[2, 4, 2])
461-
sch.reorder(outer, f_1, f_2, f_0, f_3)
462-
fused_1 = sch.fuse(f_1, f_2)
463-
fused_2 = sch.fuse(f_0, f_3)
464-
sch.tensorize(outer, "mma_store")
465-
# print(sch.mod.script())
466-
# return
450+
sch.tensorize(sch.get_loops(block_inner)[-3], "mma_sync")
451+
sch.tensorize(sch.get_loops(block_init_c)[-2], "mma_fill")
452+
sch.tensorize(sch.get_loops(C_warp)[-2], "mma_store")
467453

468454

469455
ir_module = tvm.IRModule({"main": workload})

0 commit comments

Comments
 (0)