Skip to content

Commit 6b8c1b3

Browse files
committed
Refine get_name
1 parent 6c8bd21 commit 6b8c1b3

File tree

1 file changed

+24
-17
lines changed

1 file changed

+24
-17
lines changed

python/tvm/relay/prelude.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -73,17 +73,15 @@ def get_tensor_array_shape(expr, dtype, prelude):
7373
return None
7474

7575

76-
def _get_name_static(canonical, dtype, shape):
76+
def _get_name_static(canonical, dtype, shape, extra_shapes=None):
7777
"""Get name for static shape tensor array op corresponding
7878
to the canonical name"""
79-
dim_names = []
80-
for dim in shape:
81-
if isinstance(dim, Any):
82-
dim_names.append("any")
83-
else:
84-
dim_names.append(str(dim))
79+
shape_str = _to_str(shape)
8580

86-
shape_str = "_".join(dim_names)
81+
if extra_shapes is not None:
82+
for n, s in extra_shapes.items():
83+
extra_shape_str = "_{}_{}".format(n, _to_str(s))
84+
shape_str += extra_shape_str
8785

8886
if len(shape_str) == 0:
8987
shape_str = "scalar"
@@ -92,6 +90,16 @@ def _get_name_static(canonical, dtype, shape):
9290
return "{}_{}_{}".format(canonical, dtype, shape_str)
9391

9492

93+
def _to_str(shape):
94+
dim_names = []
95+
for dim in shape:
96+
if isinstance(dim, Any):
97+
dim_names.append("any")
98+
else:
99+
dim_names.append(str(dim))
100+
return "_".join(dim_names)
101+
102+
95103
class StaticTensorArrayOps(object):
96104
"""Contains tensor array related ops for fixed rank tensor array"""
97105

@@ -102,9 +110,9 @@ def __init__(self, prelude, dtype, shape):
102110
self.shape = shape
103111
self.list, self.cons, self.nil = self.prelude.mod.get_type("List")
104112

105-
def get_name(self, canonical):
113+
def get_name(self, canonical, extra_shapes=None):
106114
"""Get name corresponding to the canonical name"""
107-
return _get_name_static(canonical, self.dtype, self.shape)
115+
return _get_name_static(canonical, self.dtype, self.shape, extra_shapes)
108116

109117
def get_global_var(self, canonical):
110118
"""Get global corresponding to the canonical name"""
@@ -378,16 +386,15 @@ def define_tensor_array_scatter(self, indices_shape=None, force_update=False):
378386
# When this operator has already been registered, only update
379387
# when force_update is set. This should be used only when we need to
380388
# redefine this op for static indices shape.
381-
tensor_array_scatter_name = self.get_name("tensor_array_scatter")
389+
390+
extra_shapes = {"indices": indices_shape} if indices_shape is not None else None
391+
tensor_array_scatter_name = self.get_name("tensor_array_scatter", extra_shapes)
382392
if hasattr(self.prelude, tensor_array_scatter_name) and not force_update:
383393
return
384394

385-
tensor_array_scatter_helper_name = self.get_name("tensor_array_scatter_helper")
386-
387-
if indices_shape:
388-
# Put indices_shape into variable name
389-
tensor_array_scatter_name += "_" + str(indices_shape)
390-
tensor_array_scatter_helper_name += "_" + str(indices_shape)
395+
tensor_array_scatter_helper_name = self.get_name(
396+
"tensor_array_scatter_helper", extra_shapes
397+
)
391398

392399
tensor_array_scatter_helper_var = self._create_global_var(tensor_array_scatter_helper_name)
393400
ta = Var("ta", self.list(self.tensor_type_var()))

0 commit comments

Comments
 (0)