88import numpy as np
99
1010
11+ def shared_16x16_to_ldmatrix_32x8_layout (i , j ):
12+ thread_id = 4 * (i % 8 ) + (j % 8 ) // 2
13+ return thread_id , 4 * (j // 8 ) + (i // 8 ) * 2 + (j % 2 )
14+
15+
16+ def shared_16x32_to_ldmatrix_32x16_layout (i , j ):
17+ thread_id = 4 * (i % 8 ) + (j % 16 ) // 4
18+ return thread_id , 8 * (j // 16 ) + (i // 8 ) * 4 + j % 4
19+
20+
21+ def shared_32x16_to_ldmatrix_32x16_layout (i , j ):
22+ thread_id = (i % 4 ) + 4 * (j % 8 )
23+ return thread_id , 8 * (j // 8 ) + (i // 16 ) * 4 + i % 4
24+
25+
26+ @tvm ._ffi .register_func ("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout" )
27+ def index_map_shared_16x16_to_ldmatrix_32x8_layout (i , j ):
28+ thread_id , local_id = shared_16x16_to_ldmatrix_32x8_layout (i , j )
29+ return tvm .runtime .convert ([thread_id , local_id ])
30+
31+
1132@T .prim_func
1233def ldmatrix_a_desc (a : T .handle , c : T .handle ) -> None :
1334 A_shared = T .match_buffer (a , (16 , 32 ), "int8" , align = 128 , offset_factor = 16 , scope = "shared" )
@@ -21,10 +42,9 @@ def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None:
2142 with T .block ("A_shared_warp" ):
2243 v0 , v1 = T .axis .remap ("SS" , [ax0 , ax1 ])
2344 T .reads (A_shared [v0 , v1 ])
24- T .writes (A_warp [v0 % 8 * 4 + v1 % 16 // 4 , v1 // 16 * 8 + v0 // 8 * 4 + v1 % 4 ])
25- A_warp [v0 % 8 * 4 + v1 % 16 // 4 , v1 // 16 * 8 + v0 // 8 * 4 + v1 % 4 ] = A_shared [
26- v0 , v1
27- ]
45+ thread_id , local_id = shared_16x32_to_ldmatrix_32x16_layout (v0 , v1 )
46+ T .writes (A_warp [thread_id , local_id ])
47+ A_warp [thread_id , local_id ] = A_shared [v0 , v1 ]
2848
2949
3050@T .prim_func
@@ -74,8 +94,9 @@ def ldmatrix_b_desc(a: T.handle, c: T.handle) -> None:
7494 with T .block ("B_shared_warp" ):
7595 v0 , v1 = T .axis .remap ("SS" , [ax0 , ax1 ])
7696 T .reads (B_shared [v0 , v1 ])
77- T .writes (B_warp [v1 % 8 * 4 + v0 % 4 , v1 // 8 * 8 + v0 // 16 * 4 + v0 % 4 ])
78- B_warp [v1 % 8 * 4 + v0 % 4 , v1 // 8 * 8 + v0 // 16 * 4 + v0 % 4 ] = B_shared [v0 , v1 ]
97+ thread_id , local_id = shared_32x16_to_ldmatrix_32x16_layout (v0 , v1 )
98+ T .writes (B_warp [thread_id , local_id ])
99+ B_warp [thread_id , local_id ] = B_shared [v0 , v1 ]
79100
80101
81102@T .prim_func
@@ -124,18 +145,19 @@ def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
124145 for i , j , k in T .grid (16 , 16 , 32 ):
125146 with T .block ("C" ):
126147 i , j , k = T .axis .remap ("SSR" , [i , j , k ])
148+
149+ thread_id_C , local_id_C = shared_16x16_to_ldmatrix_32x8_layout (i , j )
150+ thread_id_A , local_id_A = shared_16x32_to_ldmatrix_32x16_layout (i , k )
151+ thread_id_B , local_id_B = shared_32x16_to_ldmatrix_32x16_layout (k , j )
152+
127153 T .reads (
128- C [i % 8 * 4 + j % 8 // 2 , j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 2 ],
129- A [i % 8 * 4 + k % 16 // 4 , k % 32 // 16 * 8 + i % 16 // 8 * 4 + k % 4 ],
130- B [j % 8 * 4 + k % 4 , j % 16 // 8 * 8 + k % 32 // 16 * 4 + k % 4 ],
154+ C [thread_id_C , local_id_C ],
155+ A [thread_id_A , local_id_A ],
156+ B [thread_id_B , local_id_B ],
131157 )
132- T .writes (C [i % 8 * 4 + j % 8 // 2 , j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 2 ])
133- C [i % 8 * 4 + j % 8 // 2 , j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 8 % 2 ] = C [
134- i % 8 * 4 + j % 8 // 2 , j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 8 % 2
135- ] + T .cast (
136- A [i % 8 * 4 + k % 16 // 4 , k % 32 // 16 * 8 + i % 16 // 8 * 4 + k % 4 ], "int32"
137- ) * T .cast (
138- B [j % 8 * 4 + k % 4 , j % 16 // 8 * 8 + k % 32 // 16 * 4 + k % 4 ], "int32"
158+ T .writes (C [thread_id_C , local_id_C ])
159+ C [thread_id_C , local_id_C ] += T .cast (A [thread_id_A , local_id_A ], "int32" ) * T .cast (
160+ B [thread_id_B , local_id_B ], "int32"
139161 )
140162
141163
@@ -198,14 +220,13 @@ def mma_store_desc(a: T.handle, c: T.handle) -> None:
198220 with T .block ("root" ):
199221 T .reads (C_warp [0 :32 , 0 :8 ])
200222 T .writes (C [0 :16 , 0 :16 ])
201- for ax1_0 , i0 , i1 in T .grid (2 , 32 , 4 ):
223+ for i0 , i1 in T .grid (16 , 16 ):
202224 with T .block ("C_warp" ):
203- v0 = T .axis .spatial (16 , i1 // 2 * 8 + i0 // 4 )
204- v1 = T .axis .spatial (16 , ax1_0 * 8 + i0 % 4 * 2 + i1 % 2 )
205-
206- T .reads (C_warp [v0 % 8 * 4 + v1 % 8 // 2 , v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2 ])
225+ v0 , v1 = T .axis .remap ("SS" , [i0 , i1 ])
226+ thread_id , local_id = shared_16x16_to_ldmatrix_32x8_layout (v0 , v1 )
227+ T .reads (C_warp [thread_id , local_id ])
207228 T .writes (C [v0 , v1 ])
208- C [v0 , v1 ] = C_warp [v0 % 8 * 4 + v1 % 8 // 2 , v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2 ]
229+ C [v0 , v1 ] = C_warp [thread_id , local_id ]
209230
210231
211232@T .prim_func
@@ -238,21 +259,13 @@ def mma_fill_desc(a: T.handle) -> None:
238259 with T .block ("root" ):
239260 T .reads ()
240261 T .writes (C_warp [0 :32 , 0 :8 ])
241- for i0 , i1 in T .grid (32 , 8 ):
262+ for i0 , i1 in T .grid (16 , 16 ):
242263 with T .block ("C_warp" ):
243- i_init = T .axis .spatial ( 16 , i1 // 4 * 8 + i0 // 4 )
244- j_init = T . axis . spatial ( 16 , ( i0 % 4 ) * 4 + i1 % 4 )
264+ i_init , j_init = T .axis .remap ( "SS" , [ i0 , i1 ] )
265+ thread_id , local_id = shared_16x16_to_ldmatrix_32x8_layout ( i_init , j_init )
245266 T .reads ()
246- T .writes (
247- C_warp [
248- i_init % 8 * 4 + j_init % 8 // 2 ,
249- j_init % 16 // 8 * 4 + i_init % 16 // 8 * 2 + j_init % 2 ,
250- ]
251- )
252- C_warp [
253- i_init % 8 * 4 + j_init % 8 // 2 ,
254- j_init % 16 // 8 * 4 + i_init % 16 // 8 * 2 + j_init % 8 % 2 ,
255- ] = T .int32 (0 )
267+ T .writes (C_warp [thread_id , local_id ])
268+ C_warp [thread_id , local_id ] = T .int32 (0 )
256269
257270
258271@T .prim_func
@@ -394,7 +407,8 @@ def fetch_to_shared(block, idx, ndim, vec=False):
394407 jo , ji = sch .split (jj , factors = [None , 16 ])
395408 sch .reorder (io , jo , ii , ji )
396409
397- block_init_c = sch .decompose_reduction (block_outer , sch .get_loops (block_outer )[3 ])
410+ sch .decompose_reduction (block_outer , sch .get_loops (block_outer )[3 ])
411+ block_init_c = sch .get_block ("C_init" )
398412
399413 def tile_wmma_fragment (block_read , height , width ):
400414 i , j = sch .get_loops (block_read )[- 2 :]
@@ -403,67 +417,39 @@ def tile_wmma_fragment(block_read, height, width):
403417 sch .reorder (i0 , j0 , i1 , j1 )
404418 return i1
405419
406- def shared_16x16_to_ldmatrix_32x8_layout (i , j ):
407- i_0 = i // 16
408- j_0 = j // 16
409-
410- i = i % 16
411- j = j % 16
412-
413- thread_id = 4 * (i % 8 ) + (j % 8 ) // 2
414- return i_0 , j_0 , thread_id , 4 * (j // 8 ) + (i // 8 ) * 2 + (j % 8 ) % 2
415-
416- def shared_16x32_to_ldmatrix_32x16_layout (i , j ):
417- i_0 = i // 16
418- j_0 = j // 32
419-
420- i = i % 16
421- j = j % 32
422-
423- thread_id = 4 * (i % 8 ) + (j % 16 ) // 4
424- return i_0 , j_0 , thread_id , 8 * (j // 16 ) + (i // 8 ) * 4 + j % 4
425-
426- def shared_32x16_to_ldmatrix_32x16_layout (i , j ):
427- i_0 = i // 32
428- j_0 = j // 16
420+ loop_a = tile_wmma_fragment (A_warp , 16 , 32 )
421+ loop_b = tile_wmma_fragment (B_warp , 32 , 16 )
429422
430- i = i % 32
431- j = j % 16
423+ def index_map_A (i , j ):
424+ return (
425+ i // 16 ,
426+ j // 32 ,
427+ * shared_16x32_to_ldmatrix_32x16_layout (i % 16 , j % 32 ),
428+ )
432429
433- thread_id = (i % 4 ) + 4 * (j % 8 )
434- return i_0 , j_0 , thread_id , 8 * (j // 8 ) + (i // 16 ) * 4 + i % 4
430+ def index_map_B (i , j ):
431+ return (
432+ i // 32 ,
433+ j // 16 ,
434+ * shared_32x16_to_ldmatrix_32x16_layout (i % 32 , j % 16 ),
435+ )
435436
436- loop_a = tile_wmma_fragment (A_warp , 16 , 32 )
437- loop_b = tile_wmma_fragment (B_warp , 32 , 16 )
437+ def index_map_C (i , j ):
438+ return (
439+ i // 16 ,
440+ j // 16 ,
441+ * shared_16x16_to_ldmatrix_32x8_layout (i % 16 , j % 16 ),
442+ )
438443
439- sch .transform_layout (A_warp , 0 , "write" , index_map = shared_16x32_to_ldmatrix_32x16_layout )
440- sch .transform_layout (B_warp , 0 , "write" , index_map = shared_32x16_to_ldmatrix_32x16_layout )
441- sch .transform_layout (C_warp , 0 , "read" , index_map = shared_16x16_to_ldmatrix_32x8_layout )
444+ sch .transform_layout (A_warp , 0 , "write" , index_map_A )
445+ sch .transform_layout (B_warp , 0 , "write" , index_map_B )
446+ sch .transform_layout (C_warp , 0 , "read" , index_map_C )
442447
443448 sch .tensorize (loop_a , "mma.ldmatrix_a" )
444449 sch .tensorize (loop_b , "mma.ldmatrix_b" )
445-
446- mma_loop = sch .get_loops (block_inner )[- 3 ]
447- sch .tensorize (mma_loop , "mma_sync" )
448-
449- block_init_c = sch .get_block ("C_init" )
450- init_loop1 , init_loop2 = sch .get_loops (block_init_c )[- 2 :]
451- f_0 , f_1 = sch .split (init_loop1 , factors = [None , 8 ])
452- f_2 , f_3 = sch .split (init_loop2 , factors = [None , 4 ])
453- sch .reorder (f_1 , f_2 , f_0 , f_3 )
454- fused_1 = sch .fuse (f_1 , f_2 )
455- fused_2 = sch .fuse (f_0 , f_3 )
456- sch .tensorize (fused_1 , "mma_fill" )
457-
458- warp_loop1 , warp_loop2 = sch .get_loops (C_warp )[- 2 :]
459- f_0 , f_1 = sch .split (warp_loop1 , factors = [None , 8 ])
460- outer , f_2 , f_3 = sch .split (warp_loop2 , factors = [2 , 4 , 2 ])
461- sch .reorder (outer , f_1 , f_2 , f_0 , f_3 )
462- fused_1 = sch .fuse (f_1 , f_2 )
463- fused_2 = sch .fuse (f_0 , f_3 )
464- sch .tensorize (outer , "mma_store" )
465- # print(sch.mod.script())
466- # return
450+ sch .tensorize (sch .get_loops (block_inner )[- 3 ], "mma_sync" )
451+ sch .tensorize (sch .get_loops (block_init_c )[- 2 ], "mma_fill" )
452+ sch .tensorize (sch .get_loops (C_warp )[- 2 ], "mma_store" )
467453
468454
469455ir_module = tvm .IRModule ({"main" : workload })
0 commit comments