@@ -73,17 +73,15 @@ def get_tensor_array_shape(expr, dtype, prelude):
7373 return None
7474
7575
76- def _get_name_static (canonical , dtype , shape ):
76+ def _get_name_static (canonical , dtype , shape , extra_shapes = None ):
7777 """Get name for static shape tensor array op corresponding
7878 to the canonical name"""
79- dim_names = []
80- for dim in shape :
81- if isinstance (dim , Any ):
82- dim_names .append ("any" )
83- else :
84- dim_names .append (str (dim ))
79+ shape_str = _to_str (shape )
8580
86- shape_str = "_" .join (dim_names )
81+ if extra_shapes is not None :
82+ for n , s in extra_shapes .items ():
83+ extra_shape_str = "_{}_{}" .format (n , _to_str (s ))
84+ shape_str += extra_shape_str
8785
8886 if len (shape_str ) == 0 :
8987 shape_str = "scalar"
@@ -92,6 +90,16 @@ def _get_name_static(canonical, dtype, shape):
9290 return "{}_{}_{}" .format (canonical , dtype , shape_str )
9391
9492
93+ def _to_str (shape ):
94+ dim_names = []
95+ for dim in shape :
96+ if isinstance (dim , Any ):
97+ dim_names .append ("any" )
98+ else :
99+ dim_names .append (str (dim ))
100+ return "_" .join (dim_names )
101+
102+
95103class StaticTensorArrayOps (object ):
96104 """Contains tensor array related ops for fixed rank tensor array"""
97105
@@ -102,9 +110,9 @@ def __init__(self, prelude, dtype, shape):
102110 self .shape = shape
103111 self .list , self .cons , self .nil = self .prelude .mod .get_type ("List" )
104112
105- def get_name (self , canonical ):
113+ def get_name (self , canonical , extra_shapes = None ):
106114 """Get name corresponding to the canonical name"""
107- return _get_name_static (canonical , self .dtype , self .shape )
115+ return _get_name_static (canonical , self .dtype , self .shape , extra_shapes )
108116
109117 def get_global_var (self , canonical ):
110118 """Get global corresponding to the canonical name"""
@@ -378,16 +386,15 @@ def define_tensor_array_scatter(self, indices_shape=None, force_update=False):
378386 # When this operator has already been registered, only update
379387 # when force_update is set. This should be used only when we need to
380388 # redefine this op for static indices shape.
381- tensor_array_scatter_name = self .get_name ("tensor_array_scatter" )
389+
390+ extra_shapes = {"indices" : indices_shape } if indices_shape is not None else None
391+ tensor_array_scatter_name = self .get_name ("tensor_array_scatter" , extra_shapes )
382392 if hasattr (self .prelude , tensor_array_scatter_name ) and not force_update :
383393 return
384394
385- tensor_array_scatter_helper_name = self .get_name ("tensor_array_scatter_helper" )
386-
387- if indices_shape :
388- # Put indices_shape into variable name
389- tensor_array_scatter_name += "_" + str (indices_shape )
390- tensor_array_scatter_helper_name += "_" + str (indices_shape )
395+ tensor_array_scatter_helper_name = self .get_name (
396+ "tensor_array_scatter_helper" , extra_shapes
397+ )
391398
392399 tensor_array_scatter_helper_var = self ._create_global_var (tensor_array_scatter_helper_name )
393400 ta = Var ("ta" , self .list (self .tensor_type_var ()))
0 commit comments