Skip to content

Commit 078060f

Browse files
committed
wip
1 parent 576f841 commit 078060f

File tree

4 files changed

+73
-85
lines changed

4 files changed

+73
-85
lines changed

tests/python/unittest/test_mma_16x8x16_4k_tune.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -364,8 +364,6 @@ def fetch_to_shared(block, idx, ndim):
364364
A_sh = fetch_to_shared(block_outer, 0, 2)
365365
B_sh = fetch_to_shared(block_outer, 1, 2)
366366

367-
loop = sch.get_loops(block_outer)[-1]
368-
369367
A_warp = sch.cache_read(block_outer, 0, "warp")
370368
B_warp = sch.cache_read(block_outer, 1, "warp")
371369

tests/python/unittest/test_mma_16x8x16_4k_tune_trans.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,8 +381,6 @@ def fetch_to_shared(block, idx, ndim):
381381
A_sh = fetch_to_shared(block_outer, 0, 2)
382382
B_sh = fetch_to_shared(block_outer, 1, 2)
383383

384-
loop = sch.get_loops(block_outer)[-1]
385-
386384
A_warp = sch.cache_read(block_outer, 0, "warp")
387385
B_warp = sch.cache_read(block_outer, 1, "warp")
388386

tests/python/unittest/test_mma_16x8x16_fp16_4k_tune.py

Lines changed: 55 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,17 @@
88
import numpy as np
99

1010

11+
def shared_16x16_to_ldmatrix_32x8_layout(i, j):
12+
thread_id = 4 * (i % 8) + (j % 8) // 2
13+
return thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 2)
14+
15+
16+
@tvm._ffi.register_func("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout")
17+
def index_map_shared_16x16_to_ldmatrix_32x8_layout(i, j):
18+
thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(i, j)
19+
return tvm.runtime.convert([thread_id, local_id])
20+
21+
1122
@T.prim_func
1223
def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None:
1324
A_shared = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared")
@@ -21,10 +32,10 @@ def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None:
2132
with T.block("A_shared_warp"):
2233
v0, v1 = T.axis.remap("SS", [ax0, ax1])
2334
T.reads(A_shared[v0, v1])
24-
T.writes(A_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2])
25-
A_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2] = A_shared[
26-
v0, v1
27-
]
35+
36+
thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(v0, v1)
37+
T.writes(A_warp[thread_id, local_id])
38+
A_warp[thread_id, local_id] = A_shared[v0, v1]
2839

2940

3041
@T.prim_func
@@ -74,10 +85,9 @@ def ldmatrix_b_desc(a: T.handle, c: T.handle) -> None:
7485
with T.block("B_shared_warp"):
7586
v0, v1 = T.axis.remap("SS", [ax0, ax1])
7687
T.reads(B_shared[v0, v1])
77-
T.writes(B_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2])
78-
B_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2] = B_shared[
79-
v0, v1
80-
]
88+
thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(v0, v1)
89+
T.writes(B_warp[thread_id, local_id])
90+
B_warp[thread_id, local_id] = B_shared[v0, v1]
8191

8292

