@@ -256,53 +256,118 @@ def _handle_axis_arg(self, config_items, is_tuple=False):
256256 tensor_idx += 1
257257 return tuple (tmp ) if is_tuple else tmp
258258
259+ def _generate_int_indices (self , item_shape , dim_size ):
260+ num_elements = numpy .prod (item_shape ).item ()
261+ if num_elements > dim_size :
262+ indices_flat = numpy .random .randint (
263+ - dim_size , dim_size , size = num_elements
264+ )
265+ else :
266+ indices_flat = numpy .random .choice (
267+ dim_size , size = num_elements , replace = False
268+ )
269+ return indices_flat .reshape (item_shape )
270+
271+
272+ def _generate_constrained_bool_mask (self , shape , num_true ):
273+ mask_size = numpy .prod (shape ).item ()
274+ if mask_size < num_true :
275+ raise ValueError (
276+ f"Cannot generate a mask with { num_true } true values in a { mask_size } element mask"
277+ )
278+ mask_flat = numpy .zeros (mask_size , dtype = "bool" )
279+ true_indices = numpy .random .choice (mask_size , num_true , replace = False )
280+ mask_flat [true_indices ] = True
281+ return mask_flat .reshape (shape )
282+
283+ def _broadcast_or_raise (self , shapes ):
284+ return numpy .broadcast_shapes (* [tuple (s ) for s in shapes ])
285+
259286 def _handle_indices_arg (self , config_items , is_tuple = False ):
260- x = self .paddle_args_config [0 ] if len (self .paddle_args_config ) > 0 else self .paddle_kwargs_config ["x" ]
261- value = self .paddle_args_config [2 ] if len (self .paddle_args_config ) > 2 else self .paddle_kwargs_config ["value" ]
287+ x = (
288+ self .paddle_args_config [0 ]
289+ if len (self .paddle_args_config ) > 0
290+ else self .paddle_kwargs_config ["x" ]
291+ )
292+ value = (
293+ self .paddle_args_config [2 ]
294+ if len (self .paddle_args_config ) > 2
295+ else self .paddle_kwargs_config ["value" ]
296+ )
262297 x_shape = x .shape
263298 value_shape = value .shape
264-
265- tmp = []
266- matched_axis = 0
267- indices_shape_len = 0
299+ int_index_shapes = []
300+ has_bool_index = False
301+ dims_consumed = 0
268302 for item in config_items :
269- if item .dtype != "bool" :
270- matched_axis += 1
271- indices_shape_len = max (indices_shape_len , len (item .shape ))
272-
273- expected = indices_shape_len + len (x_shape ) - matched_axis
274- reduced = expected - len (value_shape )
275- x_shape_index = 0
276- value_shape_index = indices_shape_len
303+ if item .dtype == "bool" :
304+ b_rank = len (item .shape )
305+ has_bool_index = True
306+ dims_consumed += b_rank
307+ else :
308+ int_index_shapes .append (tuple (item .shape ))
309+ dims_consumed += 1
277310
311+ if dims_consumed > len (x_shape ):
312+ raise ValueError (
313+ f"Too many indices: consume { dims_consumed } dims but x has { len (x_shape )} dims"
314+ )
315+ num_true_needed = - 1
316+ num_remaining_dims = len (x_shape ) - dims_consumed
317+ advanced_shape = ()
318+ if int_index_shapes :
319+ try :
320+ advanced_shape = self ._broadcast_or_raise (int_index_shapes )
321+ # give a default 1
322+ if (
323+ has_bool_index
324+ and len (value_shape ) > num_remaining_dims
325+ and advanced_shape [- 1 ] == 1
326+ and value_shape [- num_remaining_dims - 1 ] != 1
327+ ):
328+ advanced_shape = (* advanced_shape [:- 1 ], value_shape [- num_remaining_dims - 1 ])
329+ num_true_needed = advanced_shape [- 1 ]
330+ except Exception :
331+ raise ValueError (
332+ f"Incompatible integer index shapes for broadcasting: { int_index_shapes } "
333+ )
334+ elif has_bool_index :
335+ if len (value_shape ) > num_remaining_dims :
336+ advanced_shape = (value_shape [0 ],)
337+ num_true_needed = value_shape [0 ]
338+ else :
339+ # give a default 1, other valid(not out of bound) shape also can.
340+ advanced_shape = (1 ,)
341+ num_true_needed = 1
342+ res_dims = advanced_shape + tuple (x_shape [dims_consumed :])
343+ try :
344+ # only for checking.
345+ self ._broadcast_or_raise ([value_shape , res_dims ])
346+ except ValueError :
347+ raise ValueError (
348+ f"Value shape { value_shape } cannot be broadcast to the indexed shape { res_dims } ."
349+ )
350+ processed_indices = []
351+ x_dim_cursor = 0
278352 for item in config_items :
279353 if item .dtype == "bool" :
280- true_needed = []
281- for i in range (len (item .shape )):
282- if reduced > 0 :
283- reduced -= 1
284- true_needed .append (1 )
285- else :
286- true_needed .append (value_shape [value_shape_index ])
287- value_shape_index += 1
288- for i in range (len (true_needed ) - 1 , 0 , - 1 ):
289- if true_needed [i ] > item .shape [i ]:
290- true_needed [i - 1 ] *= true_needed [i ] // item .shape [i ]
291- true_needed [i ] = item .shape [i ]
292- mask = numpy .zeros (item .shape , dtype = bool )
293- indices = [
294- numpy .random .choice (dim_size , size = needed , replace = False )
295- for dim_size , needed in zip (item .shape , true_needed )
296- ]
297- mask [numpy .ix_ (* indices )] = True
298- item .numpy_tensor = mask
299- x_shape_index += len (item .shape )
354+ if num_true_needed < 0 :
355+ raise ValueError (
356+ "Cannot determine the number of True elements for the boolean mask."
357+ )
358+ item .numpy_tensor = self ._generate_constrained_bool_mask (
359+ item .shape , num_true_needed
360+ )
361+ x_dim_cursor += len (item .shape )
300362 else :
301- x_dim = x_shape [x_shape_index ]
302- item .numpy_tensor = numpy .random .randint (- x_dim , x_dim , size = item .shape , dtype = item .dtype )
303- x_shape_index += 1
304- tmp .append (item .get_numpy_tensor (self .api_config ))
305- return tuple (tmp ) if is_tuple else tmp
363+ x_dim_to_index = x_shape [x_dim_cursor ]
364+ indices = self ._generate_int_indices (item .shape , x_dim_to_index )
365+ item .numpy_tensor = indices .astype (item .dtype )
366+ x_dim_cursor += 1
367+
368+ processed_indices .append (item .numpy_tensor )
369+
370+ return tuple (processed_indices ) if is_tuple else processed_indices
306371
307372 def gen_numpy_input (self ):
308373 for i , arg_config in enumerate (self .paddle_args_config ):
0 commit comments