Skip to content

Commit cb0dabd

Browse files
committed
Refine get_name
1 parent 7424a66 commit cb0dabd

File tree

1 file changed

+24
-17
lines changed

1 file changed

+24
-17
lines changed

python/tvm/relay/prelude.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def get_tensor_array_shape(expr, dtype, prelude):
7373
return None
7474

7575

76-
def _get_name_static(canonical, dtype, shape, batch_dim=None):
76+
def _get_name_static(canonical, dtype, shape, batch_dim=None, extra_shapes=None):
7777
"""Get name for static shape tensor array op
7878
7979
By design, static ADT tensor in TVM has type name in the format
@@ -100,14 +100,12 @@ def _get_name_static(canonical, dtype, shape, batch_dim=None):
100100
name : String
101101
The tensor array op name
102102
"""
103-
dim_names = []
104-
for dim in shape:
105-
if isinstance(dim, Any):
106-
dim_names.append("any")
107-
else:
108-
dim_names.append(str(dim))
103+
shape_str = _to_str(shape)
109104

110-
shape_str = "_".join(dim_names)
105+
if extra_shapes is not None:
106+
for n, s in extra_shapes.items():
107+
extra_shape_str = "_{}_{}".format(n, _to_str(s))
108+
shape_str += extra_shape_str
111109

112110
if len(shape_str) == 0:
113111
shape_str = "scalar"
@@ -120,6 +118,16 @@ def _get_name_static(canonical, dtype, shape, batch_dim=None):
120118
return "{}_{}_batch{}_{}".format(canonical, dtype, str(batch_dim), shape_str)
121119

122120

121+
def _to_str(shape):
122+
dim_names = []
123+
for dim in shape:
124+
if isinstance(dim, Any):
125+
dim_names.append("any")
126+
else:
127+
dim_names.append(str(dim))
128+
return "_".join(dim_names)
129+
130+
123131
class StaticTensorArrayOps(object):
124132
"""Contains tensor array related ops for fixed rank tensor array"""
125133

@@ -131,9 +139,9 @@ def __init__(self, prelude, dtype, shape, batch_dim=None):
131139
self.batch_dim = batch_dim
132140
self.list, self.cons, self.nil = self.prelude.mod.get_type("List")
133141

134-
def get_name(self, canonical):
142+
def get_name(self, canonical, extra_shapes=None):
135143
"""Get name corresponding to the canonical name"""
136-
return _get_name_static(canonical, self.dtype, self.shape, self.batch_dim)
144+
return _get_name_static(canonical, self.dtype, self.shape, self.batch_dim, extra_shapes)
137145

138146
def get_global_var(self, canonical):
139147
"""Get global corresponding to the canonical name"""
@@ -408,16 +416,15 @@ def define_tensor_array_scatter(self, indices_shape=None, force_update=False):
408416
# When this operator has already been registered, only update
409417
# when force_update is set. This should be used only when we need to
410418
# redefine this op for static indices shape.
411-
tensor_array_scatter_name = self.get_name("tensor_array_scatter")
419+
420+
extra_shapes = {"indices": indices_shape} if indices_shape is not None else None
421+
tensor_array_scatter_name = self.get_name("tensor_array_scatter", extra_shapes)
412422
if hasattr(self.prelude, tensor_array_scatter_name) and not force_update:
413423
return
414424

415-
tensor_array_scatter_helper_name = self.get_name("tensor_array_scatter_helper")
416-
417-
if indices_shape:
418-
# Put indices_shape into variable name
419-
tensor_array_scatter_name += "_" + str(indices_shape)
420-
tensor_array_scatter_helper_name += "_" + str(indices_shape)
425+
tensor_array_scatter_helper_name = self.get_name(
426+
"tensor_array_scatter_helper", extra_shapes
427+
)
421428

422429
tensor_array_scatter_helper_var = self._create_global_var(tensor_array_scatter_helper_name)
423430
ta = Var("ta", self.list(self.tensor_type_var()))

0 commit comments

Comments
 (0)