@@ -183,14 +183,14 @@ def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
183183tir .TensorIntrin .register ("mma.ldmatrix_b" , ldmatrix_b_desc , ldmatrix_b_impl )
184184tir .TensorIntrin .register ("mma_sync" , mma_sync_desc , mma_sync_impl )
185185
186-
187186N = 4096
188187M = 4096
189188K = 4096
190189
191190workload = te .create_prim_func (te_workload .matmul_fp16 (n = N , m = M , k = K ))
192191
193192tune = False
193+ use_ldmatrix = True
194194
195195
196196def schedule (sch : tir .Schedule ):
@@ -199,6 +199,7 @@ def schedule(sch: tir.Schedule):
199199 i , i_tc = sch .split (i , factors = [None , 16 ])
200200 j , j_tc = sch .split (j , factors = [None , 8 ])
201201 k , k_tc = sch .split (k , factors = [None , 8 ])
202+
202203 sch .reorder (
203204 i , j , k ,
204205 i_tc , j_tc , k_tc ,
@@ -211,10 +212,12 @@ def schedule(sch: tir.Schedule):
211212 i_factors = sch .sample_perfect_tile (i , n = 5 )
212213 j_factors = sch .sample_perfect_tile (j , n = 5 )
213214 k_factors = sch .sample_perfect_tile (k , n = 3 )
215+ num_ty = sch .get (i_factors [2 ]) * sch .get (j_factors [2 ])
214216 else :
215217 i_factors = [1 , 16 , 4 , 2 , 2 ]
216218 j_factors = [1 , 64 , 1 , 8 , 1 ]
217219 k_factors = [128 , 4 , 1 ]
220+ num_ty = i_factors [2 ] * j_factors [2 ]
218221
219222 i0 , i1 , i2 , i3 , i4 = sch .split (i , factors = i_factors )
220223 j0 , j1 , j2 , j3 , j4 = sch .split (j , factors = j_factors )
@@ -241,11 +244,6 @@ def schedule(sch: tir.Schedule):
241244 sch .bind (block_idy , "blockIdx.y" )
242245 sch .bind (thread_idy , "threadIdx.y" )
243246
244- if isinstance (i_factors [2 ], int ):
245- num_ty = i_factors [2 ] * j_factors [2 ]
246- else :
247- num_ty = sch .get (i_factors [2 ]) * sch .get (j_factors [2 ])
248-
249247 def fetch_to_shared (block , idx , ndim ):
250248 block_read = sch .cache_read (block , idx , "shared" )
251249 sch .compute_at (block_read , k0 )
@@ -327,8 +325,6 @@ def lambda_b(i, j):
327325 index_map = lambda_a ,
328326 )
329327
330- use_ldmatrix = True
331-
332328 if use_ldmatrix :
333329 sch .tensorize (loop_a , "mma.ldmatrix_a" )
334330 sch .tensorize (loop_b , "mma.ldmatrix_b" )
@@ -347,8 +343,8 @@ def lambda_b(i, j):
347343 fused_1 = sch .fuse (warp_loop2 , f_0 )
348344 sch .bind (fused_1 , "threadIdx.x" )
349345
350- loop = sch .get_loops (block_inner )[- 3 ]
351- sch .tensorize (loop , "mma_sync" )
346+ mma_loop = sch .get_loops (block_inner )[- 3 ]
347+ sch .tensorize (mma_loop , "mma_sync" )
352348
353349 block_init_c = sch .get_block ("C_init" )
354350 init_loop1 , init_loop2 = sch .get_loops (block_init_c )[- 2 :]
@@ -378,7 +374,6 @@ def lambda_b(i, j):
378374 sch = ms .tune_tir (
379375 mod = workload ,
380376 target = tvm .target .Target ("nvidia/geforce-rtx-3070" ),
381- # use replay or evolutionary search
382377 config = ms .TuneConfig (
383378 strategy = "evolutionary" ,
384379 num_trials_per_iter = 32 ,
0 commit comments