Skip to content

Commit 6003440

Browse files
authored
[Accuracy diff No.79] Fix accuracy diff for paddle.combinations API (#73293)
* fix accuracy for combinations * empty commit
1 parent 4f051c0 commit 6003440

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

python/paddle/tensor/math.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8391,11 +8391,6 @@ def combinations(
83918391
if r == 0:
83928392
return paddle.empty(shape=[0], dtype=x.dtype)
83938393

8394-
if (r > x.shape[0] and not with_replacement) or (
8395-
x.shape[0] == 0 and with_replacement
8396-
):
8397-
return paddle.empty(shape=[0, r], dtype=x.dtype)
8398-
83998394
if r > 1:
84008395
t_l = [x for i in range(r)]
84018396
grids = paddle.meshgrid(t_l)

test/legacy_test/test_combinations.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,22 @@ def test_dygraph(self):
143143
for place in self.place:
144144
paddle.device.set_device(place)
145145
a = paddle.rand([3], dtype='float32')
146+
a.stop_gradient = False
146147
c = paddle.combinations(a, r=4)
147148
expected = convert_combinations_to_array(a.numpy(), r=4)
148149
np.testing.assert_allclose(c, expected)
150+
loss = c.sum().backward()
151+
expected = np.zeros([3], dtype='float32')
152+
np.testing.assert_allclose(a.grad, expected)
153+
154+
a = paddle.rand([0], dtype='float32')
155+
a.stop_gradient = False
156+
c = paddle.combinations(a, r=2, with_replacement=True)
157+
expected = convert_combinations_to_array(a.numpy(), r=2)
158+
np.testing.assert_allclose(c, expected)
159+
loss = c.sum().backward()
160+
expected = np.empty([0], dtype='float32')
161+
np.testing.assert_allclose(a.grad, expected)
149162

150163
# test empty input
151164
a = paddle.empty([random.randint(0, 8)])

0 commit comments

Comments
 (0)