@@ -73,7 +73,7 @@ def get_tensor_array_shape(expr, dtype, prelude):
7373 return None
7474
7575
76- def _get_name_static (canonical , dtype , shape , batch_dim = None ):
76+ def _get_name_static (canonical , dtype , shape , batch_dim = None , extra_shapes = None ):
7777 """Get name for static shape tensor array op
7878
7979 By design, static ADT tensor in TVM has type name in the format
@@ -100,14 +100,12 @@ def _get_name_static(canonical, dtype, shape, batch_dim=None):
100100 name : String
101101 The tensor array op name
102102 """
103- dim_names = []
104- for dim in shape :
105- if isinstance (dim , Any ):
106- dim_names .append ("any" )
107- else :
108- dim_names .append (str (dim ))
103+ shape_str = _to_str (shape )
109104
110- shape_str = "_" .join (dim_names )
105+ if extra_shapes is not None :
106+ for n , s in extra_shapes .items ():
107+ extra_shape_str = "_{}_{}" .format (n , _to_str (s ))
108+ shape_str += extra_shape_str
111109
112110 if len (shape_str ) == 0 :
113111 shape_str = "scalar"
@@ -120,6 +118,16 @@ def _get_name_static(canonical, dtype, shape, batch_dim=None):
120118 return "{}_{}_batch{}_{}" .format (canonical , dtype , str (batch_dim ), shape_str )
121119
122120
121+ def _to_str (shape ):
122+ dim_names = []
123+ for dim in shape :
124+ if isinstance (dim , Any ):
125+ dim_names .append ("any" )
126+ else :
127+ dim_names .append (str (dim ))
128+ return "_" .join (dim_names )
129+
130+
123131class StaticTensorArrayOps (object ):
124132 """Contains tensor array related ops for fixed rank tensor array"""
125133
@@ -131,9 +139,9 @@ def __init__(self, prelude, dtype, shape, batch_dim=None):
131139 self .batch_dim = batch_dim
132140 self .list , self .cons , self .nil = self .prelude .mod .get_type ("List" )
133141
134- def get_name (self , canonical ):
142+ def get_name (self , canonical , extra_shapes = None ):
135143 """Get name corresponding to the canonical name"""
136- return _get_name_static (canonical , self .dtype , self .shape , self .batch_dim )
144+ return _get_name_static (canonical , self .dtype , self .shape , self .batch_dim , extra_shapes )
137145
138146 def get_global_var (self , canonical ):
139147 """Get global corresponding to the canonical name"""
@@ -408,11 +416,16 @@ def define_tensor_array_scatter(self, indices_shape=None, force_update=False):
408416 # When this operator has already been registered, only update
409417 # when force_update is set. This should be used only when we need to
410418 # redefine this op for static indices shape.
411- tensor_array_scatter_name = self .get_name ("tensor_array_scatter" )
419+
420+ extra_shapes = {"indices" : indices_shape } if indices_shape is not None else None
421+ tensor_array_scatter_name = self .get_name ("tensor_array_scatter" , extra_shapes )
412422 if hasattr (self .prelude , tensor_array_scatter_name ) and not force_update :
413423 return
414424
415- tensor_array_scatter_helper_name = self .get_name ("tensor_array_scatter_helper" )
425+ tensor_array_scatter_helper_name = self .get_name (
426+ "tensor_array_scatter_helper" , extra_shapes
427+ )
428+
416429 tensor_array_scatter_helper_var = self ._create_global_var (tensor_array_scatter_helper_name )
417430 ta = Var ("ta" , self .list (self .tensor_type_var ()))
418431 current = Var ("current" , scalar_type ("int32" ))
0 commit comments