@@ -53,8 +53,8 @@ def ldmatrix_a_impl(a: T.handle, c: T.handle) -> None:
5353 4 ,
5454 ".b16" ,
5555 A_warp .data ,
56- 16 * tx ,
57- A_shared .data ,
56+ A_warp . elem_offset + 16 * tx ,
57+ A_shared .access_ptr ( "r" ) ,
5858 s1 * (tx % 16 ) + 16 * (tx // 16 ),
5959 dtype = "int8" ,
6060 )
@@ -104,8 +104,8 @@ def ldmatrix_b_impl(a: T.handle, c: T.handle) -> None:
104104 4 ,
105105 ".b16" ,
106106 B_warp .data ,
107- 16 * tx ,
108- B_shared .data ,
107+ B_warp . elem_offset + 16 * tx ,
108+ B_shared .access_ptr ( "r" ) ,
109109 s1 ,
110110 dtype = "int8" ,
111111 )
@@ -359,7 +359,7 @@ def schedule(sch: tir.Schedule):
359359 sch .bind (block_idy , "blockIdx.y" )
360360 sch .bind (thread_idy , "threadIdx.y" )
361361
362- def fetch_to_shared (block , idx , ndim ):
362+ def fetch_to_shared (block , idx , ndim , vec = False ):
363363 block_read = sch .cache_read (block , idx , "shared" )
364364 sch .compute_at (block_read , k0 )
365365 vector_size = 16
@@ -368,13 +368,15 @@ def fetch_to_shared(block, idx, ndim):
368368 f_0 , f_1 , f_2 , f_3 = sch .split (fused , factors = [None , num_ty , warp_size , vector_size ])
369369 sch .bind (f_2 , "threadIdx.x" )
370370 sch .bind (f_1 , "threadIdx.y" )
371- sch .vectorize (f_3 )
372- sch .storage_align (block_read , 0 , axis = - 2 , factor = 32 , offset = 16 )
371+
372+ if vec :
373+ sch .vectorize (f_3 )
374+ sch .storage_align (block_read , 0 , axis = - 2 , factor = 32 , offset = 16 )
373375
374376 return block_read
375377
376- A_sh = fetch_to_shared (block_outer , 0 , 2 )
377- B_sh = fetch_to_shared (block_outer , 1 , 2 )
378+ A_sh = fetch_to_shared (block_outer , 0 , 2 , True )
379+ B_sh = fetch_to_shared (block_outer , 1 , 2 , True )
378380
379381 loop = sch .get_loops (block_outer )[- 1 ]
380382
@@ -488,14 +490,11 @@ def shared_32x16_to_ldmatrix_32x16_layout(i, j):
488490 else :
489491 print (sch .mod .script ())
490492 print (sch .trace )
491- else :
492- target = "cuda"
493- f = tvm .build (sch .mod ["main" ], target = target , name = "dense" )
494493
495494dev = tvm .device ("cuda" , 0 )
496- a_np = np .random .uniform ( size = ( N , K )).astype ("int8" )
497- b_np = np .random .uniform ( size = (K , M )).astype ("int8" )
498- c_np = np .dot (a_np .astype ("int32 " ), b_np .astype ("int32" ) )
495+ a_np = np .random .randint ( - 128 , 128 , ( M , K )).astype ("int8" )
496+ b_np = np .random .randint ( - 128 , 128 , (K , N )).astype ("int8" )
497+ c_np = np .dot (a_np .astype ("float32 " ), b_np .astype ("float32" )). astype ( "int32" )
499498a = tvm .nd .array (a_np , dev )
500499b = tvm .nd .array (b_np , dev )
501500c = tvm .nd .array (np .zeros ((M , N ), dtype = "int32" ), dev )
@@ -506,7 +505,7 @@ def shared_32x16_to_ldmatrix_32x16_layout(i, j):
506505tvm .testing .assert_allclose (c .numpy (), c_np , rtol = 1e-3 )
507506print ("ok" )
508507
509- evaluator = f .time_evaluator (f .entry_name , dev , number = 1000 )
508+ evaluator = f .time_evaluator (f .entry_name , dev , number = 500 )
510509gflops = (N * M * K ) * 2 / 1e9
511510time_ms = evaluator (a , b , c ).mean * 1e3
512511print ("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms , gflops / (time_ms / 1e3 )))
0 commit comments