Skip to content

Commit 6c8bd21

Browse files
committed
[Relay] Fix a bug in tensor_array_scatter
tensor_array_scatter constructs helper functions according to dtype and shape of element. When there are multiple scatter operations with same dtype and element shape but different indicies_shape, there will be name conflict in prelude.
1 parent 0469a77 commit 6c8bd21

File tree

2 files changed

+25
-9
lines changed

2 files changed

+25
-9
lines changed

python/tvm/relay/prelude.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,12 @@ def define_tensor_array_scatter(self, indices_shape=None, force_update=False):
383383
return
384384

385385
tensor_array_scatter_helper_name = self.get_name("tensor_array_scatter_helper")
386+
387+
if indices_shape:
388+
# Put indices_shape into variable name
389+
tensor_array_scatter_name += "_" + str(indices_shape)
390+
tensor_array_scatter_helper_name += "_" + str(indices_shape)
391+
386392
tensor_array_scatter_helper_var = self._create_global_var(tensor_array_scatter_helper_name)
387393
ta = Var("ta", self.list(self.tensor_type_var()))
388394
current = Var("current", scalar_type("int32"))

tests/python/frontend/tensorflow/test_forward.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1388,19 +1388,29 @@ def run(dtype_str, infer_shape):
13881388
element_shape = tf.TensorShape([tf.Dimension(None)])
13891389
else:
13901390
element_shape = None
1391-
t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(dtype_str), dtype=dtype)
1392-
indices = tf.constant([2, 1, 0])
1393-
ta1 = tf.TensorArray(
1394-
dtype=dtype, size=3, infer_shape=infer_shape, element_shape=element_shape
1395-
)
1396-
ta2 = ta1.scatter(indices, t)
1397-
out0 = ta2.read(0)
1398-
out1 = ta2.read(1)
1399-
out2 = ta2.read(2)
1391+
ta0 = _construct_scatter(dtype, dtype_str, element_shape, infer_shape, 3)
1392+
out0 = ta0.read(0)
1393+
out1 = ta0.read(1)
1394+
out2 = ta0.read(2)
1395+
ta1 = _construct_scatter(dtype, dtype_str, element_shape, infer_shape, 4)
1396+
out4 = ta1.read(0)
14001397
g = tf.get_default_graph()
14011398
compare_tf_with_tvm([], [], ["TensorArrayReadV3:0"], mode="vm")
14021399
compare_tf_with_tvm([], [], ["TensorArrayReadV3_1:0"], mode="vm")
14031400
compare_tf_with_tvm([], [], ["TensorArrayReadV3_2:0"], mode="vm")
1401+
compare_tf_with_tvm([], [], ["TensorArrayReadV3_2:0", out4.name], mode="vm")
1402+
1403+
def _construct_scatter(dtype, dtype_str, element_shape, infer_shape, size):
1404+
arr = [[float(i)] for i in range(size)]
1405+
indices_arr = [i for i in range(size - 1, -1, -1)]
1406+
1407+
t = tf.constant(np.array(arr).astype(dtype_str), dtype=dtype)
1408+
indices = tf.constant(indices_arr)
1409+
ta1 = tf.TensorArray(
1410+
dtype=dtype, size=size, infer_shape=infer_shape, element_shape=element_shape
1411+
)
1412+
ta2 = ta1.scatter(indices, t)
1413+
return ta2
14041414

14051415
for dtype in ["float32", "int8"]:
14061416
run(dtype, False)

0 commit comments

Comments
 (0)