Skip to content

Commit 5d7d8ef

Browse files
committed
fix index_put
1 parent 11fe163 commit 5d7d8ef

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

tester/api_config/config_analyzer.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1720,6 +1720,25 @@ def get_padding_offset(bsz, max_seq_len, seq_lens_this_time):
17201720
inputs=self.get_arg(api_config, 0, "x")
17211721
self.numpy_tensor = numpy.random.randint(0,inputs.shape[axis], size=self.shape).astype(self.dtype)
17221722

1723+
elif api_config.api_name in {"paddle.Tensor.index_put", "paddle.index_put"}:
1724+
if self.check_arg(api_config,1,'indices') and not self.get_arg(api_config, 3, "accumulate"):
1725+
# NOTE(zrr1999): If accumulate is False, the behavior is undefined if indices contain duplicate elements in torch.
1726+
1727+
inputs=self.get_arg(api_config, 0, "x")
1728+
value=self.get_arg(api_config, 2, "value")
1729+
inputs_numel = inputs.numel()
1730+
value_numel = value.numel()
1731+
if inputs_numel < value_numel:
1732+
raise ValueError(
1733+
f"Invalid input for paddle.index_put: inputs.numel() < value.numel() when accumulate=False. "
1734+
)
1735+
inputs_shape = inputs.shape
1736+
value_shape = value.shape
1737+
1738+
flat_indices = numpy.random.choice(inputs_numel, size=value_numel, replace=False)
1739+
indices = [index.reshape(value_shape) for index in numpy.unravel_index(flat_indices, inputs_shape)]
1740+
self.numpy_tensor = indices.astype(self.dtype)
1741+
17231742
elif api_config.api_name == "paddle.Tensor.tile":
17241743
if index==1 or key=='repeat_times':
17251744
self.numpy_tensor = numpy.random.randint(1,128, size=self.shape).astype(self.dtype)

0 commit comments

Comments
 (0)