Skip to content

Commit 5b028dc

Browse files
masahiMikael Sevenier
authored andcommitted
[TOPI] Fix batch_matmul tensorcore legalize for transpose_b = False case (apache#13618)
* fixed tensor core batch_matmul legalize for transpose_b = False case * add test * clean up
1 parent f53dc45 commit 5b028dc

File tree

2 files changed

+63
-12
lines changed

2 files changed

+63
-12
lines changed

python/tvm/topi/cuda/tensorcore_alter_op.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,22 @@ def _batch_matmul_legalize(attrs, inputs, arg_types):
4848
x_tensor, y_tensor = arg_types[0], arg_types[1]
4949
dtype = x_tensor.dtype
5050

51+
if attrs.transpose_a:
52+
B, K, M = x_tensor.shape
53+
else:
54+
B, M, K = x_tensor.shape
55+
56+
if attrs.transpose_b:
57+
B, N, K = y_tensor.shape
58+
else:
59+
B, K, N = y_tensor.shape
60+
5161
# Collect the output tensor.
5262
output_tensor = arg_types[2]
5363

5464
# Collect the input exprs.
5565
x, y = inputs
5666

57-
B, M, K = x_tensor.shape
58-
B, N, K = y_tensor.shape
5967
if (
6068
isinstance(B, tir.expr.Any)
6169
or isinstance(M, tir.expr.Any)
@@ -96,9 +104,23 @@ def _batch_matmul_legalize(attrs, inputs, arg_types):
96104
return None
97105

98106
logger.info("batch_matmul pad_to_tensorcore, extra_flops %s", extra_flops)
99-
x_ = relay.nn.pad(x, pad_width=((0, 0), (0, dm), (0, dk))) if dm or dk else x
100-
y_ = relay.nn.pad(y, pad_width=((0, 0), (0, dn), (0, dk))) if dn or dk else y
101-
out_ = relay.nn.batch_matmul(x_, y_, attrs.out_dtype)
107+
108+
if attrs.transpose_a:
109+
pad_width = ((0, 0), (0, dk), (0, dm))
110+
else:
111+
pad_width = ((0, 0), (0, dm), (0, dk))
112+
113+
x_ = relay.nn.pad(x, pad_width=pad_width) if dm or dk else x
114+
115+
if attrs.transpose_b:
116+
pad_width = ((0, 0), (0, dn), (0, dk))
117+
else:
118+
pad_width = ((0, 0), (0, dk), (0, dn))
119+
120+
y_ = relay.nn.pad(y, pad_width=pad_width) if dn or dk else y
121+
122+
out_ = relay.nn.batch_matmul(x_, y_, **attrs)
123+
102124
out = (
103125
relay.strided_slice(out_, begin=[0, 0, 0], end=[x.value for x in output_tensor.shape])
104126
if dm or dn

tests/python/relay/test_pass_legalize_tensorcore.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -277,17 +277,27 @@ def expected():
277277

278278
@tvm.testing.uses_gpu
279279
def test_legalize_batch_matmul():
280-
def _test_legalize_batch_matmul(data_shape, kernel_shape, pad_shape, dtype, do_pad=True):
280+
def _test_legalize_batch_matmul(
281+
data_shape, kernel_shape, pad_shape, dtype, do_pad=True, transpose_a=False, transpose_b=True
282+
):
281283
"""test legalize dense to enable tensorcore"""
282-
B, M, _ = data_shape
283-
_, N, _ = kernel_shape
284+
if transpose_a:
285+
B, _, M = data_shape
286+
else:
287+
B, M, _ = data_shape
288+
289+
if transpose_b:
290+
_, N, _ = kernel_shape
291+
else:
292+
_, _, N = kernel_shape
293+
284294
out_shape = (B, M, N)
285295
dm, dk, dn = pad_shape
286296

287297
def before():
288298
x = relay.var("x", shape=data_shape, dtype=dtype)
289299
weight = relay.var("weight", shape=kernel_shape, dtype=dtype)
290-
y = relay.nn.batch_matmul(x, weight)
300+
y = relay.nn.batch_matmul(x, weight, transpose_a=transpose_a, transpose_b=transpose_b)
291301
y = relay.Function([x, weight], y)
292302
return y
293303

@@ -298,19 +308,31 @@ def legalize_batch_matmul(attrs, inputs, types):
298308
def expected():
299309
if not do_pad:
300310
return before()
311+
301312
x = relay.var("x", shape=data_shape, dtype=dtype)
313+
weight = relay.var("weight", shape=(kernel_shape), dtype=dtype)
314+
302315
if dm or dk:
303-
x_pad = relay.nn.pad(x, pad_width=((0, 0), (0, dm), (0, dk)))
316+
if transpose_a:
317+
x_pad = relay.nn.pad(x, pad_width=((0, 0), (0, dk), (0, dm)))
318+
else:
319+
x_pad = relay.nn.pad(x, pad_width=((0, 0), (0, dm), (0, dk)))
304320
else:
305321
x_pad = x
306-
weight = relay.var("weight", shape=(kernel_shape), dtype=dtype)
322+
307323
if dn or dk:
308-
weight_pad = relay.nn.pad(weight, pad_width=((0, 0), (0, dn), (0, dk)))
324+
if transpose_b:
325+
weight_pad = relay.nn.pad(weight, pad_width=((0, 0), (0, dn), (0, dk)))
326+
else:
327+
weight_pad = relay.nn.pad(weight, pad_width=((0, 0), (0, dk), (0, dn)))
309328
else:
310329
weight_pad = weight
330+
311331
y_pad = relay.nn.batch_matmul(
312332
x_pad,
313333
weight_pad,
334+
transpose_a=transpose_a,
335+
transpose_b=transpose_b,
314336
)
315337
if dm or dn:
316338
y = relay.strided_slice(y_pad, begin=[0, 0, 0], end=out_shape)
@@ -343,6 +365,13 @@ def expected():
343365
_test_legalize_batch_matmul((16, 8, 16), (16, 32, 16), (0, 16, 0), "int4")
344366
_test_legalize_batch_matmul((16, 2, 16), (16, 32, 16), (0, 0, 0), "int4", False)
345367

368+
_test_legalize_batch_matmul(
369+
(16, 8, 16), (16, 16, 32), (0, 0, 0), "float16", False, transpose_b=False
370+
)
371+
_test_legalize_batch_matmul(
372+
(16, 16, 8), (16, 32, 16), (0, 0, 0), "float16", False, transpose_a=True
373+
)
374+
346375

347376
if __name__ == "__main__":
348377
test_legalize_conv2d_NHWC()

0 commit comments

Comments
 (0)