Skip to content

Commit abfc92d

Browse files
committed
Merge branch 'b50' into b51
2 parents dc31a7a + a05d7f8 commit abfc92d

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

test/legacy_test/test_logaddexp.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,5 +79,37 @@ def test_api_int64(self):
7979
self.api_case()
8080

8181

82+
class TestLogsumexpAPI_ZeroSize(unittest.TestCase):
83+
def setUp(self):
84+
self.place = (
85+
paddle.CUDAPlace(0)
86+
if paddle.base.core.is_compiled_with_cuda()
87+
else paddle.CPUPlace()
88+
)
89+
90+
def api_case(self):
91+
self.x = np.random.uniform(-1, 1, self.xshape).astype(self.dtype)
92+
self.y = np.random.uniform(-1, 1, self.yshape).astype(self.dtype)
93+
out_ref = ref_logaddexp(self.x, self.y)
94+
95+
paddle.disable_static(self.place)
96+
x = paddle.to_tensor(self.x)
97+
y = paddle.to_tensor(self.y)
98+
x.stop_gradient = False
99+
y.stop_gradient = False
100+
out = paddle.logaddexp(x, y)
101+
np.testing.assert_allclose(out.numpy(), out_ref, atol=1e-06)
102+
103+
loss = paddle.sum(out)
104+
loss.backward()
105+
np.testing.assert_allclose(x.grad.shape, x.shape)
106+
107+
def test_api(self):
108+
self.xshape = [1, 2, 3, 0]
109+
self.yshape = [1, 2, 3, 1]
110+
self.dtype = np.float32
111+
self.api_case()
112+
113+
82114
if __name__ == '__main__':
83115
unittest.main()

0 commit comments

Comments
 (0)