Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions paddle/phi/kernels/impl/kldiv_loss_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ void KLDivLossGradKernel(const Context& dev_ctx,
const std::string& reduction,
bool log_target,
DenseTensor* d_x) {
if (d_x->numel() == 0) {
dev_ctx.template Alloc<T>(d_x);
return;
}

auto& place = *dev_ctx.eigen_device();
auto* target = &label;
auto* input_grad = d_x;
Expand Down
7 changes: 6 additions & 1 deletion paddle/phi/kernels/impl/kldiv_loss_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
#include "paddle/common/hostdevice.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"

namespace phi {
using Array1 = Eigen::DSizes<int64_t, 1>;
template <typename T>
Expand Down Expand Up @@ -48,6 +48,11 @@ void KLDivLossKernel(const Context& dev_ctx,
const std::string& reduction,
bool log_target,
DenseTensor* out) {
if (x.numel() == 0) {
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(out->dims())), NAN, out);
return;
}
auto& place = *(dev_ctx.eigen_device());
auto* input = &x;
auto* target = &label;
Expand Down
7 changes: 6 additions & 1 deletion paddle/phi/kernels/xpu/kldiv_loss_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ limitations under the License. */
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/softmax_kernel.h"

namespace phi {

template <typename T, typename Context>
Expand All @@ -31,6 +31,11 @@ void KLDivLossKernel(const Context& dev_ctx,
if (out->numel() == 0) {
return;
}
if (x.numel() == 0) {
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(out->dims())), NAN, out);
return;
}

int r = 0;

Expand Down
39 changes: 39 additions & 0 deletions test/legacy_test/test_hinge_embedding_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,5 +201,44 @@ def test_value_error():
self.assertRaises(ValueError, test_value_error)


class TestFunctionalHingeEmbeddingLoss_ZeroSize(unittest.TestCase):
def setUp(self):
self.margin = 1.0
self.shape = (0, 10, 5) # zero size
self.input_np = np.random.random(size=self.shape).astype(np.float64)
self.label_np = 2 * np.random.randint(0, 2, size=self.shape) - 1.0

def run_dynamic_check(self, place=paddle.CPUPlace()):
paddle.disable_static(place=place)
input = paddle.to_tensor(self.input_np)
input.stop_gradient = False
label = paddle.to_tensor(self.label_np, dtype="float64")

dy_result = paddle.nn.functional.hinge_embedding_loss(input, label)
expected = calc_hinge_embedding_loss(self.input_np, self.label_np)
np.testing.assert_allclose(dy_result.numpy(), expected, rtol=1e-05)
self.assertEqual(dy_result.shape, [])

dy_result = paddle.nn.functional.hinge_embedding_loss(
input, label, reduction='none'
)
expected = calc_hinge_embedding_loss(
self.input_np, self.label_np, reduction='none'
)
np.testing.assert_allclose(dy_result.numpy(), expected, rtol=1e-05)

loss = paddle.sum(dy_result)
loss.backward()
self.assertEqual(input.grad.shape, input.shape)

def test_cpu(self):
self.run_dynamic_check(place=paddle.CPUPlace())

def test_gpu(self):
if not paddle.is_compiled_with_cuda():
return
self.run_dynamic_check(place=paddle.CUDAPlace(0))


if __name__ == "__main__":
unittest.main()
46 changes: 46 additions & 0 deletions test/legacy_test/test_kldiv_loss_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,52 @@ def initTestCase(self):
self.log_target = True


class TestKLDivLossOp_ZeroSize1(TestKLDivLossOp):
def setUp(self):
self.initTestCase()
self.op_type = 'kldiv_loss'
self.python_api = kl_div
self.public_python_api = paddle.nn.functional.kl_div
x = np.random.uniform(-10, 10, self.x_shape).astype('float64')
target = np.random.uniform(-10, 10, self.x_shape).astype('float64')

self.attrs = {
"reduction": self.reduction,
"log_target": self.log_target,
}

self.inputs = {
'X': x,
'Target': target,
}
loss = kldiv_loss(x, target, self.reduction, self.log_target)
self.outputs = {'Loss': loss.astype('float64')}

def initTestCase(self):
# return NAN
self.x_shape = (0, 2, 7, 7)
self.reduction = 'mean'
self.log_target = False

def test_check_output(self):
self.check_output(check_pir=True, equal_nan=True)

def test_check_grad(self):
self.check_grad(
['X'],
'Loss',
no_grad_set={"Target"},
check_pir=True,
)


class TestKLDivLossOp_ZeroSize2(TestKLDivLossOp_ZeroSize1):
def initTestCase(self):
self.x_shape = (0, 2, 7, 7)
self.reduction = 'none'
self.log_target = False


class TestKLDivLossDygraph(unittest.TestCase):
def run_kl_loss(self, reduction, shape=(5, 20), log_target=False):
x = np.random.uniform(-10, 10, shape).astype('float64')
Expand Down
38 changes: 38 additions & 0 deletions test/legacy_test/test_l1_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,44 @@ def test_value_error():
self.assertRaises(ValueError, test_value_error)


class TestClassL1Loss_ZeroSize(unittest.TestCase):
def setUp(self):
self.input_np = np.random.random(size=(0, 10, 5)).astype(np.float32)
self.label_np = np.random.random(size=(0, 10, 5)).astype(np.float32)

def run_imperative(self):
input = paddle.to_tensor(self.input_np)
label = paddle.to_tensor(self.label_np)
input.stop_gradient = False
l1_loss = paddle.nn.loss.L1Loss()
dy_result = l1_loss(input, label)
expected = np.mean(np.abs(self.input_np - self.label_np))
np.testing.assert_allclose(dy_result.numpy(), expected, rtol=1e-05)
self.assertEqual(dy_result.shape, [])

l1_loss = paddle.nn.loss.L1Loss(reduction='sum')
dy_result = l1_loss(input, label)
expected = np.sum(np.abs(self.input_np - self.label_np))
np.testing.assert_allclose(dy_result.numpy(), expected, rtol=1e-05)
self.assertEqual(dy_result.shape, [])

loss = paddle.sum(dy_result)
loss.backward()
np.testing.assert_allclose(input.grad.shape, input.shape)

def test_cpu(self):
paddle.disable_static(place=paddle.base.CPUPlace())
self.run_imperative()
paddle.enable_static()

def test_gpu(self):
if not base.core.is_compiled_with_cuda():
return
paddle.disable_static(place=paddle.base.CUDAPlace(0))
self.run_imperative()
paddle.enable_static()


if __name__ == "__main__":
paddle.enable_static()
unittest.main()
Loading