Skip to content

Commit a427efb

Browse files
authored
[Relay] Fix a bug in tensor_array_scatter (#6890)
* [Relay] Fix a bug in tensor_array_scatter tensor_array_scatter constructs helper functions according to dtype and shape of element. When there are multiple scatter operations with same dtype and element shape but different indicies_shape, there will be name conflict in prelude. * Refine get_name
1 parent 099ebaa commit a427efb

File tree

2 files changed

+44
-21
lines changed

2 files changed

+44
-21
lines changed

python/tvm/relay/prelude.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def get_tensor_array_shape(expr, dtype, prelude):
7373
return None
7474

7575

76-
def _get_name_static(canonical, dtype, shape, batch_dim=None):
76+
def _get_name_static(canonical, dtype, shape, batch_dim=None, extra_shapes=None):
7777
"""Get name for static shape tensor array op
7878
7979
By design, static ADT tensor in TVM has type name in the format
@@ -100,14 +100,12 @@ def _get_name_static(canonical, dtype, shape, batch_dim=None):
100100
name : String
101101
The tensor array op name
102102
"""
103-
dim_names = []
104-
for dim in shape:
105-
if isinstance(dim, Any):
106-
dim_names.append("any")
107-
else:
108-
dim_names.append(str(dim))
103+
shape_str = _to_str(shape)
109104

110-
shape_str = "_".join(dim_names)
105+
if extra_shapes is not None:
106+
for n, s in extra_shapes.items():
107+
extra_shape_str = "_{}_{}".format(n, _to_str(s))
108+
shape_str += extra_shape_str
111109

112110
if len(shape_str) == 0:
113111
shape_str = "scalar"
@@ -120,6 +118,16 @@ def _get_name_static(canonical, dtype, shape, batch_dim=None):
120118
return "{}_{}_batch{}_{}".format(canonical, dtype, str(batch_dim), shape_str)
121119

122120

121+
def _to_str(shape):
122+
dim_names = []
123+
for dim in shape:
124+
if isinstance(dim, Any):
125+
dim_names.append("any")
126+
else:
127+
dim_names.append(str(dim))
128+
return "_".join(dim_names)
129+
130+
123131
class StaticTensorArrayOps(object):
124132
"""Contains tensor array related ops for fixed rank tensor array"""
125133

@@ -131,9 +139,9 @@ def __init__(self, prelude, dtype, shape, batch_dim=None):
131139
self.batch_dim = batch_dim
132140
self.list, self.cons, self.nil = self.prelude.mod.get_type("List")
133141

134-
def get_name(self, canonical):
142+
def get_name(self, canonical, extra_shapes=None):
135143
"""Get name corresponding to the canonical name"""
136-
return _get_name_static(canonical, self.dtype, self.shape, self.batch_dim)
144+
return _get_name_static(canonical, self.dtype, self.shape, self.batch_dim, extra_shapes)
137145

138146
def get_global_var(self, canonical):
139147
"""Get global corresponding to the canonical name"""
@@ -408,11 +416,16 @@ def define_tensor_array_scatter(self, indices_shape=None, force_update=False):
408416
# When this operator has already been registered, only update
409417
# when force_update is set. This should be used only when we need to
410418
# redefine this op for static indices shape.
411-
tensor_array_scatter_name = self.get_name("tensor_array_scatter")
419+
420+
extra_shapes = {"indices": indices_shape} if indices_shape is not None else None
421+
tensor_array_scatter_name = self.get_name("tensor_array_scatter", extra_shapes)
412422
if hasattr(self.prelude, tensor_array_scatter_name) and not force_update:
413423
return
414424

415-
tensor_array_scatter_helper_name = self.get_name("tensor_array_scatter_helper")
425+
tensor_array_scatter_helper_name = self.get_name(
426+
"tensor_array_scatter_helper", extra_shapes
427+
)
428+
416429
tensor_array_scatter_helper_var = self._create_global_var(tensor_array_scatter_helper_name)
417430
ta = Var("ta", self.list(self.tensor_type_var()))
418431
current = Var("current", scalar_type("int32"))

tests/python/frontend/tensorflow/test_forward.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1537,19 +1537,29 @@ def run(dtype_str, infer_shape):
15371537
element_shape = tf.TensorShape([tf.Dimension(None)])
15381538
else:
15391539
element_shape = None
1540-
t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(dtype_str), dtype=dtype)
1541-
indices = tf.constant([2, 1, 0])
1542-
ta1 = tf.TensorArray(
1543-
dtype=dtype, size=3, infer_shape=infer_shape, element_shape=element_shape
1544-
)
1545-
ta2 = ta1.scatter(indices, t)
1546-
out0 = ta2.read(0)
1547-
out1 = ta2.read(1)
1548-
out2 = ta2.read(2)
1540+
ta0 = _construct_scatter(dtype, dtype_str, element_shape, infer_shape, 3)
1541+
out0 = ta0.read(0)
1542+
out1 = ta0.read(1)
1543+
out2 = ta0.read(2)
1544+
ta1 = _construct_scatter(dtype, dtype_str, element_shape, infer_shape, 4)
1545+
out4 = ta1.read(0)
15491546
g = tf.get_default_graph()
15501547
compare_tf_with_tvm([], [], ["TensorArrayReadV3:0"], mode="vm")
15511548
compare_tf_with_tvm([], [], ["TensorArrayReadV3_1:0"], mode="vm")
15521549
compare_tf_with_tvm([], [], ["TensorArrayReadV3_2:0"], mode="vm")
1550+
compare_tf_with_tvm([], [], ["TensorArrayReadV3_2:0", out4.name], mode="vm")
1551+
1552+
def _construct_scatter(dtype, dtype_str, element_shape, infer_shape, size):
1553+
arr = [[float(i)] for i in range(size)]
1554+
indices_arr = [i for i in range(size - 1, -1, -1)]
1555+
1556+
t = tf.constant(np.array(arr).astype(dtype_str), dtype=dtype)
1557+
indices = tf.constant(indices_arr)
1558+
ta1 = tf.TensorArray(
1559+
dtype=dtype, size=size, infer_shape=infer_shape, element_shape=element_shape
1560+
)
1561+
ta2 = ta1.scatter(indices, t)
1562+
return ta2
15531563

15541564
for dtype in ["float32", "int8"]:
15551565
run(dtype, False)

0 commit comments

Comments
 (0)