diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index f294cc376c21ec..1aab1bd9507ae2 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -14,9 +14,7 @@ from __future__ import annotations -import functools import math -import operator from typing import TYPE_CHECKING, Literal, overload import paddle @@ -110,10 +108,6 @@ def dice_loss( assert ( input.shape[:-1] == label.shape[:-1] ), "All dimensions should be equal except the last one." - assert ( - functools.reduce(operator.mul, input.shape) != 0 - and functools.reduce(operator.mul, label.shape) != 0 - ), "Any dimension of input and label cannot be equal to 0." label = paddle.squeeze(label, [-1]) label = paddle.nn.functional.one_hot(label, input.shape[-1]) diff --git a/test/legacy_test/test_nn_dice_loss.py b/test/legacy_test/test_nn_dice_loss.py index 8734bc05e749dd..317195e45423d3 100644 --- a/test/legacy_test/test_nn_dice_loss.py +++ b/test/legacy_test/test_nn_dice_loss.py @@ -14,9 +14,28 @@ import unittest +import numpy as np +from op_test import get_places + +import paddle + num_classes = 4 eps = 1e-6 +class TestDiceLossOpApi_ZeroSize(unittest.TestCase): + def test_api_with_dygraph(self): + for place in get_places(): + paddle.disable_static(place) + input = paddle.randn([0, 2]).astype(paddle.float64) + input.stop_gradient = False + label = paddle.randn([0, 1]).astype(paddle.int64) + label.stop_gradient = False + out = paddle.nn.functional.dice_loss(input, label, 1e-5) + np.testing.assert_allclose(out.numpy(), paddle.nan) + out.sum().backward() + np.testing.assert_allclose(input.grad.shape, input.shape) + + if __name__ == "__main__": unittest.main()