Skip to content

Commit 543d026

Browse files
Merge pull request #270 from NKNaN/bce-fix2
[Accuracy diff No. 110, 111] Fix accuracy diff for paddle.nn.functional.binary_cross_entropy, paddle.nn.functional.binary_cross_entropy_with_logits API
2 parents 9ae3a15 + 2fbb461 commit 543d026

File tree

2 files changed

+15865
-2
lines changed

2 files changed

+15865
-2
lines changed

tester/accuracy.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,13 @@ def test(self):
383383
if self.api_config.api_name == "paddle.scale":
384384
paddle_out_grads = paddle_out_grads[0]
385385
torch_out_grads = torch_out_grads[0]
386-
386+
if self.api_config.api_name == "paddle.nn.functional.binary_cross_entropy":
387+
paddle_out_grads = paddle_out_grads[0]
388+
torch_out_grads = torch_out_grads[0]
389+
if self.api_config.api_name == "paddle.nn.functional.binary_cross_entropy_with_logits":
390+
paddle_out_grads = paddle_out_grads[0]
391+
torch_out_grads = torch_out_grads[0]
392+
387393
if isinstance(paddle_out_grads, paddle.Tensor):
388394
if isinstance(torch_out_grads, torch.Tensor):
389395
try:

0 commit comments

Comments
 (0)