@@ -1720,6 +1720,25 @@ def get_padding_offset(bsz, max_seq_len, seq_lens_this_time):
17201720 inputs = self .get_arg (api_config , 0 , "x" )
17211721 self .numpy_tensor = numpy .random .randint (0 ,inputs .shape [axis ], size = self .shape ).astype (self .dtype )
17221722
1723+ elif api_config .api_name in {"paddle.Tensor.index_put" , "paddle.index_put" }:
1724+ if self .check_arg (api_config ,1 ,'indices' ) and not self .get_arg (api_config , 3 , "accumulate" ):
1725+ # NOTE(zrr1999): If accumulate is False, the behavior is undefined if indices contain duplicate elements in torch.
1726+
1727+ inputs = self .get_arg (api_config , 0 , "x" )
1728+ value = self .get_arg (api_config , 2 , "value" )
1729+ inputs_numel = inputs .numel ()
1730+ value_numel = value .numel ()
1731+ if inputs_numel < value_numel :
1732+ raise ValueError (
1733+ f"Invalid input for paddle.index_put: inputs.numel() < value.numel() when accumulate=False. "
1734+ )
1735+ inputs_shape = inputs .shape
1736+ value_shape = value .shape
1737+
1738+ flat_indices = numpy .random .choice (inputs_numel , size = value_numel , replace = False )
1739+ indices = [index .reshape (value_shape ) for index in numpy .unravel_index (flat_indices , inputs_shape )]
1740+ self .numpy_tensor = indices .astype (self .dtype )
1741+
17231742 elif api_config .api_name == "paddle.Tensor.tile" :
17241743 if index == 1 or key == 'repeat_times' :
17251744 self .numpy_tensor = numpy .random .randint (1 ,128 , size = self .shape ).astype (self .dtype )
0 commit comments