88import numpy as np
99
1010
11+ def shared_16x16_to_ldmatrix_32x8_layout (i , j ):
12+ thread_id = 4 * (i % 8 ) + (j % 8 ) // 2
13+ return thread_id , 4 * (j // 8 ) + (i // 8 ) * 2 + (j % 2 )
14+
15+
16+ @tvm ._ffi .register_func ("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout" )
17+ def index_map_shared_16x16_to_ldmatrix_32x8_layout (i , j ):
18+ thread_id , local_id = shared_16x16_to_ldmatrix_32x8_layout (i , j )
19+ return tvm .runtime .convert ([thread_id , local_id ])
20+
21+
1122@T .prim_func
1223def ldmatrix_a_desc (a : T .handle , c : T .handle ) -> None :
1324 A_shared = T .match_buffer (a , (16 , 16 ), "float16" , align = 128 , offset_factor = 16 , scope = "shared" )
@@ -21,10 +32,10 @@ def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None:
2132 with T .block ("A_shared_warp" ):
2233 v0 , v1 = T .axis .remap ("SS" , [ax0 , ax1 ])
2334 T .reads (A_shared [v0 , v1 ])
24- T . writes ( A_warp [ v0 % 8 * 4 + v1 % 8 // 2 , v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2 ])
25- A_warp [ v0 % 8 * 4 + v1 % 8 // 2 , v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2 ] = A_shared [
26- v0 , v1
27- ]
35+
36+ thread_id , local_id = shared_16x16_to_ldmatrix_32x8_layout ( v0 , v1 )
37+ T . writes ( A_warp [ thread_id , local_id ])
38+ A_warp [ thread_id , local_id ] = A_shared [ v0 , v1 ]
2839
2940
3041@T .prim_func
@@ -74,10 +85,9 @@ def ldmatrix_b_desc(a: T.handle, c: T.handle) -> None:
7485 with T .block ("B_shared_warp" ):
7586 v0 , v1 = T .axis .remap ("SS" , [ax0 , ax1 ])
7687 T .reads (B_shared [v0 , v1 ])
77- T .writes (B_warp [v0 % 8 * 4 + v1 % 8 // 2 , v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2 ])
78- B_warp [v0 % 8 * 4 + v1 % 8 // 2 , v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2 ] = B_shared [
79- v0 , v1
80- ]
88+ thread_id , local_id = shared_16x16_to_ldmatrix_32x8_layout (v0 , v1 )
89+ T .writes (B_warp [thread_id , local_id ])
90+ B_warp [thread_id , local_id ] = B_shared [v0 , v1 ]
8191
8292
8393@T .prim_func
@@ -126,15 +136,19 @@ def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
126136 for i , j , k in T .grid (16 , 16 , 16 ):
127137 with T .block ("C" ):
128138 i , j , k = T .axis .remap ("SSR" , [i , j , k ])
139+ thread_id_C , local_id_C = shared_16x16_to_ldmatrix_32x8_layout (i , j )
140+ thread_id_A , local_id_A = shared_16x16_to_ldmatrix_32x8_layout (i , k )
141+ thread_id_B , local_id_B = shared_16x16_to_ldmatrix_32x8_layout (k , j )
142+
129143 T .reads (
130- C [i % 8 * 4 + j % 8 // 2 , j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 2 ],
131- A [i % 8 * 4 + k % 8 // 2 , k % 16 // 8 * 4 + i % 16 // 8 * 2 + k % 2 ],
132- B [k % 8 * 4 + j % 8 // 2 , j % 16 // 8 * 4 + k % 16 // 8 * 2 + j % 2 ],
144+ C [thread_id_C , local_id_C ],
145+ A [thread_id_A , local_id_A ],
146+ B [thread_id_B , local_id_B ],
147+ )
148+ T .writes (C [thread_id_C , local_id_C ])
149+ C [thread_id_C , local_id_C ] += (
150+ A [thread_id_A , local_id_A ] * B [thread_id_B , local_id_B ]
133151 )
134- T .writes (C [i % 8 * 4 + j % 8 // 2 , j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 2 ])
135- C [i % 8 * 4 + j % 8 // 2 , j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 8 % 2 ] = C [
136- i % 8 * 4 + j % 8 // 2 , j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 8 % 2
137- ] + A [i % 8 * 4 + k % 8 // 2 , k % 16 // 8 * 4 + i % 16 // 8 * 2 + k % 8 % 2 ] * B [k % 8 * 4 + j % 8 // 2 , j % 16 // 8 * 4 + k % 16 // 8 * 2 + j % 8 % 2 ]
138152
139153
140154@T .prim_func
@@ -196,14 +210,13 @@ def mma_store_desc(a: T.handle, c: T.handle) -> None:
196210 with T .block ("root" ):
197211 T .reads (C_warp [0 :32 , 0 :8 ])
198212 T .writes (C [0 :16 , 0 :16 ])
199- for ax1_0 , i0 , i1 in T .grid (2 , 32 , 4 ):
213+ for i0 , i1 in T .grid (16 , 16 ):
200214 with T .block ("C_warp" ):
201- v0 = T .axis .spatial (16 , i1 // 2 * 8 + i0 // 4 )
202- v1 = T .axis .spatial (16 , ax1_0 * 8 + i0 % 4 * 2 + i1 % 2 )
203-
204- T .reads (C_warp [v0 % 8 * 4 + v1 % 8 // 2 , v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2 ])
215+ v0 , v1 = T .axis .remap ("SS" , [i0 , i1 ])
216+ thread_id , local_id = shared_16x16_to_ldmatrix_32x8_layout (v0 , v1 )
217+ T .reads (C_warp [thread_id , local_id ])
205218 T .writes (C [v0 , v1 ])
206- C [v0 , v1 ] = C_warp [v0 % 8 * 4 + v1 % 8 // 2 , v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2 ]
219+ C [v0 , v1 ] = C_warp [thread_id , local_id ]
207220
208221
209222@T .prim_func
@@ -236,21 +249,13 @@ def mma_fill_desc(a: T.handle) -> None:
236249 with T .block ("root" ):
237250 T .reads ()
238251 T .writes (C_warp [0 :32 , 0 :8 ])
239- for i0 , i1 in T .grid (32 , 8 ):
252+ for i0 , i1 in T .grid (16 , 16 ):
240253 with T .block ("C_warp" ):
241- i_init = T .axis .spatial ( 16 , i1 // 4 * 8 + i0 // 4 )
242- j_init = T . axis . spatial ( 16 , ( i0 % 4 ) * 4 + i1 % 4 )
254+ i_init , j_init = T .axis .remap ( "SS" , [ i0 , i1 ] )
255+ thread_id , local_id = shared_16x16_to_ldmatrix_32x8_layout ( i_init , j_init )
243256 T .reads ()
244- T .writes (
245- C_warp [
246- i_init % 8 * 4 + j_init % 8 // 2 ,
247- j_init % 16 // 8 * 4 + i_init % 16 // 8 * 2 + j_init % 2 ,
248- ]
249- )
250- C_warp [
251- i_init % 8 * 4 + j_init % 8 // 2 ,
252- j_init % 16 // 8 * 4 + i_init % 16 // 8 * 2 + j_init % 8 % 2 ,
253- ] = T .float16 (0 )
257+ T .writes (C_warp [thread_id , local_id ])
258+ C_warp [thread_id , local_id ] = T .float16 (0 )
254259
255260
256261@T .prim_func
@@ -276,6 +281,7 @@ def mma_fill_impl(a: T.handle) -> None:
276281M = 4096
277282K = 4096
278283
284+
279285def matmul_fp16 (n , m , k ):
280286 a = te .placeholder ((n , k ), name = "A" , dtype = "float16" )
281287 b = te .placeholder ((k , m ), name = "B" , dtype = "float16" )
@@ -373,8 +379,6 @@ def fetch_to_shared(block, idx, ndim):
373379 A_sh = fetch_to_shared (block_outer , 0 , 2 )
374380 B_sh = fetch_to_shared (block_outer , 1 , 2 )
375381
376- loop = sch .get_loops (block_outer )[- 1 ]
377-
378382 A_warp = sch .cache_read (block_outer , 0 , "warp" )
379383 B_warp = sch .cache_read (block_outer , 1 , "warp" )
380384
@@ -389,7 +393,8 @@ def fetch_to_shared(block, idx, ndim):
389393 jo , ji = sch .split (jj , factors = [None , 16 ])
390394 sch .reorder (io , jo , ii , ji )
391395
392- block_init_c = sch .decompose_reduction (block_outer , sch .get_loops (block_outer )[3 ])
396+ sch .decompose_reduction (block_outer , sch .get_loops (block_outer )[3 ])
397+ block_init_c = sch .get_block ("C_init" )
393398
394399 def tile_wmma_fragment (block_read , height ):
395400 i , j = sch .get_loops (block_read )[- 2 :]
@@ -398,47 +403,25 @@ def tile_wmma_fragment(block_read, height):
398403 sch .reorder (i0 , j0 , i1 , j1 )
399404 return i1
400405
401- def shared_16x16_to_ldmatrix_32x8_layout (i , j ):
402- i_0 = i // 16
403- j_0 = j // 16
404-
405- i = i % 16
406- j = j % 16
407-
408- thread_id = 4 * (i % 8 ) + (j % 8 ) // 2
409- return i_0 , j_0 , thread_id , 4 * (j // 8 ) + (i // 8 ) * 2 + (j % 8 ) % 2
410-
411406 loop_a = tile_wmma_fragment (A_warp , 16 )
412407 loop_b = tile_wmma_fragment (B_warp , 16 )
413408
414- sch .transform_layout (A_warp , 0 , "write" , index_map = shared_16x16_to_ldmatrix_32x8_layout )
415- sch .transform_layout (B_warp , 0 , "write" , index_map = shared_16x16_to_ldmatrix_32x8_layout )
416- sch .transform_layout (C_warp , 0 , "read" , index_map = shared_16x16_to_ldmatrix_32x8_layout )
409+ def index_map (i , j ):
410+ return (
411+ i // 16 ,
412+ j // 16 ,
413+ * shared_16x16_to_ldmatrix_32x8_layout (i % 16 , j % 16 ),
414+ )
415+
416+ sch .transform_layout (A_warp , 0 , "write" , index_map )
417+ sch .transform_layout (B_warp , 0 , "write" , index_map )
418+ sch .transform_layout (C_warp , 0 , "read" , index_map )
417419
418420 sch .tensorize (loop_a , "mma.ldmatrix_a" )
419421 sch .tensorize (loop_b , "mma.ldmatrix_b" )
420-
421- mma_loop = sch .get_loops (block_inner )[- 3 ]
422- sch .tensorize (mma_loop , "mma_sync" )
423-
424- block_init_c = sch .get_block ("C_init" )
425- init_loop1 , init_loop2 = sch .get_loops (block_init_c )[- 2 :]
426- f_0 , f_1 = sch .split (init_loop1 , factors = [None , 8 ])
427- f_2 , f_3 = sch .split (init_loop2 , factors = [None , 4 ])
428- sch .reorder (f_1 , f_2 , f_0 , f_3 )
429- fused_1 = sch .fuse (f_1 , f_2 )
430- fused_2 = sch .fuse (f_0 , f_3 )
431- sch .tensorize (fused_1 , "mma_fill" )
432-
433- warp_loop1 , warp_loop2 = sch .get_loops (C_warp )[- 2 :]
434- f_0 , f_1 = sch .split (warp_loop1 , factors = [None , 8 ])
435- outer , f_2 , f_3 = sch .split (warp_loop2 , factors = [2 , 4 , 2 ])
436- sch .reorder (outer , f_1 , f_2 , f_0 , f_3 )
437- fused_1 = sch .fuse (f_1 , f_2 )
438- fused_2 = sch .fuse (f_0 , f_3 )
439- sch .tensorize (outer , "mma_store" )
440- # print(sch.mod.script())
441- # return
422+ sch .tensorize (sch .get_loops (block_inner )[- 3 ], "mma_sync" )
423+ sch .tensorize (sch .get_loops (block_init_c )[- 2 ], "mma_fill" )
424+ sch .tensorize (sch .get_loops (C_warp )[- 2 ], "mma_store" )
442425
443426
444427ir_module = tvm .IRModule ({"main" : workload })
0 commit comments