@@ -135,16 +135,18 @@ def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
135135 for i , j , k in T .grid (16 , 16 , 16 ):
136136 with T .block ("C" ):
137137 i , j , k = T .axis .remap ("SSR" , [i , j , k ])
138+ thread_id_C , local_id_C = shared_16x16_to_ldmatrix_32x8_layout (i , j )
139+ thread_id_A , local_id_A = shared_16x16_to_ldmatrix_32x8_layout (i , k )
140+ thread_id_B , local_id_B = shared_16x16_to_ldmatrix_32x8_layout (j , k )
141+
138142 T .reads (
139- C [i % 8 * 4 + j % 8 // 2 , j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 2 ],
140- A [i % 8 * 4 + k % 8 // 2 , k % 16 // 8 * 4 + i % 16 // 8 * 2 + k % 2 ],
141- B [j % 8 * 4 + k % 8 // 2 , k % 16 // 8 * 4 + j % 16 // 8 * 2 + k % 2 ],
143+ C [thread_id_C , local_id_C ],
144+ A [thread_id_A , local_id_A ],
145+ B [thread_id_B , local_id_B ],
142146 )
143- T .writes (C [i % 8 * 4 + j % 8 // 2 , j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 2 ])
144- C [i % 8 * 4 + j % 8 // 2 , j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 8 % 2 ] = (
145- C [i % 8 * 4 + j % 8 // 2 , j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 8 % 2 ]
146- + A [i % 8 * 4 + k % 8 // 2 , k % 16 // 8 * 4 + i % 16 // 8 * 2 + k % 8 % 2 ]
147- * B [j % 8 * 4 + k % 8 // 2 , k % 16 // 8 * 4 + j % 16 // 8 * 2 + k % 8 % 2 ]
147+ T .writes (C [thread_id_C , local_id_C ])
148+ C [thread_id_C , local_id_C ] += (
149+ A [thread_id_A , local_id_A ] * B [thread_id_B , local_id_B ]
148150 )
149151
150152
@@ -207,14 +209,13 @@ def mma_store_desc(a: T.handle, c: T.handle) -> None:
207209 with T .block ("root" ):
208210 T .reads (C_warp [0 :32 , 0 :8 ])
209211 T .writes (C [0 :16 , 0 :16 ])
210- for ax1_0 , i0 , i1 in T .grid (2 , 32 , 4 ):
212+ for i0 , i1 in T .grid (16 , 16 ):
211213 with T .block ("C_warp" ):
212- v0 = T .axis .spatial (16 , i1 // 2 * 8 + i0 // 4 )
213- v1 = T .axis .spatial (16 , ax1_0 * 8 + i0 % 4 * 2 + i1 % 2 )
214-
215- T .reads (C_warp [v0 % 8 * 4 + v1 % 8 // 2 , v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2 ])
214+ v0 , v1 = T .axis .remap ("SS" , [i0 , i1 ])
215+ thread_id , local_id = shared_16x16_to_ldmatrix_32x8_layout (v0 , v1 )
216+ T .reads (C_warp [thread_id , local_id ])
216217 T .writes (C [v0 , v1 ])
217- C [v0 , v1 ] = C_warp [v0 % 8 * 4 + v1 % 8 // 2 , v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2 ]
218+ C [v0 , v1 ] = C_warp [thread_id , local_id ]
218219
219220
220221@T .prim_func
@@ -247,21 +248,13 @@ def mma_fill_desc(a: T.handle) -> None:
247248 with T .block ("root" ):
248249 T .reads ()
249250 T .writes (C_warp [0 :32 , 0 :8 ])
250- for i0 , i1 in T .grid (32 , 8 ):
251+ for i0 , i1 in T .grid (16 , 16 ):
251252 with T .block ("C_warp" ):
252- i_init = T .axis .spatial ( 16 , i1 // 4 * 8 + i0 // 4 )
253- j_init = T . axis . spatial ( 16 , ( i0 % 4 ) * 4 + i1 % 4 )
253+ i_init , j_init = T .axis .remap ( "SS" , [ i0 , i1 ] )
254+ thread_id , local_id = shared_16x16_to_ldmatrix_32x8_layout ( i_init , j_init )
254255 T .reads ()
255- T .writes (
256- C_warp [
257- i_init % 8 * 4 + j_init % 8 // 2 ,
258- j_init % 16 // 8 * 4 + i_init % 16 // 8 * 2 + j_init % 2 ,
259- ]
260- )
261- C_warp [
262- i_init % 8 * 4 + j_init % 8 // 2 ,
263- j_init % 16 // 8 * 4 + i_init % 16 // 8 * 2 + j_init % 8 % 2 ,
264- ] = T .float16 (0 )
256+ T .writes (C_warp [thread_id , local_id ])
257+ C_warp [thread_id , local_id ] = T .float16 (0 )
265258
266259
267260@T .prim_func
@@ -387,8 +380,6 @@ def fetch_to_shared(block, idx, ndim):
387380 A_sh = fetch_to_shared (block_outer , 0 , 2 )
388381 B_sh = fetch_to_shared (block_outer , 1 , 2 )
389382
390- loop = sch .get_loops (block_outer )[- 1 ]
391-
392383 A_warp = sch .cache_read (block_outer , 0 , "warp" )
393384 B_warp = sch .cache_read (block_outer , 1 , "warp" )
394385
@@ -403,7 +394,8 @@ def fetch_to_shared(block, idx, ndim):
403394 jo , ji = sch .split (jj , factors = [None , 16 ])
404395 sch .reorder (io , jo , ii , ji )
405396
406- block_init_c = sch .decompose_reduction (block_outer , sch .get_loops (block_outer )[3 ])
397+ sch .decompose_reduction (block_outer , sch .get_loops (block_outer )[3 ])
398+ block_init_c = sch .get_block ("C_init" )
407399
408400 def tile_wmma_fragment (block_read , height ):
409401 i , j = sch .get_loops (block_read )[- 2 :]
@@ -412,47 +404,25 @@ def tile_wmma_fragment(block_read, height):
412404 sch .reorder (i0 , j0 , i1 , j1 )
413405 return i1
414406
415- def shared_16x16_to_ldmatrix_32x8_layout (i , j ):
416- i_0 = i // 16
417- j_0 = j // 16
418-
419- i = i % 16
420- j = j % 16
421-
422- thread_id = 4 * (i % 8 ) + (j % 8 ) // 2
423- return i_0 , j_0 , thread_id , 4 * (j // 8 ) + (i // 8 ) * 2 + (j % 8 ) % 2
424-
425407 loop_a = tile_wmma_fragment (A_warp , 16 )
426408 loop_b = tile_wmma_fragment (B_warp , 16 )
427409
428- sch .transform_layout (A_warp , 0 , "write" , index_map = shared_16x16_to_ldmatrix_32x8_layout )
429- sch .transform_layout (B_warp , 0 , "write" , index_map = shared_16x16_to_ldmatrix_32x8_layout )
430- sch .transform_layout (C_warp , 0 , "read" , index_map = shared_16x16_to_ldmatrix_32x8_layout )
410+ def index_map (i , j ):
411+ return (
412+ i // 16 ,
413+ j // 16 ,
414+ * shared_16x16_to_ldmatrix_32x8_layout (i % 16 , j % 16 ),
415+ )
416+
417+ sch .transform_layout (A_warp , 0 , "write" , index_map )
418+ sch .transform_layout (B_warp , 0 , "write" , index_map )
419+ sch .transform_layout (C_warp , 0 , "read" , index_map )
431420
432421 sch .tensorize (loop_a , "mma.ldmatrix_a" )
433422 sch .tensorize (loop_b , "mma.ldmatrix_b" )
434-
435- mma_loop = sch .get_loops (block_inner )[- 3 ]
436- sch .tensorize (mma_loop , "mma_sync" )
437-
438- block_init_c = sch .get_block ("C_init" )
439- init_loop1 , init_loop2 = sch .get_loops (block_init_c )[- 2 :]
440- f_0 , f_1 = sch .split (init_loop1 , factors = [None , 8 ])
441- f_2 , f_3 = sch .split (init_loop2 , factors = [None , 4 ])
442- sch .reorder (f_1 , f_2 , f_0 , f_3 )
443- fused_1 = sch .fuse (f_1 , f_2 )
444- fused_2 = sch .fuse (f_0 , f_3 )
445- sch .tensorize (fused_1 , "mma_fill" )
446-
447- warp_loop1 , warp_loop2 = sch .get_loops (C_warp )[- 2 :]
448- f_0 , f_1 = sch .split (warp_loop1 , factors = [None , 8 ])
449- outer , f_2 , f_3 = sch .split (warp_loop2 , factors = [2 , 4 , 2 ])
450- sch .reorder (outer , f_1 , f_2 , f_0 , f_3 )
451- fused_1 = sch .fuse (f_1 , f_2 )
452- fused_2 = sch .fuse (f_0 , f_3 )
453- sch .tensorize (outer , "mma_store" )
454- # print(sch.mod.script())
455- # return
423+ sch .tensorize (sch .get_loops (block_inner )[- 3 ], "mma_sync" )
424+ sch .tensorize (sch .get_loops (block_init_c )[- 2 ], "mma_fill" )
425+ sch .tensorize (sch .get_loops (C_warp )[- 2 ], "mma_store" )
456426
457427
458428ir_module = tvm .IRModule ({"main" : workload })
0 commit comments