From 8fd9c445ea7c7debbe635613dafce5940aba65dd Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Thu, 18 Mar 2021 02:45:49 +0800 Subject: [PATCH] [Frontend,TOPI] Improve dynamism for BatchMatmul and Dense (#7496) * [TOPI] Dense cuda schedule support dynamic dimension * [TOPI] batch_matmul cublas te computation support dynamism * [Frontend] tensorflow frontend: dynamic support for BatchMatmul * [TOPI] nn batch_matmul te computation support dynamism * fix CI * Update python/tvm/topi/nn/batch_matmul.py Co-authored-by: Cody Yu * Update python/tvm/topi/cuda/batch_matmul.py Co-authored-by: Cody Yu * remove concat_dynamic_shape function * update topi dense op integer checking * fix ci * Update python/tvm/relay/frontend/tensorflow.py Co-authored-by: Cody Yu * Update batch_matmul.py * [Frontend] add test for batch_matmul in dynamic shaped case Co-authored-by: Cody Yu --- python/tvm/relay/frontend/tensorflow.py | 54 +++++++++++++++---- python/tvm/topi/cuda/batch_matmul.py | 7 +-- python/tvm/topi/cuda/dense.py | 11 ++-- .../frontend/tensorflow/test_forward.py | 52 +++++++++++++++++- 4 files changed, 104 insertions(+), 20 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index f56d187b6a63..1946223a50a4 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -44,6 +44,17 @@ __all__ = ["from_tensorflow"] +def check_symbolic_shape(shape): + return not all([isinstance(dim, (int, tvm.tir.IntImm)) for dim in shape]) + + +def list_shape_of(tensor, ndim): + shape_tensor = _op.shape_of(tensor) + return [ + _op.strided_slice(shape_tensor, begin=[i], end=[i + 1], strides=[1]) for i in range(ndim) + ] + + def _get_pad_pair(input1d, kernel1d, stride1d): if input1d % stride1d == 0: pad = max(kernel1d - stride1d, 0) @@ -1022,13 +1033,31 @@ def _impl(inputs, attr, params, mod): input_y = inputs[1] orig_shape_x = _infer_shape(input_x, mod) orig_shape_y = _infer_shape(input_y, mod) + ndim = len(orig_shape_x) + + is_static = not check_symbolic_shape(orig_shape_x) + + if ndim > 3 and not is_static: + shape_of_x = list_shape_of(inputs[0], ndim) + shape_of_y = list_shape_of(inputs[1], ndim) # reshape n-dimensional batch matmul into 3d - if len(orig_shape_x) > 3: + if ndim > 3: outer_dims = [orig_shape_x[i] for i in range(0, len(orig_shape_x) - 2)] - num_outer_elts = np.prod(outer_dims) - new_shape_x = (num_outer_elts, orig_shape_x[-2], orig_shape_x[-1]) - new_shape_y = (num_outer_elts, orig_shape_y[-2], orig_shape_y[-1]) + if is_static: + num_outer_elts = np.prod(outer_dims) + new_shape_x = (num_outer_elts, orig_shape_x[-2], orig_shape_x[-1]) + new_shape_y = (num_outer_elts, orig_shape_y[-2], orig_shape_y[-1]) + else: # handle dynamic shape (dyn.reshape op) + # new shape = [prod(shape[:-2]), -2, -1] + new_shape_x = [_op.const(1), shape_of_x[-2], shape_of_x[-1]] + new_shape_y = [_op.const(1), shape_of_y[-2], shape_of_y[-1]] + for i in range(ndim - 2): + new_shape_x[0] *= shape_of_x[i] + new_shape_y[0] *= shape_of_y[i] + new_shape_x = _op.concatenate(_op.Tuple(new_shape_x), axis=0) + new_shape_y = _op.concatenate(_op.Tuple(new_shape_y), axis=0) + input_x = _op.reshape(input_x, newshape=new_shape_x) input_y = _op.reshape(input_y, newshape=new_shape_y) @@ -1039,12 +1068,19 @@ def _impl(inputs, attr, params, mod): ret = get_relay_op("batch_matmul")(input_x, input_y) # reshape result back to n-dimensional - if len(orig_shape_x) > 3: - final_shape = list(orig_shape_x) - final_shape[-2] = orig_shape_x[-1] if adj_x else orig_shape_x[-2] - final_shape[-1] = orig_shape_y[-2] if adj_y else orig_shape_y[-1] - ret = _op.reshape(ret, newshape=final_shape) + if ndim > 3: + if is_static: + final_shape = list(orig_shape_x) + final_shape[-2] = orig_shape_x[-1] if adj_x else orig_shape_x[-2] + final_shape[-1] = orig_shape_y[-2] if adj_y else orig_shape_y[-1] + else: + # calculate the resulting shape = [shape[:-2], 0, 0] + final_shape = list(shape_of_x) + final_shape[-2] = shape_of_x[-1] if adj_x else shape_of_x[-2] + final_shape[-1] = shape_of_y[-2] if adj_y else shape_of_y[-1] + final_shape = _op.concatenate(_op.Tuple(final_shape), axis=0) + ret = _op.reshape(ret, newshape=final_shape) return ret return _impl diff --git a/python/tvm/topi/cuda/batch_matmul.py b/python/tvm/topi/cuda/batch_matmul.py index 006b866d6bad..04e484f526d2 100644 --- a/python/tvm/topi/cuda/batch_matmul.py +++ b/python/tvm/topi/cuda/batch_matmul.py @@ -159,9 +159,10 @@ def batch_matmul_cublas(cfg, x, y, out_shape=None): output : tvm.te.Tensor 3-D with shape [batch, M, N] """ - b, m, k = x.shape - b, n, k = y.shape - cfg.add_flop(b * m * k * n * 2) + b, m, k = get_const_tuple(x.shape) + b, n, k = get_const_tuple(y.shape) + if all([isinstance(s, int) for s in [b, m, n, k]]): + cfg.add_flop(b * m * k * n * 2) return cublas.batch_matmul(x, y, False, True) diff --git a/python/tvm/topi/cuda/dense.py b/python/tvm/topi/cuda/dense.py index ad4882ab09f2..8adc38b84b1b 100644 --- a/python/tvm/topi/cuda/dense.py +++ b/python/tvm/topi/cuda/dense.py @@ -17,7 +17,7 @@ # pylint: disable=invalid-name, unused-argument """Schedule for dense operator""" import logging -from tvm import te, tir +from tvm import te import tvm.autotvm as autotvm from tvm.autotvm.task.space import SplitEntity from tvm.contrib import cublas @@ -39,14 +39,11 @@ def dense_cublas(cfg, data, weight, bias=None, out_dtype=None): if out_dtype is None: out_dtype = data.dtype assert out_dtype == data.dtype, "Mixed precision not supported." - batch, in_dim = data.shape - out_dim, _ = weight.shape + batch, in_dim = get_const_tuple(data.shape) + out_dim, _ = get_const_tuple(weight.shape) matmul = cublas.matmul(data, weight, False, True) - if isinstance(batch, int): + if all(isinstance(d, int) for d in [batch, in_dim, out_dim]): cfg.add_flop(batch * in_dim * out_dim * 2) - elif isinstance(batch, tir.IntImm): - cfg.add_flop(batch.value * in_dim * out_dim * 2) - # if we get a te.Var, we cannot add flop counts if bias is not None: matmul = te.compute( (batch, out_dim), lambda i, j: matmul[i, j] + bias[j], tag=tag.BROADCAST diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 43eb3f803092..f22bd29d3f8f 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -210,6 +210,7 @@ def compare_tf_with_tvm( mode="graph_runtime", cuda_layout="NCHW", add_shapes_to_graph_def=True, + targets=None, ): """Generic function to generate and compare tensorflow and TVM output""" @@ -233,13 +234,18 @@ def name_without_num(name): tf_output = run_tf_graph(sess, in_data, in_name, out_name) - for device in ["llvm", "cuda"]: + devices = targets if targets else ["llvm", "cuda"] + + for device in devices: ctx = tvm.context(device, 0) if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) continue if no_gpu and device == "cuda": continue + if "cublas" in device and not tvm.get_global_func("tvm.contrib.cublas.matmul", True): + print("Skip because cublas is not enabled: %s" % device) + continue tvm_output = run_tvm_graph( final_graph_def, @@ -1796,6 +1802,23 @@ def _test_batch_matmul(A_shape, B_shape, dtype, adjoint_a=False, adjoint_b=False compare_tf_with_tvm([A_np, B_np], [A.name, B.name], result.name) +def _test_batch_matmul_dynamic( + A_shape, B_shape, A_np_shape, B_np_shape, dtype, adjoint_a=False, adjoint_b=False +): + with tf.Graph().as_default(): + A = tf.placeholder(shape=A_shape, dtype=dtype, name="A") + B = tf.placeholder(shape=B_shape, dtype=dtype, name="B") + result = tf.matmul(A, B, adjoint_a=adjoint_a, adjoint_b=adjoint_b, name="batchmatmul") + + A_np = np.random.uniform(high=5.0, size=A_np_shape).astype(dtype) + B_np = np.random.uniform(high=5.0, size=B_np_shape).astype(dtype) + # for now, in TOPI, only cublas's implementation support dynamic shape + # TODO add more backends support in TOPI + compare_tf_with_tvm( + [A_np, B_np], [A.name, B.name], result.name, mode="vm", targets=["cuda -libs=cublas"] + ) + + def test_forward_batch_matmul(): """ TF op BatchMatMul, BatchMatMulV2 test""" _test_batch_matmul((3, 5, 4), (3, 4, 5), "int32") @@ -1808,6 +1831,33 @@ def test_forward_batch_matmul(): _test_batch_matmul((2, 3, 4, 2, 3, 4, 5, 6), (2, 3, 4, 2, 3, 4, 5, 6), "float32", False, True) +@tvm.testing.requires_cuda +def test_forward_batch_matmul_dynamic(): + _test_batch_matmul_dynamic((None, 5, 4), (None, 4, 5), (3, 5, 4), (3, 4, 5), "int32") + _test_batch_matmul_dynamic( + (None, 5, 4), (None, 4, 5), (3, 5, 4), (3, 4, 5), "float32", True, True + ) + _test_batch_matmul_dynamic( + (None, 5, 4), (None, 5, 4), (3, 5, 4), (3, 5, 4), "int32", True, False + ) + _test_batch_matmul_dynamic( + (None, 5, 4), (None, 5, 4), (3, 5, 4), (3, 5, 4), "float32", False, True + ) + _test_batch_matmul_dynamic( + (None, 4, 5, 6), (None, 4, 6, 5), (3, 4, 5, 6), (3, 4, 6, 5), "float32" + ) + _test_batch_matmul_dynamic( + (None, None, 5, 6), (None, None, 6, 5), (3, 4, 5, 6), (3, 4, 6, 5), "float32" + ) + _test_batch_matmul_dynamic( + (None, None, None, 5, 6), + (None, None, None, 6, 5), + (2, 3, 4, 5, 6), + (2, 3, 4, 6, 5), + "float32", + ) + + ####################################################################### # SparseTensorDenseMatMul # ----------------------------------