diff --git a/paddle/phi/kernels/cpu/pad3d_grad_kernel.cc b/paddle/phi/kernels/cpu/pad3d_grad_kernel.cc index ecfec05dda25be..c6112fbca9bf37 100644 --- a/paddle/phi/kernels/cpu/pad3d_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/pad3d_grad_kernel.cc @@ -375,6 +375,7 @@ void Pad3dGradKernel(const Context& dev_ctx, auto d_out_dims = d_out->dims(); const T* d_out_data = d_out->data(); T* d_in_data = dev_ctx.template Alloc(d_in); + if (x.numel() == 0) return; phi::funcs::SetConstant()(dev_ctx, d_in, static_cast(0)); const int pad_left = static_cast(pads[0]); diff --git a/paddle/phi/kernels/cpu/pad3d_kernel.cc b/paddle/phi/kernels/cpu/pad3d_kernel.cc index f99a5582ecfbcc..cb247640484e91 100644 --- a/paddle/phi/kernels/cpu/pad3d_kernel.cc +++ b/paddle/phi/kernels/cpu/pad3d_kernel.cc @@ -17,6 +17,7 @@ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/common/complex.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/full_kernel.h" namespace phi { @@ -406,6 +407,11 @@ void Pad3dKernel(const Context& dev_ctx, auto out_dims = out->dims(); T* out_data = dev_ctx.template Alloc(out); + if (x.numel() == 0) { + phi::Full( + dev_ctx, phi::IntArray(common::vectorize(out->dims())), pad_value, out); + return; + } int channels = static_cast(in_dims[1]); int in_depth = static_cast(in_dims[2]); diff --git a/paddle/phi/kernels/gpu/pad3d_grad_kernel.cu b/paddle/phi/kernels/gpu/pad3d_grad_kernel.cu index c3492c5560cdc3..b7494549024d79 100644 --- a/paddle/phi/kernels/gpu/pad3d_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/pad3d_grad_kernel.cu @@ -347,6 +347,7 @@ void Pad3dGradKernel(const Context& dev_ctx, auto d_out_dims = d_out->dims(); const T* d_out_data = d_out->data(); T* d_in_data = dev_ctx.template Alloc(d_in); + if (x.numel() == 0) return; phi::funcs::SetConstant()(dev_ctx, d_in, static_cast(0)); diff --git a/paddle/phi/kernels/gpu/pad3d_kernel.cu b/paddle/phi/kernels/gpu/pad3d_kernel.cu index 30cc5bcb4d2570..556548ada5c34f 100644 --- a/paddle/phi/kernels/gpu/pad3d_kernel.cu +++ b/paddle/phi/kernels/gpu/pad3d_kernel.cu @@ -20,7 +20,7 @@ #include "paddle/phi/backends/gpu/gpu_primitives.h" #include "paddle/phi/common/complex.h" #include "paddle/phi/core/kernel_registry.h" - +#include "paddle/phi/kernels/full_kernel.h" namespace phi { using phi::PADDLE_CUDA_NUM_THREADS; @@ -359,6 +359,11 @@ void Pad3dKernel(const Context& dev_ctx, } out->Resize(out_dims); T* out_data = dev_ctx.template Alloc(out); + if (x.numel() == 0) { + phi::Full( + dev_ctx, phi::IntArray(common::vectorize(out->dims())), pad_value, out); + return; + } int64_t channels = in_dims[1]; int64_t in_depth = in_dims[2]; diff --git a/paddle/phi/kernels/impl/pad_grad_kernel_impl.h b/paddle/phi/kernels/impl/pad_grad_kernel_impl.h index 875045f8c8e8f0..eb6352aaea87e8 100644 --- a/paddle/phi/kernels/impl/pad_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/pad_grad_kernel_impl.h @@ -28,6 +28,7 @@ void PadGradKernel(const Context& dev_ctx, return; } dev_ctx.template Alloc(d_x); + if (d_x->numel() == 0) return; int rank = d_out.dims().size(); phi::funcs::PaddingGradFunctor( rank, dev_ctx, paddings, d_out, d_x); diff --git a/paddle/phi/kernels/impl/pad_kernel_impl.h b/paddle/phi/kernels/impl/pad_kernel_impl.h index 7737882c42fd92..217ff8fff74ff0 100644 --- a/paddle/phi/kernels/impl/pad_kernel_impl.h +++ b/paddle/phi/kernels/impl/pad_kernel_impl.h @@ -18,6 +18,7 @@ #include "paddle/phi/common/scalar.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/funcs/padding.h" namespace phi { template @@ -27,6 +28,15 @@ void PadKernel(const Context& dev_ctx, const Scalar& pad_value, DenseTensor* out) { dev_ctx.template Alloc(out); + if (x.numel() == 0) { + if (out) { + phi::Full(dev_ctx, + phi::IntArray(common::vectorize(out->dims())), + pad_value, + out); + } + return; + } int rank = x.dims().size(); funcs::PaddingFunctor( rank, dev_ctx, paddings, pad_value.to(), x, out); diff --git a/paddle/phi/kernels/xpu/pad3d_grad_kernel.cc b/paddle/phi/kernels/xpu/pad3d_grad_kernel.cc index b42a3180dcf8a6..c0ec47b722fb98 100644 --- a/paddle/phi/kernels/xpu/pad3d_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/pad3d_grad_kernel.cc @@ -37,6 +37,7 @@ void Pad3dGradKernel(const Context& dev_ctx, auto d_in_dims = common::vectorize(d_in->dims()); const T* d_out_data = d_out->data(); T* d_in_data = dev_ctx.template Alloc(d_in); + if (x.numel() == 0) return; bool is_ncdhw = true; if (data_format == "NDHWC") { diff --git a/paddle/phi/kernels/xpu/pad3d_kernel.cc b/paddle/phi/kernels/xpu/pad3d_kernel.cc index a74d1018303c79..b01bfa974afded 100644 --- a/paddle/phi/kernels/xpu/pad3d_kernel.cc +++ b/paddle/phi/kernels/xpu/pad3d_kernel.cc @@ -17,6 +17,7 @@ #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/backends/xpu/xpu_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/full_kernel.h" namespace phi { @@ -51,6 +52,11 @@ void Pad3dKernel(const Context& dev_ctx, } T* out_data = dev_ctx.template Alloc(out); + if (x.numel() == 0) { + phi::Full( + dev_ctx, phi::IntArray(common::vectorize(out->dims())), pad_value, out); + return; + } const int64_t num = in_dims[0]; // n int64_t channels = in_dims[1]; // c diff --git a/paddle/phi/kernels/xpu/pad_grad_kernel.cc b/paddle/phi/kernels/xpu/pad_grad_kernel.cc index fb84a2f09440f2..2d7a0db907ed66 100644 --- a/paddle/phi/kernels/xpu/pad_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/pad_grad_kernel.cc @@ -29,6 +29,7 @@ void PadGradKernel(const Context& dev_ctx, std::vector pad_left, pad_right; std::vector out_shape = common::vectorize(d_out.dims()); dev_ctx.template Alloc(d_x); + if (d_x && d_x->numel() == 0) return; for (size_t i = 0; i < paddings.size() / 2; ++i) { pad_left.push_back(-paddings[i * 2]); @@ -58,6 +59,7 @@ void PadGradKernel, XPUContext>( std::vector pad_left, pad_right; std::vector out_shape = common::vectorize(d_out.dims()); dev_ctx.template Alloc(d_x); + if (d_x && d_x->numel() == 0) return; for (size_t i = 0; i < paddings.size() / 2; ++i) { pad_left.push_back(-paddings[i * 2]); diff --git a/paddle/phi/kernels/xpu/pad_kernel.cc b/paddle/phi/kernels/xpu/pad_kernel.cc index b7353d4d6743ae..eb86c0a05fc105 100644 --- a/paddle/phi/kernels/xpu/pad_kernel.cc +++ b/paddle/phi/kernels/xpu/pad_kernel.cc @@ -17,6 +17,7 @@ #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/complex_kernel.h" +#include "paddle/phi/kernels/full_kernel.h" namespace phi { template @@ -27,6 +28,15 @@ void PadKernel(const Context& dev_ctx, DenseTensor* out) { using XPUType = typename XPUTypeTrait::Type; dev_ctx.template Alloc(out); + if (x.numel() == 0) { + if (out) { + phi::Full(dev_ctx, + phi::IntArray(common::vectorize(out->dims())), + pad_value, + out); + return; + } + } std::vector pad_left, pad_right; std::vector xshape = common::vectorize(x.dims()); diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 15b158763572bb..432be87fa457b0 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -1959,7 +1959,9 @@ def pad( ], f"mode should be one of constant, reflect, replicate, circular, but got {mode}." x_dim = len(x.shape) - + if in_dynamic_mode(): + if isinstance(pad, (Variable, paddle.Tensor)) and pad.size == 0: + return x.clone() if ( mode == "constant" and isinstance(pad, (list, tuple)) diff --git a/test/legacy_test/test_pad3d_op.py b/test/legacy_test/test_pad3d_op.py index c6402fab176aca..46c3ab42ab99f3 100644 --- a/test/legacy_test/test_pad3d_op.py +++ b/test/legacy_test/test_pad3d_op.py @@ -1196,8 +1196,9 @@ def test_replicate_1(): self.assertRaises(Exception, test_reflect_1) self.assertRaises(Exception, test_reflect_2) self.assertRaises(Exception, test_reflect_3) - self.assertRaises(Exception, test_circular_1) - self.assertRaises(Exception, test_replicate_1) + # comment out because pad3d support 0-size now. + # self.assertRaises(Exception, test_circular_1) + # self.assertRaises(Exception, test_replicate_1) paddle.enable_static() diff --git a/test/legacy_test/test_pad_op.py b/test/legacy_test/test_pad_op.py index 1c000d7ac79bbd..05b9efe14b60f9 100644 --- a/test/legacy_test/test_pad_op.py +++ b/test/legacy_test/test_pad_op.py @@ -17,7 +17,7 @@ import unittest import numpy as np -from op_test import OpTest, convert_float_to_uint16 +from op_test import OpTest, convert_float_to_uint16, get_places sys.path.append("../deprecated/legacy_test") from test_attribute_var import UnittestBase @@ -561,6 +561,53 @@ def init_case(self): self.pad_value = 0.5 +class TestPadOp_ZeroSize(unittest.TestCase): + def init_case(self): + self.shape = [0, 16] + self.paddings = [(0, 1), (2, 3)] + self.paddings_empty_tensor = False + self.pad_value = 0.5 + + def test_dygraph(self): + self.init_case() + for place in get_places(): + paddle.disable_static(place) + x_np = np.random.random(self.shape).astype('float32') + paddings_np = self.paddings.copy() + x = paddle.to_tensor(x_np) + x.stop_gradient = False + paddings = list(np.array(self.paddings).flatten()) + if self.paddings_empty_tensor: + paddings = paddle.to_tensor(paddings) + # output the same as x + out_np = x_np + else: + out_np = np.pad( + x_np, + paddings_np, + mode="constant", + constant_values=self.pad_value, + ) + out = paddle.nn.functional.pad( + x, + paddings, + mode='constant', + value=self.pad_value, + pad_from_left_axis=True, + ) + np.testing.assert_array_equal(out, out_np) + out.sum().backward() + np.testing.assert_allclose(x.grad.numpy(), np.ones(self.shape)) + + +class TestPadOp_ZeroSize2(TestPadOp_ZeroSize): + def init_case(self): + self.shape = [4, 6, 6] + self.paddings = [] + self.paddings_empty_tensor = True + self.pad_value = 0.5 + + if __name__ == "__main__": # paddle.enable_static() unittest.main() diff --git a/test/xpu/test_pad3d_op_xpu.py b/test/xpu/test_pad3d_op_xpu.py index 4e98f554e12445..59dd708f063898 100644 --- a/test/xpu/test_pad3d_op_xpu.py +++ b/test/xpu/test_pad3d_op_xpu.py @@ -836,7 +836,8 @@ def test_replicate_1(): self.assertRaises(Exception, test_reflect_1) self.assertRaises(Exception, test_reflect_2) self.assertRaises(Exception, test_reflect_3) - self.assertRaises(Exception, test_replicate_1) + # comment out because pad3d support 0-size now. + # self.assertRaises(Exception, test_replicate_1) paddle.enable_static() class TestPadDataformatError(unittest.TestCase):