@@ -127,17 +127,19 @@ def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
127127 with T .block ("C" ):
128128 i , j , k = T .axis .remap ("SSR" , [i , j , k ])
129129 T .reads (
130- C [i % 8 * 4 + j % 8 // 2 , j // 8 * 4 + i // 8 * 2 + j % 2 ],
131- A [i % 8 * 4 + k % 8 // 2 , k // 8 * 4 + i // 8 * 2 + k % 2 ],
132- B [k % 8 * 4 + j % 8 // 2 , j // 8 * 4 + k // 8 * 2 + j % 2 ],
130+ C [i % 8 * 4 + j % 8 // 2 , j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 2 ],
131+ A [i % 8 * 4 + k % 8 // 2 , k % 16 // 8 * 4 + i % 16 // 8 * 2 + k % 2 ],
132+ B [k % 8 * 4 + j % 8 // 2 , j % 16 // 8 * 4 + k % 16 // 8 * 2 + j % 2 ],
133133 )
134- T .writes (C [i % 8 * 4 + j % 8 // 2 , j // 8 * 4 + i // 8 * 2 + j % 2 ])
135- C [i % 8 * 4 + j % 8 // 2 , j // 8 * 4 + i // 8 * 2 + j % 2 ] = C [
136- i % 8 * 4 + j % 8 // 2 , j // 8 * 4 + i // 8 * 2 + j % 2
134+ T .writes (C [i % 8 * 4 + j % 8 // 2 , j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 2 ])
135+ C [i % 8 * 4 + j % 8 // 2 , j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 8 % 2 ] = C [
136+ i % 8 * 4 + j % 8 // 2 , j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 8 % 2
137137 ] + T .cast (
138- A [i % 8 * 4 + k % 8 // 2 , k // 8 * 4 + i // 8 * 2 + k % 2 ], "float32"
138+ A [i % 8 * 4 + k % 8 // 2 , k % 16 // 8 * 4 + i % 16 // 8 * 2 + k % 8 % 2 ],
139+ "float32" ,
139140 ) * T .cast (
140- B [k % 8 * 4 + j % 8 // 2 , j // 8 * 4 + k // 8 * 2 + j % 2 ], "float32"
141+ B [k % 8 * 4 + j % 8 // 2 , j % 16 // 8 * 4 + k % 16 // 8 * 2 + j % 8 % 2 ],
142+ "float32" ,
141143 )
142144
143145
@@ -242,11 +244,19 @@ def mma_fill_desc(a: T.handle) -> None:
242244 T .writes (C_warp [0 :32 , 0 :8 ])
243245 for i0 , i1 in T .grid (32 , 8 ):
244246 with T .block ("C_warp" ):
245- i = T .axis .spatial (16 , i1 // 4 * 8 + i0 // 4 )
246- j = T .axis .spatial (16 , (i0 % 4 ) * 4 + i1 % 4 )
247+ i_init = T .axis .spatial (16 , i1 // 4 * 8 + i0 // 4 )
248+ j_init = T .axis .spatial (16 , (i0 % 4 ) * 4 + i1 % 4 )
247249 T .reads ()
248- T .writes (C_warp [i % 8 * 4 + j % 8 // 2 , j // 8 * 4 + i // 8 * 2 + j % 2 ])
249- C_warp [i % 8 * 4 + j % 8 // 2 , j // 8 * 4 + i // 8 * 2 + j % 2 ] = T .float32 (0 )
250+ T .writes (
251+ C_warp [
252+ i_init % 8 * 4 + j_init % 8 // 2 ,
253+ j_init % 16 // 8 * 4 + i_init % 16 // 8 * 2 + j_init % 2 ,
254+ ]
255+ )
256+ C_warp [
257+ i_init % 8 * 4 + j_init % 8 // 2 ,
258+ j_init % 16 // 8 * 4 + i_init % 16 // 8 * 2 + j_init % 8 % 2 ,
259+ ] = T .float32 (0 )
250260
251261
252262@T .prim_func
@@ -304,8 +314,8 @@ def schedule(sch: tir.Schedule):
304314 num_ty = sch .get (i_factors [2 ]) * sch .get (j_factors [2 ])
305315 else :
306316 i_factors = [1 , 16 , 4 , 2 , 2 ]
307- j_factors = [1 , 64 , 1 , 8 , 1 ]
308- k_factors = [128 , 4 , 1 ]
317+ j_factors = [1 , 32 , 1 , 8 , 1 ]
318+ k_factors = [64 , 4 , 1 ]
309319 num_ty = i_factors [2 ] * j_factors [2 ]
310320
311321 i0 , i1 , i2 , i3 , i4 = sch .split (i , factors = i_factors )
@@ -368,7 +378,7 @@ def fetch_to_shared(block, idx, ndim):
368378
369379 ii , jj = sch .get_loops (C_warp )[- 2 :]
370380 io , ii = sch .split (ii , factors = [None , 16 ])
371- jo , ji = sch .split (jj , factors = [None , 8 ])
381+ jo , ji = sch .split (jj , factors = [None , 16 ])
372382 sch .reorder (io , jo , ii , ji )
373383
374384 block_init_c = sch .decompose_reduction (block_outer , sch .get_loops (block_outer )[3 ])
@@ -394,18 +404,10 @@ def shared_16x16_to_ldmatrix_32x8_layout(i, j):
394404 loop_b = tile_wmma_fragment (B_warp , 16 )
395405
396406 sch .transform_layout (A_warp , 0 , "write" , index_map = shared_16x16_to_ldmatrix_32x8_layout )
397- sch .transform_layout (
398- B_warp ,
399- 0 ,
400- "write" ,
401- index_map = shared_16x16_to_ldmatrix_32x8_layout
402- )
403- sch .transform_layout (
404- C_warp ,
405- 0 ,
406- "read" ,
407- index_map = shared_16x16_to_ldmatrix_32x8_layout
408- )
407+ sch .transform_layout (B_warp , 0 , "write" , index_map = shared_16x16_to_ldmatrix_32x8_layout )
408+ sch .transform_layout (C_warp , 0 , "read" , index_map = shared_16x16_to_ldmatrix_32x8_layout )
409+
410+ # return
409411
410412 if use_ldmatrix :
411413 sch .tensorize (loop_a , "mma.ldmatrix_a" )
@@ -425,69 +427,65 @@ def shared_16x16_to_ldmatrix_32x8_layout(i, j):
425427 fused_1 = sch .fuse (warp_loop2 , f_0 )
426428 sch .bind (fused_1 , "threadIdx.x" )
427429
428- # mma_loop = sch.get_loops(block_inner)[-3]
429- # sch.tensorize(mma_loop, "mma_sync")
430-
431- # block_init_c = sch.get_block("C_init")
432- # init_loop1, init_loop2 = sch.get_loops(block_init_c)[-2:]
433- # f_0, f_1 = sch.split(init_loop1, factors=[None, 8])
434- # f_2, f_3 = sch.split(init_loop2, factors=[None, 2])
435- # sch.reorder(f_1, f_2, f_0, f_3)
436- # fused_1 = sch.fuse(f_1, f_2)
437- # fused_2 = sch.fuse(f_0, f_3)
438- # # sch.bind(fused_1, "threadIdx.x")
439- # sch.tensorize(fused_1, "mma_fill")
440-
441- # warp_loop1, warp_loop2 = sch.get_loops(C_warp)[-2:]
442- # f_0, f_1 = sch.split(warp_loop1, factors=[None, 8])
443- # f_2, f_3 = sch.split(warp_loop2, factors=[None, 2])
444- # sch.reorder(f_1, f_2, f_0, f_3)
445- # fused_1 = sch.fuse(f_1, f_2)
446- # fused_2 = sch.fuse(f_0, f_3)
447-
448- # # print(sch.mod.script())
449- # # return
450-
451- # sch.tensorize(fused_1, "mma_store")
430+ mma_loop = sch .get_loops (block_inner )[- 3 ]
431+ sch .tensorize (mma_loop , "mma_sync" )
432+
433+ block_init_c = sch .get_block ("C_init" )
434+ init_loop1 , init_loop2 = sch .get_loops (block_init_c )[- 2 :]
435+ f_0 , f_1 = sch .split (init_loop1 , factors = [None , 8 ])
436+ f_2 , f_3 = sch .split (init_loop2 , factors = [None , 4 ])
437+ sch .reorder (f_1 , f_2 , f_0 , f_3 )
438+ fused_1 = sch .fuse (f_1 , f_2 )
439+ fused_2 = sch .fuse (f_0 , f_3 )
440+ sch .tensorize (fused_1 , "mma_fill" )
441+
442+ warp_loop1 , warp_loop2 = sch .get_loops (C_warp )[- 2 :]
443+ f_0 , f_1 = sch .split (warp_loop1 , factors = [None , 8 ])
444+ outer , f_2 , f_3 = sch .split (warp_loop2 , factors = [2 , 4 , 2 ])
445+ sch .reorder (outer , f_1 , f_2 , f_0 , f_3 )
446+ fused_1 = sch .fuse (f_1 , f_2 )
447+ fused_2 = sch .fuse (f_0 , f_3 )
448+ sch .tensorize (outer , "mma_store" )
449+ # print(sch.mod.script())
450+ # return
452451
453452
454453ir_module = tvm .IRModule ({"main" : workload })
455454sch = tvm .tir .Schedule (ir_module )
456455schedule (sch )
457456print (sch .mod .script ())
458457
459- # if tune:
460- # with tempfile.TemporaryDirectory() as work_dir:
461- # sch = ms.tune_tir(
462- # mod=workload,
463- # target=tvm.target.Target("nvidia/geforce-rtx-3070"),
464- # config=ms.TuneConfig(
465- # strategy="evolutionary",
466- # num_trials_per_iter=32,
467- # max_trials_per_task=128,
468- # max_trials_global=128,
469- # ),
470- # work_dir=work_dir,
471- # space=ms.space_generator.ScheduleFn(schedule),
472- # )
473- # if sch is None:
474- # print("No valid schedule found!")
475- # else:
476- # print(sch.mod.script())
477- # print(sch.trace)
478- # else:
479- # print(sch.mod.script())
480- # target = "cuda"
481- # f = tvm.build(sch.mod["main"], target=target, name="dense")
482-
483- # dev = tvm.device("cuda", 0)
484- # a_np = np.random.uniform(size=(N, K)).astype("float16")
485- # b_np = np.random.uniform(size=(K, M)).astype("float16")
486- # c_np = np.dot(a_np.astype("float32"), b_np.astype("float32"))
487- # a = tvm.nd.array(a_np, dev)
488- # b = tvm.nd.array(b_np, dev)
489- # c = tvm.nd.array(np.zeros((M, N), dtype="float32"), dev)
490- # f = tvm.build(sch.mod["main"], target="cuda", name="dense")
458+ if tune :
459+ with tempfile .TemporaryDirectory () as work_dir :
460+ sch = ms .tune_tir (
461+ mod = workload ,
462+ target = tvm .target .Target ("nvidia/geforce-rtx-3070" ),
463+ config = ms .TuneConfig (
464+ strategy = "evolutionary" ,
465+ num_trials_per_iter = 32 ,
466+ max_trials_per_task = 128 ,
467+ max_trials_global = 128 ,
468+ ),
469+ work_dir = work_dir ,
470+ space = ms .space_generator .ScheduleFn (schedule ),
471+ )
472+ if sch is None :
473+ print ("No valid schedule found!" )
474+ else :
475+ print (sch .mod .script ())
476+ print (sch .trace )
477+ else :
478+ target = "cuda"
479+ f = tvm .build (sch .mod ["main" ], target = target , name = "dense" )
480+
481+ dev = tvm .device ("cuda" , 0 )
482+ a_np = np .random .uniform (size = (N , K )).astype ("float16" )
483+ b_np = np .random .uniform (size = (K , M )).astype ("float16" )
484+ c_np = np .dot (a_np .astype ("float32" ), b_np .astype ("float32" ))
485+ a = tvm .nd .array (a_np , dev )
486+ b = tvm .nd .array (b_np , dev )
487+ c = tvm .nd .array (np .zeros ((M , N ), dtype = "float32" ), dev )
488+ f = tvm .build (sch .mod ["main" ], target = "cuda" , name = "dense" )
491489
492490# print(f.imported_modules[0].get_source())
493491# f(a, b, c)
0 commit comments