@@ -2330,19 +2330,49 @@ def one_hot(self, inputs, input_types):
23302330
23312331 def index (self , inputs , input_types ):
23322332 data = inputs [0 ]
2333+ data_shape = self .infer_type (data ).shape
2334+
2335+ axes_adv_idx = [i for i , v in enumerate (inputs [1 ]) if v is not None ]
2336+ axes_rest = [i for i in range (len (data_shape )) if i not in axes_adv_idx ]
2337+
2338+ # check if the adv_index axes are consecutive
2339+ # if consecutive, result must be transposed again at the end
2340+ consecutive = True
2341+ for curr , nxt in zip (axes_adv_idx [:- 1 ], axes_adv_idx [1 :]):
2342+ if nxt - curr != 1 :
2343+ consecutive = False
2344+ break
2345+
23332346 indices_list = []
2347+ axes_order = axes_adv_idx + axes_rest
23342348
2335- for indices in inputs [1 ]:
2336- if self .infer_type (indices ).dtype == "bool" :
2349+ for i in axes_adv_idx :
2350+ inp = inputs [1 ][i ]
2351+ if self .infer_type (inp ).dtype == "bool" :
23372352 # adv_index does not support a mask as the index tensor (it will treat 0/1 as
23382353 # an index rather than a flag).
23392354 # So we use argwhere to turn the mask into indices, which will also take care
23402355 # of the dynamism in the indexing by mask.
2341- indices_list .append (_op .squeeze (_op .transform .argwhere (indices ), axis = [1 ]))
2356+ indices_list .append (_op .squeeze (_op .transform .argwhere (inp ), axis = [1 ]))
23422357 else :
2343- indices_list .append (indices )
2358+ indices_list .append (inp )
2359+
2360+ data_after_adv_index = _op .adv_index ([_op .transpose (data , axes = axes_order )] + indices_list )
23442361
2345- return _op .adv_index ([data ] + indices_list )
2362+ if consecutive :
2363+ num_dims = len (self .infer_type (data_after_adv_index ).shape )
2364+ num_new_dims = num_dims - len (axes_rest )
2365+
2366+ axes_final_order = list (range (num_dims ))
2367+ axes_final_order = (
2368+ axes_final_order [num_new_dims : num_new_dims + axes_adv_idx [0 ]]
2369+ + axes_final_order [:num_new_dims ]
2370+ + axes_final_order [num_new_dims + axes_adv_idx [0 ] :]
2371+ )
2372+
2373+ return _op .transpose (data_after_adv_index , axes = axes_final_order )
2374+ else :
2375+ return data_after_adv_index
23462376
23472377 def meshgrid (self , inputs , input_types ):
23482378 data = inputs [0 ]
0 commit comments