Skip to content

Commit 1d3518f

Browse files
authored
[0-size Tensor Job2 No.56] Add 0-size Tensor support for paddle.nn.functional.dice_loss[fluid_ops]
1 parent 0a23433 commit 1d3518f

File tree

2 files changed

+19
-6
lines changed

2 files changed

+19
-6
lines changed

python/paddle/nn/functional/loss.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@
1414

1515
from __future__ import annotations
1616

17-
import functools
1817
import math
19-
import operator
2018
from typing import TYPE_CHECKING, Literal, overload
2119

2220
import paddle
@@ -110,10 +108,6 @@ def dice_loss(
110108
assert (
111109
input.shape[:-1] == label.shape[:-1]
112110
), "All dimensions should be equal except the last one."
113-
assert (
114-
functools.reduce(operator.mul, input.shape) != 0
115-
and functools.reduce(operator.mul, label.shape) != 0
116-
), "Any dimension of input and label cannot be equal to 0."
117111

118112
label = paddle.squeeze(label, [-1])
119113
label = paddle.nn.functional.one_hot(label, input.shape[-1])

test/legacy_test/test_nn_dice_loss.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,28 @@
1414

1515
import unittest
1616

17+
import numpy as np
18+
from op_test import get_places
19+
20+
import paddle
21+
1722
num_classes = 4
1823
eps = 1e-6
1924

2025

26+
class TestDiceLossOpApi_ZeroSize(unittest.TestCase):
27+
def test_api_with_dygraph(self):
28+
for place in get_places():
29+
paddle.disable_static(place)
30+
input = paddle.randn([0, 2]).astype(paddle.float64)
31+
input.stop_gradient = False
32+
label = paddle.randn([0, 1]).astype(paddle.int64)
33+
label.stop_gradient = False
34+
out = paddle.nn.functional.dice_loss(input, label, 1e-5)
35+
np.testing.assert_allclose(out.numpy(), paddle.nan)
36+
out.sum().backward()
37+
np.testing.assert_allclose(input.grad.shape, input.shape)
38+
39+
2140
if __name__ == "__main__":
2241
unittest.main()

0 commit comments

Comments
 (0)