@@ -1715,6 +1715,25 @@ def get_padding_offset(bsz, max_seq_len, seq_lens_this_time):
17151715 inputs = self .get_arg (api_config , 0 , "x" )
17161716 self .numpy_tensor = numpy .random .randint (0 ,inputs .shape [axis ], size = self .shape ).astype (self .dtype )
17171717
1718+ elif api_config .api_name in {"paddle.Tensor.index_put" , "paddle.index_put" }:
1719+ if self .check_arg (api_config ,1 ,'indices' ) and not self .get_arg (api_config , 3 , "accumulate" ):
1720+ # NOTE(zrr1999): If accumulate is False, the behavior is undefined if indices contain duplicate elements in torch.
1721+
1722+ inputs = self .get_arg (api_config , 0 , "x" )
1723+ value = self .get_arg (api_config , 2 , "value" )
1724+ inputs_numel = inputs .numel ()
1725+ value_numel = value .numel ()
1726+ if inputs_numel < value_numel :
1727+ raise ValueError (
1728+ f"Invalid input for paddle.index_put: inputs.numel() < value.numel() when accumulate=False. "
1729+ )
1730+ inputs_shape = inputs .shape
1731+ value_shape = value .shape
1732+
1733+ flat_indices = numpy .random .choice (inputs_numel , size = value_numel , replace = False )
1734+ indices = [index .reshape (value_shape ) for index in numpy .unravel_index (flat_indices , inputs_shape )]
1735+ self .numpy_tensor = indices .astype (self .dtype )
1736+
17181737 elif api_config .api_name == "paddle.Tensor.tile" :
17191738 if index == 1 or key == 'repeat_times' :
17201739 self .numpy_tensor = numpy .random .randint (1 ,128 , size = self .shape ).astype (self .dtype )
0 commit comments