Skip to content

Commit cba3865

Browse files
Merge pull request #510 from zrr1999/acc/index_put
[Accuracy diff No.98] Fix accuracy diff for paddle.index_put API
2 parents ea1dd17 + 5d7d8ef commit cba3865

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

tester/api_config/config_analyzer.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1715,6 +1715,25 @@ def get_padding_offset(bsz, max_seq_len, seq_lens_this_time):
17151715
inputs=self.get_arg(api_config, 0, "x")
17161716
self.numpy_tensor = numpy.random.randint(0,inputs.shape[axis], size=self.shape).astype(self.dtype)
17171717

1718+
elif api_config.api_name in {"paddle.Tensor.index_put", "paddle.index_put"}:
1719+
if self.check_arg(api_config,1,'indices') and not self.get_arg(api_config, 3, "accumulate"):
1720+
# NOTE(zrr1999): If accumulate is False, the behavior is undefined if indices contain duplicate elements in torch.
1721+
1722+
inputs=self.get_arg(api_config, 0, "x")
1723+
value=self.get_arg(api_config, 2, "value")
1724+
inputs_numel = inputs.numel()
1725+
value_numel = value.numel()
1726+
if inputs_numel < value_numel:
1727+
raise ValueError(
1728+
f"Invalid input for paddle.index_put: inputs.numel() < value.numel() when accumulate=False. "
1729+
)
1730+
inputs_shape = inputs.shape
1731+
value_shape = value.shape
1732+
1733+
flat_indices = numpy.random.choice(inputs_numel, size=value_numel, replace=False)
1734+
indices = [index.reshape(value_shape) for index in numpy.unravel_index(flat_indices, inputs_shape)]
1735+
self.numpy_tensor = indices.astype(self.dtype)
1736+
17181737
elif api_config.api_name == "paddle.Tensor.tile":
17191738
if index==1 or key=='repeat_times':
17201739
self.numpy_tensor = numpy.random.randint(1,128, size=self.shape).astype(self.dtype)

0 commit comments

Comments
 (0)