@@ -277,17 +277,27 @@ def expected():
277277
278278@tvm .testing .uses_gpu
279279def test_legalize_batch_matmul ():
280- def _test_legalize_batch_matmul (data_shape , kernel_shape , pad_shape , dtype , do_pad = True ):
280+ def _test_legalize_batch_matmul (
281+ data_shape , kernel_shape , pad_shape , dtype , do_pad = True , transpose_a = False , transpose_b = True
282+ ):
281283 """test legalize dense to enable tensorcore"""
282- B , M , _ = data_shape
283- _ , N , _ = kernel_shape
284+ if transpose_a :
285+ B , _ , M = data_shape
286+ else :
287+ B , M , _ = data_shape
288+
289+ if transpose_b :
290+ _ , N , _ = kernel_shape
291+ else :
292+ _ , _ , N = kernel_shape
293+
284294 out_shape = (B , M , N )
285295 dm , dk , dn = pad_shape
286296
287297 def before ():
288298 x = relay .var ("x" , shape = data_shape , dtype = dtype )
289299 weight = relay .var ("weight" , shape = kernel_shape , dtype = dtype )
290- y = relay .nn .batch_matmul (x , weight )
300+ y = relay .nn .batch_matmul (x , weight , transpose_a = transpose_a , transpose_b = transpose_b )
291301 y = relay .Function ([x , weight ], y )
292302 return y
293303
@@ -298,19 +308,31 @@ def legalize_batch_matmul(attrs, inputs, types):
298308 def expected ():
299309 if not do_pad :
300310 return before ()
311+
301312 x = relay .var ("x" , shape = data_shape , dtype = dtype )
313+ weight = relay .var ("weight" , shape = (kernel_shape ), dtype = dtype )
314+
302315 if dm or dk :
303- x_pad = relay .nn .pad (x , pad_width = ((0 , 0 ), (0 , dm ), (0 , dk )))
316+ if transpose_a :
317+ x_pad = relay .nn .pad (x , pad_width = ((0 , 0 ), (0 , dk ), (0 , dm )))
318+ else :
319+ x_pad = relay .nn .pad (x , pad_width = ((0 , 0 ), (0 , dm ), (0 , dk )))
304320 else :
305321 x_pad = x
306- weight = relay . var ( "weight" , shape = ( kernel_shape ), dtype = dtype )
322+
307323 if dn or dk :
308- weight_pad = relay .nn .pad (weight , pad_width = ((0 , 0 ), (0 , dn ), (0 , dk )))
324+ if transpose_b :
325+ weight_pad = relay .nn .pad (weight , pad_width = ((0 , 0 ), (0 , dn ), (0 , dk )))
326+ else :
327+ weight_pad = relay .nn .pad (weight , pad_width = ((0 , 0 ), (0 , dk ), (0 , dn )))
309328 else :
310329 weight_pad = weight
330+
311331 y_pad = relay .nn .batch_matmul (
312332 x_pad ,
313333 weight_pad ,
334+ transpose_a = transpose_a ,
335+ transpose_b = transpose_b ,
314336 )
315337 if dm or dn :
316338 y = relay .strided_slice (y_pad , begin = [0 , 0 , 0 ], end = out_shape )
@@ -343,6 +365,13 @@ def expected():
343365 _test_legalize_batch_matmul ((16 , 8 , 16 ), (16 , 32 , 16 ), (0 , 16 , 0 ), "int4" )
344366 _test_legalize_batch_matmul ((16 , 2 , 16 ), (16 , 32 , 16 ), (0 , 0 , 0 ), "int4" , False )
345367
368+ _test_legalize_batch_matmul (
369+ (16 , 8 , 16 ), (16 , 16 , 32 ), (0 , 0 , 0 ), "float16" , False , transpose_b = False
370+ )
371+ _test_legalize_batch_matmul (
372+ (16 , 16 , 8 ), (16 , 32 , 16 ), (0 , 0 , 0 ), "float16" , False , transpose_a = True
373+ )
374+
346375
347376if __name__ == "__main__" :
348377 test_legalize_conv2d_NHWC ()
0 commit comments