Skip to content

Commit

Permalink
[TF frontend] add support for StridedSlice to input a single constant (
Browse files Browse the repository at this point in the history
…apache#6949)

* [TF frontend] add support for StridedSlice to input a single constant

* add test for strideslice with a single number input

* fix bug
  • Loading branch information
alter-xp authored and Tushar Dey committed Jan 20, 2021
1 parent f221a55 commit db0970b
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 31 deletions.
3 changes: 3 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1616,6 +1616,9 @@ def _impl(inputs, attr, params, mod):
data_shape = get_const_tuple(in_type.checked_type.shape)
data_dim = len(data_shape)
stride_dim = len(stride)
if data_dim == 0 and isinstance(inputs[0], _expr.Constant):
new_data = inputs[0].data.asnumpy().reshape(1)
return _expr.const(new_data, inputs[0].data.dtype)

# This is a special routine to handle strided_slice after shape_of.
# We need this since in some cases we want to do strided_slice on
Expand Down
71 changes: 40 additions & 31 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ def run_tvm_graph(
m = graph_runtime.create(graph, lib, ctx)
# set inputs
for e, i in zip(input_node, input_data):
m.set_input(e, tvm.nd.array(i))
if e != "":
m.set_input(e, tvm.nd.array(i))

m.set_input(**params)
# execute
Expand All @@ -192,8 +193,10 @@ def run_tf_graph(sess, input_data, input_node, output_node):
tensor = [sess.graph.get_tensor_by_name(output_name) for output_name in output_node]

input_dict = {e: input_data[i] for i, e in enumerate(input_node)}

output_data = sess.run(tensor, input_dict)
if len(input_node) == 1 and input_node[0] == "":
output_data = sess.run(tensor)
else:
output_data = sess.run(tensor, input_dict)
return output_data


Expand Down Expand Up @@ -1843,8 +1846,12 @@ def _test_stridedslice(
""" One iteration of a Stridedslice """

tf.reset_default_graph()
np_data = np.random.uniform(size=ip_shape).astype(dtype)
with tf.Graph().as_default():
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
if len(ip_shape) == 0:
in_data = tf.constant(np_data, dtype)
else:
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
tf.strided_slice(
in_data,
begin,
Expand All @@ -1857,56 +1864,58 @@ def _test_stridedslice(
ellipsis_mask=ellipsis_mask,
name="strided_slice",
)
np_data = np.random.uniform(size=ip_shape).astype(dtype)

compare_tf_with_tvm(np_data, "in_data:0", "strided_slice:0")
if len(ip_shape) == 0:
compare_tf_with_tvm(None, "", "strided_slice:0")
else:
compare_tf_with_tvm(np_data, "in_data:0", "strided_slice:0")


def test_forward_stridedslice():
"""test StridedSlice"""

_test_stridedslice((2), [1], [1], [1], "float32", shrink_axis_mask=1)
_test_stridedslice((2, 1), [0], [1], [1], "float32", shrink_axis_mask=1)
_test_stridedslice((2, 3, 4), [0], [1], [1], "float32", shrink_axis_mask=8)
_test_stridedslice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], "float32")
_test_stridedslice((3, 4, 3), [1, 0], [4, 3], [2, 1], "float32", ellipsis_mask=8)
_test_stridedslice((3, 4, 3), [1, 0], [4, 2], [2, 1], "float32", ellipsis_mask=2)
_test_stridedslice((3, 4, 5, 3), [1, 0], [4, 2], [2, 1], "float32", ellipsis_mask=2)
_test_stridedslice((3, 4, 5, 3), [1, 0, 1], [4, 2, 2], [2, 1, 1], "float32", ellipsis_mask=2)
_test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 2], [2, 1, 1], "float32", new_axis_mask=5)
_test_stridedslice([], [0], [0], [1], "float32", new_axis_mask=1)
_test_stridedslice([2], [1], [1], [1], "float32", shrink_axis_mask=1)
_test_stridedslice([2, 1], [0], [1], [1], "float32", shrink_axis_mask=1)
_test_stridedslice([2, 3, 4], [0], [1], [1], "float32", shrink_axis_mask=8)
_test_stridedslice([3, 4, 3], [1, -1, 0], [4, -5, 3], [2, -1, 1], "float32")
_test_stridedslice([3, 4, 3], [1, 0], [4, 3], [2, 1], "float32", ellipsis_mask=8)
_test_stridedslice([3, 4, 3], [1, 0], [4, 2], [2, 1], "float32", ellipsis_mask=2)
_test_stridedslice([3, 4, 5, 3], [1, 0], [4, 2], [2, 1], "float32", ellipsis_mask=2)
_test_stridedslice([3, 4, 5, 3], [1, 0, 1], [4, 2, 2], [2, 1, 1], "float32", ellipsis_mask=2)
_test_stridedslice([3, 4, 3], [1, 1, 0], [4, 4, 2], [2, 1, 1], "float32", new_axis_mask=5)
_test_stridedslice(
(3, 4, 3), [1, 1, 1], [4, 4, 1], [2, 1, 1], "float32", ellipsis_mask=2, new_axis_mask=4
[3, 4, 3], [1, 1, 1], [4, 4, 1], [2, 1, 1], "float32", ellipsis_mask=2, new_axis_mask=4
)
_test_stridedslice(
(6, 4, 5), [1, 1, 1], [6, 3, 4], [2, 1, 1], "float32", ellipsis_mask=2, new_axis_mask=5
[6, 4, 5], [1, 1, 1], [6, 3, 4], [2, 1, 1], "float32", ellipsis_mask=2, new_axis_mask=5
)
_test_stridedslice(
(3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], "float32", ellipsis_mask=4, new_axis_mask=2
[3, 4, 3], [1, 1, 2], [4, 4, 3], [2, 1, 1], "float32", ellipsis_mask=4, new_axis_mask=2
)
_test_stridedslice(
(3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], "float32", ellipsis_mask=2, new_axis_mask=3
[3, 4, 3], [1, 1, 2], [4, 4, 3], [2, 1, 1], "float32", ellipsis_mask=2, new_axis_mask=3
)
_test_stridedslice(
(3, 4, 3), [1, 1, 0], [4, 4, 1], [2, 1, 1], "float32", ellipsis_mask=2, new_axis_mask=3
[3, 4, 3], [1, 1, 0], [4, 4, 1], [2, 1, 1], "float32", ellipsis_mask=2, new_axis_mask=3
)
_test_stridedslice(
(3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], "float32", ellipsis_mask=2, new_axis_mask=2
[3, 4, 3], [1, 1, 2], [4, 4, 3], [2, 1, 1], "float32", ellipsis_mask=2, new_axis_mask=2
)
_test_stridedslice((3, 4), [1, 0], [4, 4], [1, 1], "float32", shrink_axis_mask=2)
_test_stridedslice(
(3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], "float32", shrink_axis_mask=2, new_axis_mask=2
[3, 4, 3], [1, 1, 0], [4, 4, 3], [2, 1, 1], "float32", shrink_axis_mask=2, new_axis_mask=2
)
_test_stridedslice(
(3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], "float32", shrink_axis_mask=1, new_axis_mask=2
[3, 4, 3], [1, 1, 0], [4, 4, 3], [2, 1, 1], "float32", shrink_axis_mask=1, new_axis_mask=2
)
_test_stridedslice(
(3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], "float32", shrink_axis_mask=2, new_axis_mask=1
[3, 4, 3], [1, 1, 0], [4, 4, 3], [2, 1, 1], "float32", shrink_axis_mask=2, new_axis_mask=1
)
_test_stridedslice(
(3, 4, 5, 4, 5, 6), [0, 0], [2, 3], [1, 1], "float32", shrink_axis_mask=5, new_axis_mask=1
[3, 4, 5, 4, 5, 6], [0, 0], [2, 3], [1, 1], "float32", shrink_axis_mask=5, new_axis_mask=1
)
_test_stridedslice(
(3, 4, 5, 4, 5, 6),
[3, 4, 5, 4, 5, 6],
[0, 0, 1, 2, 1],
[2, 3, 4, 5, 3],
[1, 1, 2, 2, 1],
Expand All @@ -1918,7 +1927,7 @@ def test_forward_stridedslice():
end_mask=8,
)
_test_stridedslice(
(3, 4, 5, 4, 5, 6),
[3, 4, 5, 4, 5, 6],
[0, 0, 1, 2, 1],
[2, 3, 4, 5, 3],
[1, 1, 2, 2, 1],
Expand All @@ -1930,7 +1939,7 @@ def test_forward_stridedslice():
end_mask=5,
)
_test_stridedslice(
(3, 4, 5, 4, 5, 6),
[3, 4, 5, 4, 5, 6],
[0, 0, 1, 2, 1],
[2, 3, 4, 5, 3],
[1, 1, 2, 2, 1],
Expand All @@ -1942,7 +1951,7 @@ def test_forward_stridedslice():
end_mask=5,
)
_test_stridedslice(
(3, 4, 5, 4, 5, 6),
[3, 4, 5, 4, 5, 6],
[1, 2, 0, -3],
[4, 5, 3, 3],
[2, 2, 1, 1],
Expand All @@ -1954,7 +1963,7 @@ def test_forward_stridedslice():
end_mask=8,
)
_test_stridedslice(
(1, 13, 13, 3, 2),
[1, 13, 13, 3, 2],
[0, 0],
[1, 1],
[1, -1],
Expand Down

0 comments on commit db0970b

Please sign in to comment.