@@ -22,7 +22,9 @@ def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None:
2222 v0 , v1 = T .axis .remap ("SS" , [ax0 , ax1 ])
2323 T .reads (A_shared [v0 , v1 ])
2424 T .writes (A_warp [v0 % 8 * 4 + v1 % 16 // 4 , v1 // 16 * 8 + v0 // 8 * 4 + v1 % 4 ])
25- A_warp [v0 % 8 * 4 + v1 % 16 // 4 , v1 // 16 * 8 + v0 // 8 * 4 + v1 % 4 ] = A_shared [v0 , v1 ]
25+ A_warp [v0 % 8 * 4 + v1 % 16 // 4 , v1 // 16 * 8 + v0 // 8 * 4 + v1 % 4 ] = A_shared [
26+ v0 , v1
27+ ]
2628
2729
2830@T .prim_func
@@ -122,9 +124,19 @@ def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
122124 for i , j , k in T .grid (16 , 16 , 32 ):
123125 with T .block ("C" ):
124126 i , j , k = T .axis .remap ("SSR" , [i , j , k ])
125- T .reads (C [i % 8 * 4 + j % 8 // 2 , j // 8 * 4 + i // 8 * 2 + j % 2 ], A [i % 8 * 4 + k % 16 // 4 , k // 16 * 8 + i // 8 * 4 + k % 4 ], B [j % 8 * 4 + k % 4 , j // 8 * 8 + k // 16 * 4 + k % 4 ])
126- T .writes (C [i % 8 * 4 + j % 8 // 2 , j // 8 * 4 + i // 8 * 2 + j % 2 ])
127- C [i % 8 * 4 + j % 8 // 2 , j // 8 * 4 + i // 8 * 2 + j % 2 ] = C [i % 8 * 4 + j % 8 // 2 , j // 8 * 4 + i // 8 * 2 + j % 2 ] + T .cast (A [i % 8 * 4 + k % 16 // 4 , k // 16 * 8 + i // 8 * 4 + k % 4 ], "int32" ) * T .cast (B [j % 8 * 4 + k % 4 , j // 8 * 8 + k // 16 * 4 + k % 4 ], "int32" )
127+ T .reads (
128+ C [i % 8 * 4 + j % 8 // 2 , j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 2 ],
129+ A [i % 8 * 4 + k % 16 // 4 , k % 32 // 16 * 8 + i % 16 // 8 * 4 + k % 4 ],
130+ B [j % 8 * 4 + k % 4 , j % 16 // 8 * 8 + k % 32 // 16 * 4 + k % 4 ],
131+ )
132+ T .writes (C [i % 8 * 4 + j % 8 // 2 , j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 2 ])
133+ C [i % 8 * 4 + j % 8 // 2 , j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 8 % 2 ] = C [
134+ i % 8 * 4 + j % 8 // 2 , j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 8 % 2
135+ ] + T .cast (
136+ A [i % 8 * 4 + k % 16 // 4 , k % 32 // 16 * 8 + i % 16 // 8 * 4 + k % 4 ], "int32"
137+ ) * T .cast (
138+ B [j % 8 * 4 + k % 4 , j % 16 // 8 * 8 + k % 32 // 16 * 4 + k % 4 ], "int32"
139+ )
128140
129141
130142@T .prim_func
@@ -266,6 +278,7 @@ def mma_fill_impl(a: T.handle) -> None:
266278M = 4096
267279K = 4096
268280
281+
269282def matmul_int8 (n , m , k ):
270283 a = te .placeholder ((n , k ), name = "A" , dtype = "int8" )
271284 b = te .placeholder ((k , m ), name = "B" , dtype = "int8" )
@@ -289,8 +302,8 @@ def schedule(sch: tir.Schedule):
289302 block = sch .get_block ("C" )
290303 i , j , k = sch .get_loops (block )
291304 i , i_tc = sch .split (i , factors = [None , 16 ])
292- j , j_tc = sch .split (j , factors = [None , 32 ])
293- k , k_tc = sch .split (k , factors = [None , 16 ])
305+ j , j_tc = sch .split (j , factors = [None , 16 ])
306+ k , k_tc = sch .split (k , factors = [None , 32 ])
294307
295308 sch .reorder (
296309 i ,
@@ -311,8 +324,8 @@ def schedule(sch: tir.Schedule):
311324 num_ty = sch .get (i_factors [2 ]) * sch .get (j_factors [2 ])
312325 else :
313326 i_factors = [4 , 8 , 2 , 4 , 1 ]
314- j_factors = [1 , 32 , 2 , 1 , 2 ]
315- k_factors = [128 , 2 , 1 ]
327+ j_factors = [1 , 64 , 2 , 1 , 2 ]
328+ k_factors = [64 , 2 , 1 ]
316329
317330 num_ty = i_factors [2 ] * j_factors [2 ]
318331
@@ -381,13 +394,10 @@ def fetch_to_shared(block, idx, ndim):
381394
382395 block_init_c = sch .decompose_reduction (block_outer , sch .get_loops (block_outer )[3 ])
383396
384- def tile_wmma_fragment (block_read , height , is_b = False ):
397+ def tile_wmma_fragment (block_read , height , width ):
385398 i , j = sch .get_loops (block_read )[- 2 :]
386399 i0 , i1 = sch .split (i , factors = [None , height ])
387- if is_b :
388- j0 , j1 = sch .split (j , factors = [32 , None ])
389- else :
390- j0 , j1 = sch .split (j , factors = [None , 32 ])
400+ j0 , j1 = sch .split (j , factors = [None , width ])
391401 sch .reorder (i0 , j0 , i1 , j1 )
392402 return i1
393403
@@ -411,7 +421,6 @@ def shared_16x32_to_ldmatrix_32x16_layout(i, j):
411421 thread_id = 4 * (i % 8 ) + (j % 16 ) // 4
412422 return i_0 , j_0 , thread_id , 8 * (j // 16 ) + (i // 8 ) * 4 + j % 4
413423
414-
415424 def shared_32x16_to_ldmatrix_32x16_layout (i , j ):
416425 i_0 = i // 32
417426 j_0 = j // 16
@@ -422,8 +431,8 @@ def shared_32x16_to_ldmatrix_32x16_layout(i, j):
422431 thread_id = (i % 4 ) + 4 * (j % 8 )
423432 return i_0 , j_0 , thread_id , 8 * (j // 8 ) + (i // 16 ) * 4 + i % 4
424433
425- loop_a = tile_wmma_fragment (A_warp , 16 )
426- loop_b = tile_wmma_fragment (B_warp , 16 , True )
434+ loop_a = tile_wmma_fragment (A_warp , 16 , 32 )
435+ loop_b = tile_wmma_fragment (B_warp , 32 , 16 )
427436
428437 sch .transform_layout (A_warp , 0 , "write" , index_map = shared_16x32_to_ldmatrix_32x16_layout )
429438 sch .transform_layout (B_warp , 0 , "write" , index_map = shared_32x16_to_ldmatrix_32x16_layout )
@@ -460,44 +469,44 @@ def shared_32x16_to_ldmatrix_32x16_layout(i, j):
460469schedule (sch )
461470print (sch .mod .script ())
462471
463- # if tune:
464- # with tempfile.TemporaryDirectory() as work_dir:
465- # sch = ms.tune_tir(
466- # mod=workload,
467- # target=tvm.target.Target("nvidia/geforce-rtx-3070"),
468- # config=ms.TuneConfig(
469- # strategy="evolutionary",
470- # num_trials_per_iter=32,
471- # max_trials_per_task=128,
472- # max_trials_global=128,
473- # ),
474- # work_dir=work_dir,
475- # space=ms.space_generator.ScheduleFn(schedule),
476- # )
477- # if sch is None:
478- # print("No valid schedule found!")
479- # else:
480- # print(sch.mod.script())
481- # print(sch.trace)
482- # else:
483- # target = "cuda"
484- # f = tvm.build(sch.mod["main"], target=target, name="dense")
485-
486- # dev = tvm.device("cuda", 0)
487- # a_np = np.random.uniform(size=(N, K)).astype("int8")
488- # b_np = np.random.uniform(size=(K, M)).astype("int8")
489- # c_np = np.dot(a_np.astype("int32"), b_np.astype("int32"))
490- # a = tvm.nd.array(a_np, dev)
491- # b = tvm.nd.array(b_np, dev)
492- # c = tvm.nd.array(np.zeros((M, N), dtype="int32"), dev)
493- # f = tvm.build(sch.mod["main"], target="cuda", name="dense")
494-
495- # print(f.imported_modules[0].get_source())
496- # f(a, b, c)
497- # tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
498- # print("ok")
499-
500- # evaluator = f.time_evaluator(f.entry_name, dev, number=1000)
501- # gflops = (N * M * K) * 2 / 1e9
502- # time_ms = evaluator(a, b, c).mean * 1e3
503- # print("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms, gflops / (time_ms / 1e3)))
472+ if tune :
473+ with tempfile .TemporaryDirectory () as work_dir :
474+ sch = ms .tune_tir (
475+ mod = workload ,
476+ target = tvm .target .Target ("nvidia/geforce-rtx-3070" ),
477+ config = ms .TuneConfig (
478+ strategy = "evolutionary" ,
479+ num_trials_per_iter = 32 ,
480+ max_trials_per_task = 128 ,
481+ max_trials_global = 128 ,
482+ ),
483+ work_dir = work_dir ,
484+ space = ms .space_generator .ScheduleFn (schedule ),
485+ )
486+ if sch is None :
487+ print ("No valid schedule found!" )
488+ else :
489+ print (sch .mod .script ())
490+ print (sch .trace )
491+ else :
492+ target = "cuda"
493+ f = tvm .build (sch .mod ["main" ], target = target , name = "dense" )
494+
495+ dev = tvm .device ("cuda" , 0 )
496+ a_np = np .random .uniform (size = (N , K )).astype ("int8" )
497+ b_np = np .random .uniform (size = (K , M )).astype ("int8" )
498+ c_np = np .dot (a_np .astype ("int32" ), b_np .astype ("int32" ))
499+ a = tvm .nd .array (a_np , dev )
500+ b = tvm .nd .array (b_np , dev )
501+ c = tvm .nd .array (np .zeros ((M , N ), dtype = "int32" ), dev )
502+ f = tvm .build (sch .mod ["main" ], target = "cuda" , name = "dense" )
503+
504+ print (f .imported_modules [0 ].get_source ())
505+ f (a , b , c )
506+ tvm .testing .assert_allclose (c .numpy (), c_np , rtol = 1e-3 )
507+ print ("ok" )
508+
509+ evaluator = f .time_evaluator (f .entry_name , dev , number = 1000 )
510+ gflops = (N * M * K ) * 2 / 1e9
511+ time_ms = evaluator (a , b , c ).mean * 1e3
512+ print ("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms , gflops / (time_ms / 1e3 )))
0 commit comments