File tree Expand file tree Collapse file tree 2 files changed +19
-6
lines changed
python/paddle/nn/functional Expand file tree Collapse file tree 2 files changed +19
-6
lines changed Original file line number Diff line number Diff line change 1414
1515from __future__ import annotations
1616
17- import functools
1817import math
19- import operator
2018from typing import TYPE_CHECKING , Literal , overload
2119
2220import paddle
@@ -110,10 +108,6 @@ def dice_loss(
110108 assert (
111109 input .shape [:- 1 ] == label .shape [:- 1 ]
112110 ), "All dimensions should be equal except the last one."
113- assert (
114- functools .reduce (operator .mul , input .shape ) != 0
115- and functools .reduce (operator .mul , label .shape ) != 0
116- ), "Any dimension of input and label cannot be equal to 0."
117111
118112 label = paddle .squeeze (label , [- 1 ])
119113 label = paddle .nn .functional .one_hot (label , input .shape [- 1 ])
Original file line number Diff line number Diff line change 1414
1515import unittest
1616
17+ import numpy as np
18+ from op_test import get_places
19+
20+ import paddle
21+
1722num_classes = 4
1823eps = 1e-6
1924
2025
26+ class TestDiceLossOpApi_ZeroSize (unittest .TestCase ):
27+ def test_api_with_dygraph (self ):
28+ for place in get_places ():
29+ paddle .disable_static (place )
30+ input = paddle .randn ([0 , 2 ]).astype (paddle .float64 )
31+ input .stop_gradient = False
32+ label = paddle .randn ([0 , 1 ]).astype (paddle .int64 )
33+ label .stop_gradient = False
34+ out = paddle .nn .functional .dice_loss (input , label , 1e-5 )
35+ np .testing .assert_allclose (out .numpy (), paddle .nan )
36+ out .sum ().backward ()
37+ np .testing .assert_allclose (input .grad .shape , input .shape )
38+
39+
2140if __name__ == "__main__" :
2241 unittest .main ()
You can’t perform that action at this time.
0 commit comments