@@ -23,10 +23,8 @@ def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None:
2323 with T .block ("A_shared_warp" ):
2424 v0 , v1 = T .axis .remap ("SS" , [ax0 , ax1 ])
2525 T .reads (A_shared [v0 , v1 ])
26- T .writes (A_warp [v0 % 8 * 4 + v1 % 8 // 2 , v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2 ])
27- A_warp [v0 % 8 * 4 + v1 % 8 // 2 , v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2 ] = A_shared [
28- v0 , v1
29- ]
26+ T .writes (A_warp [v0 % 8 * 4 + v1 % 16 // 4 , v1 // 16 * 8 + v0 // 8 * 4 + v1 % 4 ])
27+ A_warp [v0 % 8 * 4 + v1 % 16 // 4 , v1 // 16 * 8 + v0 // 8 * 4 + v1 % 4 ] = A_shared [v0 , v1 ]
3028
3129
3230@T .prim_func
@@ -65,21 +63,19 @@ def ldmatrix_a_impl(a: T.handle, c: T.handle) -> None:
6563
6664@T .prim_func
6765def ldmatrix_b_desc (a : T .handle , c : T .handle ) -> None :
68- B_shared = T .match_buffer (a , (16 , 32 ), "int8" , align = 128 , offset_factor = 16 , scope = "shared" )
66+ B_shared = T .match_buffer (a , (32 , 16 ), "int8" , align = 128 , offset_factor = 16 , scope = "shared" )
6967 B_warp = T .match_buffer (c , (32 , 16 ), "int8" , align = 128 , offset_factor = 16 , scope = "warp" )
7068
7169 with T .block ("root" ):
72- T .reads (B_shared [0 :16 , 0 :32 ])
70+ T .reads (B_shared [0 :32 , 0 :16 ])
7371 T .writes (B_warp [0 :32 , 0 :16 ])
7472
75- for ax0 , ax1 in T .grid (16 , 32 ):
73+ for ax0 , ax1 in T .grid (32 , 16 ):
7674 with T .block ("B_shared_warp" ):
7775 v0 , v1 = T .axis .remap ("SS" , [ax0 , ax1 ])
7876 T .reads (B_shared [v0 , v1 ])
79- T .writes (B_warp [v0 % 8 * 4 + v1 % 8 // 2 , v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2 ])
80- B_warp [v0 % 8 * 4 + v1 % 8 // 2 , v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2 ] = B_shared [
81- v0 , v1
82- ]
77+ T .writes (B_warp [v1 % 8 * 4 + v0 % 4 , v1 // 8 * 8 + v0 // 16 * 4 + v0 % 4 ])
78+ B_warp [v1 % 8 * 4 + v0 % 4 , v1 // 8 * 8 + v0 // 16 * 4 + v0 % 4 ] = B_shared [v0 , v1 ]
8379
8480
8581@T .prim_func
@@ -88,7 +84,7 @@ def ldmatrix_b_impl(a: T.handle, c: T.handle) -> None:
8884 s0 = T .var ("int32" )
8985 B_shared = T .match_buffer (
9086 a ,
91- (16 , 32 ),
87+ (32 , 16 ),
9288 "int8" ,
9389 align = 128 ,
9490 offset_factor = 16 ,
@@ -97,7 +93,7 @@ def ldmatrix_b_impl(a: T.handle, c: T.handle) -> None:
9793 )
9894 B_warp = T .match_buffer (c , (32 , 16 ), "int8" , align = 128 , offset_factor = 16 , scope = "warp" )
9995 with T .block ("root" ):
100- T .reads (B_shared [0 :16 , 0 :32 ])
96+ T .reads (B_shared [0 :32 , 0 :16 ])
10197 T .writes (B_warp [0 :32 , 0 :16 ])
10298 tx = T .env_thread ("threadIdx.x" )
10399 T .launch_thread (tx , 32 )
@@ -110,7 +106,7 @@ def ldmatrix_b_impl(a: T.handle, c: T.handle) -> None:
110106 B_warp .data ,
111107 16 * tx ,
112108 B_shared .data ,
113- 32 * ( tx % 16 ) + 16 * ( tx // 16 ) ,
109+ 16 * tx ,
114110 dtype = "int8" ,
115111 )
116112 )
@@ -125,22 +121,12 @@ def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
125121 with T .block ("root" ):
126122 T .reads (C [0 :32 , 0 :8 ], A [0 :32 , 0 :16 ], B [0 :32 , 0 :16 ])
127123 T .writes (C [0 :32 , 0 :8 ])
128- for i , j , k in T .grid (32 , 8 , 16 ):
124+ for i , j , k in T .grid (16 , 16 , 32 ):
129125 with T .block ("C" ):
130126 i , j , k = T .axis .remap ("SSR" , [i , j , k ])
131- T .reads (
132- C [i % 8 * 4 + j % 8 // 2 , j // 8 * 4 + i // 8 * 2 + j % 2 ],
133- A [i % 8 * 4 + k % 8 // 2 , k // 8 * 4 + i // 8 * 2 + k % 2 ],
134- B [k % 8 * 4 + j % 8 // 2 , j // 8 * 4 + k // 8 * 2 + j % 2 ],
135- )
127+ T .reads (C [i % 8 * 4 + j % 8 // 2 , j // 8 * 4 + i // 8 * 2 + j % 2 ], A [i % 8 * 4 + k % 16 // 4 , k // 16 * 8 + i // 8 * 4 + k % 4 ], B [j % 8 * 4 + k % 4 , j // 8 * 8 + k // 16 * 4 + k % 4 ])
136128 T .writes (C [i % 8 * 4 + j % 8 // 2 , j // 8 * 4 + i // 8 * 2 + j % 2 ])
137- C [i % 8 * 4 + j % 8 // 2 , j // 8 * 4 + i // 8 * 2 + j % 2 ] = C [
138- i % 8 * 4 + j % 8 // 2 , j // 8 * 4 + i // 8 * 2 + j % 2
139- ] + T .cast (
140- A [i % 8 * 4 + k % 8 // 2 , k // 8 * 4 + i // 8 * 2 + k % 2 ], "int32"
141- ) * T .cast (
142- B [k % 8 * 4 + j % 8 // 2 , j // 8 * 4 + k // 8 * 2 + j % 2 ], "int32"
143- )
129+ C [i % 8 * 4 + j % 8 // 2 , j // 8 * 4 + i // 8 * 2 + j % 2 ] = C [i % 8 * 4 + j % 8 // 2 , j // 8 * 4 + i // 8 * 2 + j % 2 ] + T .cast (A [i % 8 * 4 + k % 16 // 4 , k // 16 * 8 + i // 8 * 4 + k % 4 ], "int32" ) * T .cast (B [j % 8 * 4 + k % 4 , j // 8 * 8 + k // 16 * 4 + k % 4 ], "int32" )
144130
145131
146132@T .prim_func
@@ -271,7 +257,9 @@ def mma_fill_impl(a: T.handle) -> None:
271257tir .TensorIntrin .register ("mma_fill" , mma_fill_desc , mma_fill_impl )
272258
273259
274- M = N = K = 16
260+ M = 16
261+ N = 16
262+ K = 32
275263
276264def matmul_int8 (n , m , k ):
277265 a = te .placeholder ((n , k ), name = "A" , dtype = "int8" )
@@ -300,13 +288,12 @@ def f_compute(i, j):
300288
301289def fetch_to_shared (block , idx ):
302290 block_read = sch .cache_read (block , idx , "shared" )
303- if use_gpu :
304- sch .compute_at (block_read , i1 , True )
305- warp_size = 32
306- loops = sch .get_loops (block_read )
307- fused = sch .fuse (* loops [- 2 :])
308- f_0 , f_1 = sch .split (fused , factors = [None , warp_size ])
309- sch .bind (f_1 , "threadIdx.x" )
291+ sch .compute_at (block_read , i1 , True )
292+ warp_size = 32
293+ loops = sch .get_loops (block_read )
294+ fused = sch .fuse (* loops [- 2 :])
295+ f_0 , f_1 = sch .split (fused , factors = [None , warp_size ])
296+ sch .bind (f_1 , "threadIdx.x" )
310297
311298 return block_read
312299
@@ -320,18 +307,28 @@ def shared_16x16_to_ldmatrix_32x8_layout(i, j):
320307 return thread_id , 4 * (j // 8 ) + (i // 8 ) * 2 + (j % 8 ) % 2
321308
322309
310+ def shared_16x32_to_ldmatrix_32x16_layout (i , j ):
311+ thread_id = 4 * (i % 8 ) + (j % 16 ) // 4
312+ return thread_id , 8 * (j // 16 ) + (i // 8 ) * 4 + j % 4
313+
314+
315+ def shared_32x16_to_ldmatrix_32x16_layout (i , j ):
316+ thread_id = (i % 4 ) + 4 * (j % 8 )
317+ return thread_id , 8 * (j // 8 ) + (i // 16 ) * 4 + i % 4
318+
319+
323320block = sch .get_block ("C" )
324321
325322A_warp = sch .cache_read (block , 0 , "warp" )
326323
327- # sch.transform_layout(A_warp, 0, "write", index_map=shared_16x16_to_ldmatrix_32x8_layout )
324+ sch .transform_layout (A_warp , 0 , "write" , index_map = shared_16x32_to_ldmatrix_32x16_layout )
328325
329326B_warp = sch .cache_read (block , 1 , "warp" )
330327
331- # sch.transform_layout(B_warp, 0, "write", index_map=shared_16x16_to_ldmatrix_32x8_layout )
328+ sch .transform_layout (B_warp , 0 , "write" , index_map = shared_32x16_to_ldmatrix_32x16_layout )
332329
333- # sch.tensorize(sch.get_loops(A_warp)[1], "mma.ldmatrix_a")
334- # sch.tensorize(sch.get_loops(B_warp)[1], "mma.ldmatrix_b")
330+ sch .tensorize (sch .get_loops (A_warp )[1 ], "mma.ldmatrix_a" )
331+ sch .tensorize (sch .get_loops (B_warp )[1 ], "mma.ldmatrix_b" )
335332
336333C_warp = sch .cache_write (block , 0 , "warp" )
337334sch .reverse_compute_at (C_warp , sch .get_loops (block )[0 ])
@@ -344,7 +341,7 @@ def shared_16x16_to_ldmatrix_32x8_layout(i, j):
344341fused_1 = sch .fuse (f_1 , f_2 )
345342fused_2 = sch .fuse (f_0 , f_3 )
346343
347- # sch.tensorize(outer, "mma_store")
344+ sch .tensorize (outer , "mma_store" )
348345
349346block_init_c = sch .decompose_reduction (block , sch .get_loops (block )[1 ])
350347
@@ -356,25 +353,25 @@ def shared_16x16_to_ldmatrix_32x8_layout(i, j):
356353fused_2 = sch .fuse (f_0 , f_3 )
357354sch .tensorize (fused_1 , "mma_fill" )
358355
359- # sch.tensorize(sch.get_loops(block)[1], "mma.mma_sync")
356+ sch .tensorize (sch .get_loops (block )[1 ], "mma.mma_sync" )
360357
361358print (sch .mod .script ())
362359
363360# lowered = tvm.lower(sch.mod["main"])
364361
365- # target = "cuda"
362+ target = "cuda"
366363
367- # f = tvm.build(sch.mod["main"], target=target, name="dense")
368- # dev = tvm.device(target, 0)
364+ f = tvm .build (sch .mod ["main" ], target = target , name = "dense" )
365+ dev = tvm .device (target , 0 )
369366
370- # a_np = np.random.randint(-128, 128, (M, K)).astype("int8")
371- # b_np = np.random.randint(-128, 128, (K, N)).astype("int8")
372- # c_np = np.dot(a_np.astype("int3232"), b_np.astype("in32"))
367+ a_np = np .random .randint (- 128 , 128 , (M , K )).astype ("int8" )
368+ b_np = np .random .randint (- 128 , 128 , (K , N )).astype ("int8" )
369+ c_np = np .dot (a_np .astype ("int3232" ), b_np .astype ("in32" ))
373370
374- # a = tvm.nd.array(a_np, dev)
375- # b = tvm.nd.array(b_np, dev)
376- # c = tvm.nd.array(np.zeros((16, K), dtype="float32"), dev)
371+ a = tvm .nd .array (a_np , dev )
372+ b = tvm .nd .array (b_np , dev )
373+ c = tvm .nd .array (np .zeros ((16 , K ), dtype = "float32" ), dev )
377374
378- # # print(f.imported_modules[0].get_source())
379- # f(a, b, c)
380- # np.testing.assert_equal(c.numpy(), c_np)
375+ # print(f.imported_modules[0].get_source())
376+ f (a , b , c )
377+ np .testing .assert_equal (c .numpy (), c_np )
0 commit comments