8393
@T.prim_func
@@ -126,15 +136,19 @@ def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
126136
for i, j, k in T.grid(16, 16, 16):
127137
with T.block("C"):
128138
i, j, k = T.axis.remap("SSR", [i, j, k])
139+
thread_id_C, local_id_C = shared_16x16_to_ldmatrix_32x8_layout(i, j)
140+
thread_id_A, local_id_A = shared_16x16_to_ldmatrix_32x8_layout(i, k)
141+
thread_id_B, local_id_B = shared_16x16_to_ldmatrix_32x8_layout(k, j)
142+
129143
T.reads(
130-
C[i % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 2],
131-
A[i % 8 * 4 + k % 8 // 2, k % 16 // 8 * 4 + i % 16 // 8 * 2 + k % 2],
132-
B[k % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + k % 16 // 8 * 2 + j % 2],
144+
C[thread_id_C, local_id_C],
145+
A[thread_id_A, local_id_A],
146+
B[thread_id_B, local_id_B],
147+
)
148+
T.writes(C[thread_id_C, local_id_C])
149+
C[thread_id_C, local_id_C] += (
150+
A[thread_id_A, local_id_A] * B[thread_id_B, local_id_B]
133151
)
134-
T.writes(C[i % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 2])
135-
C[i % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 8 % 2] = C[
136-
i % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 8 % 2
137-
] + A[i % 8 * 4 + k % 8 // 2, k % 16 // 8 * 4 + i % 16 // 8 * 2 + k % 8 % 2] * B[k % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + k % 16 // 8 * 2 + j % 8 % 2]
138152

139153

140154
@T.prim_func
@@ -196,14 +210,13 @@ def mma_store_desc(a: T.handle, c: T.handle) -> None:
196210
with T.block("root"):
197211
T.reads(C_warp[0:32, 0:8])
198212
T.writes(C[0:16, 0:16])
199-
for ax1_0, i0, i1 in T.grid(2, 32, 4):
213+
for i0, i1 in T.grid(16, 16):
200214
with T.block("C_warp"):
201-
v0 = T.axis.spatial(16, i1 // 2 * 8 + i0 // 4)
202-
v1 = T.axis.spatial(16, ax1_0 * 8 + i0 % 4 * 2 + i1 % 2)
203-
204-
T.reads(C_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2])
215+
v0, v1 = T.axis.remap("SS", [i0, i1])
216+
thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(v0, v1)
217+
T.reads(C_warp[thread_id, local_id])
205218
T.writes(C[v0, v1])
206-
C[v0, v1] = C_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2]
219+
C[v0, v1] = C_warp[thread_id, local_id]
207220

208221

209222
@T.prim_func
@@ -236,21 +249,13 @@ def mma_fill_desc(a: T.handle) -> None:
236249
with T.block("root"):
237250
T.reads()
238251
T.writes(C_warp[0:32, 0:8])
239-
for i0, i1 in T.grid(32, 8):
252+
for i0, i1 in T.grid(16, 16):
240253
with T.block("C_warp"):
241-
i_init = T.axis.spatial(16, i1 // 4 * 8 + i0 // 4)
242-
j_init = T.axis.spatial(16, (i0 % 4) * 4 + i1 % 4)
254+
i_init, j_init = T.axis.remap("SS", [i0, i1])
255+
thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(i_init, j_init)
243256
T.reads()
244-
T.writes(
245-
C_warp[
246-
i_init % 8 * 4 + j_init % 8 // 2,
247-
j_init % 16 // 8 * 4 + i_init % 16 // 8 * 2 + j_init % 2,
248-
]
249-
)
250-
C_warp[
251-
i_init % 8 * 4 + j_init % 8 // 2,
252-
j_init % 16 // 8 * 4 + i_init % 16 // 8 * 2 + j_init % 8 % 2,
253-
] = T.float16(0)
257+
T.writes(C_warp[thread_id, local_id])
258+
C_warp[thread_id, local_id] = T.float16(0)
254259

255260

256261
@T.prim_func
@@ -276,6 +281,7 @@ def mma_fill_impl(a: T.handle) -> None:
276281
M = 4096
277282
K = 4096
278283

284+
279285
def matmul_fp16(n, m, k):
280286
a = te.placeholder((n, k), name="A", dtype="float16")
281287
b = te.placeholder((k, m), name="B", dtype="float16")
@@ -373,8 +379,6 @@ def fetch_to_shared(block, idx, ndim):
373379
A_sh = fetch_to_shared(block_outer, 0, 2)
374380
B_sh = fetch_to_shared(block_outer, 1, 2)
375381

376-
loop = sch.get_loops(block_outer)[-1]
377-
378382
A_warp = sch.cache_read(block_outer, 0, "warp")
379383
B_warp = sch.cache_read(block_outer, 1, "warp")
380384

@@ -389,7 +393,8 @@ def fetch_to_shared(block, idx, ndim):
389393
jo, ji = sch.split(jj, factors=[None, 16])
390394
sch.reorder(io, jo, ii, ji)
391395

392-
block_init_c = sch.decompose_reduction(block_outer, sch.get_loops(block_outer)[3])
396+
sch.decompose_reduction(block_outer, sch.get_loops(block_outer)[3])
397+
block_init_c = sch.get_block("C_init")
393398

394399
def tile_wmma_fragment(block_read, height):
395400
i, j = sch.get_loops(block_read)[-2:]
@@ -398,47 +403,25 @@ def tile_wmma_fragment(block_read, height):
398403
sch.reorder(i0, j0, i1, j1)
399404
return i1
400405

401-
def shared_16x16_to_ldmatrix_32x8_layout(i, j):
402-
i_0 = i // 16
403-
j_0 = j // 16
404-
405-
i = i % 16
406-
j = j % 16
407-
408-
thread_id = 4 * (i % 8) + (j % 8) // 2
409-
return i_0, j_0, thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 8) % 2
410-
411406
loop_a = tile_wmma_fragment(A_warp, 16)
412407
loop_b = tile_wmma_fragment(B_warp, 16)
413408

414-
sch.transform_layout(A_warp, 0, "write", index_map=shared_16x16_to_ldmatrix_32x8_layout)
415-
sch.transform_layout(B_warp, 0, "write", index_map=shared_16x16_to_ldmatrix_32x8_layout)
416-
sch.transform_layout(C_warp, 0, "read", index_map=shared_16x16_to_ldmatrix_32x8_layout)
409+
def index_map(i, j):
410+
return (
411+
i // 16,
412+
j // 16,
413+
*shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16),
414+
)
415+
416+
sch.transform_layout(A_warp, 0, "write", index_map)
417+
sch.transform_layout(B_warp, 0, "write", index_map)
418+
sch.transform_layout(C_warp, 0, "read", index_map)
417419

418420
sch.tensorize(loop_a, "mma.ldmatrix_a")
419421
sch.tensorize(loop_b, "mma.ldmatrix_b")
420-
421-
mma_loop = sch.get_loops(block_inner)[-3]
422-
sch.tensorize(mma_loop, "mma_sync")
423-
424-
block_init_c = sch.get_block("C_init")
425-
init_loop1, init_loop2 = sch.get_loops(block_init_c)[-2:]
426-
f_0, f_1 = sch.split(init_loop1, factors=[None, 8])
427-
f_2, f_3 = sch.split(init_loop2, factors=[None, 4])
428-
sch.reorder(f_1, f_2, f_0, f_3)
429-
fused_1 = sch.fuse(f_1, f_2)
430-
fused_2 = sch.fuse(f_0, f_3)
431-
sch.tensorize(fused_1, "mma_fill")
432-
433-
warp_loop1, warp_loop2 = sch.get_loops(C_warp)[-2:]
434-
f_0, f_1 = sch.split(warp_loop1, factors=[None, 8])
435-
outer, f_2, f_3 = sch.split(warp_loop2, factors=[2, 4, 2])
436-
sch.reorder(outer, f_1, f_2, f_0, f_3)
437-
fused_1 = sch.fuse(f_1, f_2)
438-
fused_2 = sch.fuse(f_0, f_3)
439-
sch.tensorize(outer, "mma_store")
440-
# print(sch.mod.script())
441-
# return
422+
sch.tensorize(sch.get_loops(block_inner)[-3], "mma_sync")
423+
sch.tensorize(sch.get_loops(block_init_c)[-2], "mma_fill")
424+
sch.tensorize(sch.get_loops(C_warp)[-2], "mma_store")
442425

443426

444427
ir_module = tvm.IRModule({"main": workload})

tests/python/unittest/test_mma_16x8x16_fp16_4k_tune_trans.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,17 @@
88
import numpy as np
99

1010

11+
def shared_16x16_to_ldmatrix_32x8_layout(i, j):
12+
thread_id = 4 * (i % 8) + (j % 8) // 2
13+
return thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 2)
14+
15+
16+
@tvm._ffi.register_func("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout")
17+
def index_map_shared_16x16_to_ldmatrix_32x8_layout(i, j):
18+
thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(i, j)
19+
return tvm.runtime.convert([thread_id, local_id])
20+
21+
1122
@T.prim_func
1223
def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None:
1324
A_shared = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared")
@@ -21,10 +32,10 @@ def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None:
2132
with T.block("A_shared_warp"):
2233
v0, v1 = T.axis.remap("SS", [ax0, ax1])
2334
T.reads(A_shared[v0, v1])
24-
T.writes(A_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2])
25-
A_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2] = A_shared[
26-
v0, v1
27-
]
35+
36+
thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(v0, v1)
37+
T.writes(A_warp[thread_id, local_id])
38+
A_warp[thread_id, local_id] = A_shared[v0, v1]
2839

2940

3041
@T.prim_func
@@ -60,7 +71,6 @@ def ldmatrix_a_impl(a: T.handle, c: T.handle) -> None:
6071
)
6172
)
6273

63-
6474
@T.prim_func
6575
def ldmatrix_b_desc(a: T.handle, c: T.handle) -> None:
6676
B_shared = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared")
@@ -74,10 +84,9 @@ def ldmatrix_b_desc(a: T.handle, c: T.handle) -> None:
7484
with T.block("B_shared_warp"):
7585
v0, v1 = T.axis.remap("SS", [ax0, ax1])
7686
T.reads(B_shared[v0, v1])
77-
T.writes(B_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2])
78-
B_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2] = B_shared[
79-
v0, v1
80-
]
87+
thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(v0, v1)
88+
T.writes(B_warp[thread_id, local_id])
89+
B_warp[thread_id, local_id] = B_shared[v0, v1]
8190

8291

8392
@T.prim_func

0 commit comments

Comments
 (0)