Skip to content

Commit

Permalink
[Frontend,TOPI] Improve dynamism for BatchMatmul and Dense (apache#7496)
Browse files Browse the repository at this point in the history
* [TOPI] Dense cuda schedule support dynamic dimension

* [TOPI] batch_matmul cublas te computation support dynamism

* [Frontend] tensorflow frontend: dynamic support for BatchMatmul

* [TOPI] nn batch_matmul te computation support dynamism

* fix CI

* Update python/tvm/topi/nn/batch_matmul.py

Co-authored-by: Cody Yu <[email protected]>

* Update python/tvm/topi/cuda/batch_matmul.py

Co-authored-by: Cody Yu <[email protected]>

* remove concat_dynamic_shape function

* update topi dense op integer checking

* fix ci

* Update python/tvm/relay/frontend/tensorflow.py

Co-authored-by: Cody Yu <[email protected]>

* Update batch_matmul.py

* [Frontend] add test for batch_matmul in dynamic shaped case

Co-authored-by: Cody Yu <[email protected]>
  • Loading branch information
2 people authored and trevor-m committed May 11, 2021
1 parent f37b7ab commit 8fd9c44
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 20 deletions.
54 changes: 45 additions & 9 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,17 @@
__all__ = ["from_tensorflow"]


def check_symbolic_shape(shape):
return not all([isinstance(dim, (int, tvm.tir.IntImm)) for dim in shape])


def list_shape_of(tensor, ndim):
shape_tensor = _op.shape_of(tensor)
return [
_op.strided_slice(shape_tensor, begin=[i], end=[i + 1], strides=[1]) for i in range(ndim)
]


def _get_pad_pair(input1d, kernel1d, stride1d):
if input1d % stride1d == 0:
pad = max(kernel1d - stride1d, 0)
Expand Down Expand Up @@ -1022,13 +1033,31 @@ def _impl(inputs, attr, params, mod):
input_y = inputs[1]
orig_shape_x = _infer_shape(input_x, mod)
orig_shape_y = _infer_shape(input_y, mod)
ndim = len(orig_shape_x)

is_static = not check_symbolic_shape(orig_shape_x)

if ndim > 3 and not is_static:
shape_of_x = list_shape_of(inputs[0], ndim)
shape_of_y = list_shape_of(inputs[1], ndim)

# reshape n-dimensional batch matmul into 3d
if len(orig_shape_x) > 3:
if ndim > 3:
outer_dims = [orig_shape_x[i] for i in range(0, len(orig_shape_x) - 2)]
num_outer_elts = np.prod(outer_dims)
new_shape_x = (num_outer_elts, orig_shape_x[-2], orig_shape_x[-1])
new_shape_y = (num_outer_elts, orig_shape_y[-2], orig_shape_y[-1])
if is_static:
num_outer_elts = np.prod(outer_dims)
new_shape_x = (num_outer_elts, orig_shape_x[-2], orig_shape_x[-1])
new_shape_y = (num_outer_elts, orig_shape_y[-2], orig_shape_y[-1])
else: # handle dynamic shape (dyn.reshape op)
# new shape = [prod(shape[:-2]), -2, -1]
new_shape_x = [_op.const(1), shape_of_x[-2], shape_of_x[-1]]
new_shape_y = [_op.const(1), shape_of_y[-2], shape_of_y[-1]]
for i in range(ndim - 2):
new_shape_x[0] *= shape_of_x[i]
new_shape_y[0] *= shape_of_y[i]
new_shape_x = _op.concatenate(_op.Tuple(new_shape_x), axis=0)
new_shape_y = _op.concatenate(_op.Tuple(new_shape_y), axis=0)

input_x = _op.reshape(input_x, newshape=new_shape_x)
input_y = _op.reshape(input_y, newshape=new_shape_y)

Expand All @@ -1039,12 +1068,19 @@ def _impl(inputs, attr, params, mod):
ret = get_relay_op("batch_matmul")(input_x, input_y)

# reshape result back to n-dimensional
if len(orig_shape_x) > 3:
final_shape = list(orig_shape_x)
final_shape[-2] = orig_shape_x[-1] if adj_x else orig_shape_x[-2]
final_shape[-1] = orig_shape_y[-2] if adj_y else orig_shape_y[-1]
ret = _op.reshape(ret, newshape=final_shape)
if ndim > 3:
if is_static:
final_shape = list(orig_shape_x)
final_shape[-2] = orig_shape_x[-1] if adj_x else orig_shape_x[-2]
final_shape[-1] = orig_shape_y[-2] if adj_y else orig_shape_y[-1]
else:
# calculate the resulting shape = [shape[:-2], 0, 0]
final_shape = list(shape_of_x)
final_shape[-2] = shape_of_x[-1] if adj_x else shape_of_x[-2]
final_shape[-1] = shape_of_y[-2] if adj_y else shape_of_y[-1]
final_shape = _op.concatenate(_op.Tuple(final_shape), axis=0)

ret = _op.reshape(ret, newshape=final_shape)
return ret

return _impl
Expand Down
7 changes: 4 additions & 3 deletions python/tvm/topi/cuda/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,10 @@ def batch_matmul_cublas(cfg, x, y, out_shape=None):
output : tvm.te.Tensor
3-D with shape [batch, M, N]
"""
b, m, k = x.shape
b, n, k = y.shape
cfg.add_flop(b * m * k * n * 2)
b, m, k = get_const_tuple(x.shape)
b, n, k = get_const_tuple(y.shape)
if all([isinstance(s, int) for s in [b, m, n, k]]):
cfg.add_flop(b * m * k * n * 2)
return cublas.batch_matmul(x, y, False, True)


Expand Down
11 changes: 4 additions & 7 deletions python/tvm/topi/cuda/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# pylint: disable=invalid-name, unused-argument
"""Schedule for dense operator"""
import logging
from tvm import te, tir
from tvm import te
import tvm.autotvm as autotvm
from tvm.autotvm.task.space import SplitEntity
from tvm.contrib import cublas
Expand All @@ -39,14 +39,11 @@ def dense_cublas(cfg, data, weight, bias=None, out_dtype=None):
if out_dtype is None:
out_dtype = data.dtype
assert out_dtype == data.dtype, "Mixed precision not supported."
batch, in_dim = data.shape
out_dim, _ = weight.shape
batch, in_dim = get_const_tuple(data.shape)
out_dim, _ = get_const_tuple(weight.shape)
matmul = cublas.matmul(data, weight, False, True)
if isinstance(batch, int):
if all(isinstance(d, int) for d in [batch, in_dim, out_dim]):
cfg.add_flop(batch * in_dim * out_dim * 2)
elif isinstance(batch, tir.IntImm):
cfg.add_flop(batch.value * in_dim * out_dim * 2)
# if we get a te.Var, we cannot add flop counts
if bias is not None:
matmul = te.compute(
(batch, out_dim), lambda i, j: matmul[i, j] + bias[j], tag=tag.BROADCAST
Expand Down
52 changes: 51 additions & 1 deletion tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def compare_tf_with_tvm(
mode="graph_runtime",
cuda_layout="NCHW",
add_shapes_to_graph_def=True,
targets=None,
):
"""Generic function to generate and compare tensorflow and TVM output"""

Expand All @@ -233,13 +234,18 @@ def name_without_num(name):

tf_output = run_tf_graph(sess, in_data, in_name, out_name)

for device in ["llvm", "cuda"]:
devices = targets if targets else ["llvm", "cuda"]

for device in devices:
ctx = tvm.context(device, 0)
if not tvm.testing.device_enabled(device):
print("Skip because %s is not enabled" % device)
continue
if no_gpu and device == "cuda":
continue
if "cublas" in device and not tvm.get_global_func("tvm.contrib.cublas.matmul", True):
print("Skip because cublas is not enabled: %s" % device)
continue

tvm_output = run_tvm_graph(
final_graph_def,
Expand Down Expand Up @@ -1796,6 +1802,23 @@ def _test_batch_matmul(A_shape, B_shape, dtype, adjoint_a=False, adjoint_b=False
compare_tf_with_tvm([A_np, B_np], [A.name, B.name], result.name)


def _test_batch_matmul_dynamic(
A_shape, B_shape, A_np_shape, B_np_shape, dtype, adjoint_a=False, adjoint_b=False
):
with tf.Graph().as_default():
A = tf.placeholder(shape=A_shape, dtype=dtype, name="A")
B = tf.placeholder(shape=B_shape, dtype=dtype, name="B")
result = tf.matmul(A, B, adjoint_a=adjoint_a, adjoint_b=adjoint_b, name="batchmatmul")

A_np = np.random.uniform(high=5.0, size=A_np_shape).astype(dtype)
B_np = np.random.uniform(high=5.0, size=B_np_shape).astype(dtype)
# for now, in TOPI, only cublas's implementation support dynamic shape
# TODO add more backends support in TOPI
compare_tf_with_tvm(
[A_np, B_np], [A.name, B.name], result.name, mode="vm", targets=["cuda -libs=cublas"]
)


def test_forward_batch_matmul():
""" TF op BatchMatMul, BatchMatMulV2 test"""
_test_batch_matmul((3, 5, 4), (3, 4, 5), "int32")
Expand All @@ -1808,6 +1831,33 @@ def test_forward_batch_matmul():
_test_batch_matmul((2, 3, 4, 2, 3, 4, 5, 6), (2, 3, 4, 2, 3, 4, 5, 6), "float32", False, True)


@tvm.testing.requires_cuda
def test_forward_batch_matmul_dynamic():
_test_batch_matmul_dynamic((None, 5, 4), (None, 4, 5), (3, 5, 4), (3, 4, 5), "int32")
_test_batch_matmul_dynamic(
(None, 5, 4), (None, 4, 5), (3, 5, 4), (3, 4, 5), "float32", True, True
)
_test_batch_matmul_dynamic(
(None, 5, 4), (None, 5, 4), (3, 5, 4), (3, 5, 4), "int32", True, False
)
_test_batch_matmul_dynamic(
(None, 5, 4), (None, 5, 4), (3, 5, 4), (3, 5, 4), "float32", False, True
)
_test_batch_matmul_dynamic(
(None, 4, 5, 6), (None, 4, 6, 5), (3, 4, 5, 6), (3, 4, 6, 5), "float32"
)
_test_batch_matmul_dynamic(
(None, None, 5, 6), (None, None, 6, 5), (3, 4, 5, 6), (3, 4, 6, 5), "float32"
)
_test_batch_matmul_dynamic(
(None, None, None, 5, 6),
(None, None, None, 6, 5),
(2, 3, 4, 5, 6),
(2, 3, 4, 6, 5),
"float32",
)


#######################################################################
# SparseTensorDenseMatMul
# ----------------------------------
Expand Down

0 comments on commit 8fd9c44

Please sign in to comment.