@@ -68,7 +68,7 @@ def maybe_swap(i, j):
6868 return i , j
6969
7070 c = te .compute (
71- (n , m ),
71+ (m , n ),
7272 lambda i , j : te .sum (maybe_cast (a [i , k ]) * maybe_cast (b [maybe_swap (k , j )]), axis = [k ]),
7373 name = "C" ,
7474 )
@@ -132,7 +132,8 @@ def fetch_to_shared(block, idx, ndim):
132132 sch .bind (f_2 , "threadIdx.x" )
133133 sch .bind (f_1 , "threadIdx.y" )
134134 sch .vectorize (f_3 )
135- sch .storage_align (block_read , 0 , axis = - 2 , factor = 32 , offset = 8 )
135+ offset = 8 if in_dtype == "float16" else 16
136+ sch .storage_align (block_read , 0 , axis = - 2 , factor = 32 , offset = offset )
136137
137138 return block_read
138139
@@ -180,36 +181,42 @@ def tile_wmma_fragment(block_read, height, width):
180181 sch .tensorize (sch .get_loops (block_init_c )[- 2 ], mma_fill_intrin )
181182 sch .tensorize (sch .get_loops (C_warp )[- 2 ], mma_store_intrin )
182183
183- # print(sch.mod.script())
184-
185184 f = tvm .build (sch .mod ["main" ], target = "cuda" , name = "dense" )
186185 dev = tvm .device ("cuda" , 0 )
187186
188187 if in_dtype == "float16" :
189188 a_np = np .random .uniform (size = (M , K )).astype ("float16" )
190189
191190 if b_transposed :
192- b_np = np .random .uniform (size = (N , K )).astype ("float16" ).transpose ()
191+ b_np = np .random .uniform (size = (N , K )).astype ("float16" )
192+ c_np = np .dot (a_np .astype ("float32" ), b_np .astype ("float32" ).transpose ()).astype (
193+ out_dtype
194+ )
193195 else :
194196 b_np = np .random .uniform (size = (K , N )).astype ("float16" )
195-
196- c_np = np .dot (a_np .astype ("float32" ), b_np .astype ("float32" )).astype (out_dtype )
197+ c_np = np .dot (a_np .astype ("float32" ), b_np .astype ("float32" )).astype (out_dtype )
197198 else :
198199 a_np = np .random .randint (- 128 , 128 , (M , K )).astype ("int8" )
199200
200201 if b_transposed :
201- b_np = np .random .randint (- 128 , 128 , (N , K )).astype ("int8" ).transpose ()
202+ b_np = np .random .randint (- 128 , 128 , (N , K )).astype ("int8" )
203+ c_np = np .dot (a_np .astype ("float32" ), b_np .astype ("float32" ).transpose ()).astype (
204+ "int32"
205+ )
202206 else :
203207 b_np = np .random .randint (- 128 , 128 , (K , N )).astype ("int8" )
204-
205- c_np = np .dot (a_np .astype ("float32" ), b_np .astype ("float32" )).astype ("int32" )
208+ c_np = np .dot (a_np .astype ("float32" ), b_np .astype ("float32" )).astype ("int32" )
206209
207210 a = tvm .nd .array (a_np , dev )
208211 b = tvm .nd .array (b_np , dev )
209212 c = tvm .nd .array (np .zeros ((M , N ), dtype = out_dtype ), dev )
210213
211214 f (a , b , c )
212- tvm .testing .assert_allclose (c .numpy (), c_np , rtol = 1e-3 )
215+
216+ if out_dtype != "float16" :
217+ # The numpy reference is computed with fp32 precision (otherwise too slow).
218+ # So there is non-trivial accuracy difference if TVM result is computed with fp16 accumulation.
219+ tvm .testing .assert_allclose (c .numpy (), c_np , rtol = 1e-3 )
213220
214221 return lambda : f .time_evaluator (f .entry_name , dev , number = 500 )(a , b , c )
215222
@@ -372,7 +379,7 @@ def index_map_C(i, j):
372379 )
373380
374381 if measure_perf :
375- print ("f16f16f32_m16n16k16 : %f GFLOPS " % (gflops / (timer ().mean )))
382+ print ("i8i8i32_m16n16k32 : %f GOPS " % (gflops / (timer ().mean )))
376383
377384 timer = run_test (
378385 k_inner ,
@@ -393,7 +400,7 @@ def index_map_C(i, j):
393400 )
394401
395402 if measure_perf :
396- print ("f16f16f32_m16n16k16_trans : %f GFLOPS " % (gflops / (timer ().mean )))
403+ print ("i8i8i32_m16n16k32_trans : %f GOPS " % (gflops / (timer ().mean )))
397404
398405
399406if __name__ == "__main__" :
0 commit comments