Skip to content

Commit 54c6864

Browse files
committed
wip
1 parent 078060f commit 54c6864

File tree

1 file changed

+35
-65
lines changed

1 file changed

+35
-65
lines changed

tests/python/unittest/test_mma_16x8x16_fp16_4k_tune_trans.py

Lines changed: 35 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -135,16 +135,18 @@ def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
135135
for i, j, k in T.grid(16, 16, 16):
136136
with T.block("C"):
137137
i, j, k = T.axis.remap("SSR", [i, j, k])
138+
thread_id_C, local_id_C = shared_16x16_to_ldmatrix_32x8_layout(i, j)
139+
thread_id_A, local_id_A = shared_16x16_to_ldmatrix_32x8_layout(i, k)
140+
thread_id_B, local_id_B = shared_16x16_to_ldmatrix_32x8_layout(j, k)
141+
138142
T.reads(
139-
C[i % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 2],
140-
A[i % 8 * 4 + k % 8 // 2, k % 16 // 8 * 4 + i % 16 // 8 * 2 + k % 2],
141-
B[j % 8 * 4 + k % 8 // 2, k % 16 // 8 * 4 + j % 16 // 8 * 2 + k % 2],
143+
C[thread_id_C, local_id_C],
144+
A[thread_id_A, local_id_A],
145+
B[thread_id_B, local_id_B],
142146
)
143-
T.writes(C[i % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 2])
144-
C[i % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 8 % 2] = (
145-
C[i % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 8 % 2]
146-
+ A[i % 8 * 4 + k % 8 // 2, k % 16 // 8 * 4 + i % 16 // 8 * 2 + k % 8 % 2]
147-
* B[j % 8 * 4 + k % 8 // 2, k % 16 // 8 * 4 + j % 16 // 8 * 2 + k % 8 % 2]
147+
T.writes(C[thread_id_C, local_id_C])
148+
C[thread_id_C, local_id_C] += (
149+
A[thread_id_A, local_id_A] * B[thread_id_B, local_id_B]
148150
)
149151

150152

@@ -207,14 +209,13 @@ def mma_store_desc(a: T.handle, c: T.handle) -> None:
207209
with T.block("root"):
208210
T.reads(C_warp[0:32, 0:8])
209211
T.writes(C[0:16, 0:16])
210-
for ax1_0, i0, i1 in T.grid(2, 32, 4):
212+
for i0, i1 in T.grid(16, 16):
211213
with T.block("C_warp"):
212-
v0 = T.axis.spatial(16, i1 // 2 * 8 + i0 // 4)
213-
v1 = T.axis.spatial(16, ax1_0 * 8 + i0 % 4 * 2 + i1 % 2)
214-
215-
T.reads(C_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2])
214+
v0, v1 = T.axis.remap("SS", [i0, i1])
215+
thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(v0, v1)
216+
T.reads(C_warp[thread_id, local_id])
216217
T.writes(C[v0, v1])
217-
C[v0, v1] = C_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2]
218+
C[v0, v1] = C_warp[thread_id, local_id]
218219

219220

220221
@T.prim_func
@@ -247,21 +248,13 @@ def mma_fill_desc(a: T.handle) -> None:
247248
with T.block("root"):
248249
T.reads()
249250
T.writes(C_warp[0:32, 0:8])
250-
for i0, i1 in T.grid(32, 8):
251+
for i0, i1 in T.grid(16, 16):
251252
with T.block("C_warp"):
252-
i_init = T.axis.spatial(16, i1 // 4 * 8 + i0 // 4)
253-
j_init = T.axis.spatial(16, (i0 % 4) * 4 + i1 % 4)
253+
i_init, j_init = T.axis.remap("SS", [i0, i1])
254+
thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(i_init, j_init)
254255
T.reads()
255-
T.writes(
256-
C_warp[
257-
i_init % 8 * 4 + j_init % 8 // 2,
258-
j_init % 16 // 8 * 4 + i_init % 16 // 8 * 2 + j_init % 2,
259-
]
260-
)
261-
C_warp[
262-
i_init % 8 * 4 + j_init % 8 // 2,
263-
j_init % 16 // 8 * 4 + i_init % 16 // 8 * 2 + j_init % 8 % 2,
264-
] = T.float16(0)
256+
T.writes(C_warp[thread_id, local_id])
257+
C_warp[thread_id, local_id] = T.float16(0)
265258

266259

267260
@T.prim_func
@@ -387,8 +380,6 @@ def fetch_to_shared(block, idx, ndim):
387380
A_sh = fetch_to_shared(block_outer, 0, 2)
388381
B_sh = fetch_to_shared(block_outer, 1, 2)
389382

390-
loop = sch.get_loops(block_outer)[-1]
391-
392383
A_warp = sch.cache_read(block_outer, 0, "warp")
393384
B_warp = sch.cache_read(block_outer, 1, "warp")
394385

@@ -403,7 +394,8 @@ def fetch_to_shared(block, idx, ndim):
403394
jo, ji = sch.split(jj, factors=[None, 16])
404395
sch.reorder(io, jo, ii, ji)
405396

406-
block_init_c = sch.decompose_reduction(block_outer, sch.get_loops(block_outer)[3])
397+
sch.decompose_reduction(block_outer, sch.get_loops(block_outer)[3])
398+
block_init_c = sch.get_block("C_init")
407399

408400
def tile_wmma_fragment(block_read, height):
409401
i, j = sch.get_loops(block_read)[-2:]
@@ -412,47 +404,25 @@ def tile_wmma_fragment(block_read, height):
412404
sch.reorder(i0, j0, i1, j1)
413405
return i1
414406

415-
def shared_16x16_to_ldmatrix_32x8_layout(i, j):
416-
i_0 = i // 16
417-
j_0 = j // 16
418-
419-
i = i % 16
420-
j = j % 16
421-
422-
thread_id = 4 * (i % 8) + (j % 8) // 2
423-
return i_0, j_0, thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 8) % 2
424-
425407
loop_a = tile_wmma_fragment(A_warp, 16)
426408
loop_b = tile_wmma_fragment(B_warp, 16)
427409

428-
sch.transform_layout(A_warp, 0, "write", index_map=shared_16x16_to_ldmatrix_32x8_layout)
429-
sch.transform_layout(B_warp, 0, "write", index_map=shared_16x16_to_ldmatrix_32x8_layout)
430-
sch.transform_layout(C_warp, 0, "read", index_map=shared_16x16_to_ldmatrix_32x8_layout)
410+
def index_map(i, j):
411+
return (
412+
i // 16,
413+
j // 16,
414+
*shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16),
415+
)
416+
417+
sch.transform_layout(A_warp, 0, "write", index_map)
418+
sch.transform_layout(B_warp, 0, "write", index_map)
419+
sch.transform_layout(C_warp, 0, "read", index_map)
431420

432421
sch.tensorize(loop_a, "mma.ldmatrix_a")
433422
sch.tensorize(loop_b, "mma.ldmatrix_b")
434-
435-
mma_loop = sch.get_loops(block_inner)[-3]
436-
sch.tensorize(mma_loop, "mma_sync")
437-
438-
block_init_c = sch.get_block("C_init")
439-
init_loop1, init_loop2 = sch.get_loops(block_init_c)[-2:]
440-
f_0, f_1 = sch.split(init_loop1, factors=[None, 8])
441-
f_2, f_3 = sch.split(init_loop2, factors=[None, 4])
442-
sch.reorder(f_1, f_2, f_0, f_3)
443-
fused_1 = sch.fuse(f_1, f_2)
444-
fused_2 = sch.fuse(f_0, f_3)
445-
sch.tensorize(fused_1, "mma_fill")
446-
447-
warp_loop1, warp_loop2 = sch.get_loops(C_warp)[-2:]
448-
f_0, f_1 = sch.split(warp_loop1, factors=[None, 8])
449-
outer, f_2, f_3 = sch.split(warp_loop2, factors=[2, 4, 2])
450-
sch.reorder(outer, f_1, f_2, f_0, f_3)
451-
fused_1 = sch.fuse(f_1, f_2)
452-
fused_2 = sch.fuse(f_0, f_3)
453-
sch.tensorize(outer, "mma_store")
454-
# print(sch.mod.script())
455-
# return
423+
sch.tensorize(sch.get_loops(block_inner)[-3], "mma_sync")
424+
sch.tensorize(sch.get_loops(block_init_c)[-2], "mma_fill")
425+
sch.tensorize(sch.get_loops(C_warp)[-2], "mma_store")
456426

457427

458428
ir_module = tvm.IRModule({"main": workload})

0 commit comments

Comments
 (0)