diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index f3392caa15595f..b9dc41a188993e 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -888,19 +888,23 @@ void ConvTransposeInferMeta(const MetaTensor& x, common::make_ddim(output_size).to_str(), i, infer_shape)); - PADDLE_ENFORCE_LT( - output_size[i], - infer_shape + strides[i], - errors::InvalidArgument( - "output_size of Op(ConvTransposeOp) should be less " - "than inferred size + stride. But received output_size = [%s], " - "whose dim %d is not less than the inferred output size (%d) + " - "stride (%d) = %d", - common::make_ddim(output_size).to_str(), - i, - infer_shape, - strides[i], - infer_shape + strides[i])); + if (common::product(x_dims) != 0) { + PADDLE_ENFORCE_LT( + output_size[i], + infer_shape + strides[i], + errors::InvalidArgument( + "output_size of Op(ConvTransposeOp) should be less " + "than inferred size + stride. But received output_size = " + "[%s], " + "whose dim %d is not less than the inferred output size (%d) " + "+ " + "stride (%d) = %d", + common::make_ddim(output_size).to_str(), + i, + infer_shape, + strides[i], + infer_shape + strides[i])); + } } output_shape.push_back(output_size[i]); } else if (!output_padding.empty()) { diff --git a/paddle/phi/kernels/gpu/conv_transpose_grad_kernel.cu b/paddle/phi/kernels/gpu/conv_transpose_grad_kernel.cu index 04e97a6647417a..d968f293be3150 100644 --- a/paddle/phi/kernels/gpu/conv_transpose_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/conv_transpose_grad_kernel.cu @@ -18,6 +18,7 @@ #include "paddle/common/layout.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/cpu/conv_util.h" +#include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/gpu/depthwise_conv.h" #include "paddle/phi/kernels/impl/conv_transpose_grad_kernel_impl.h" @@ -77,7 +78,25 @@ void DepthwiseConv2dTransposeGradKernel(const Context& dev_ctx, if (!dx && !dfilter) { return; } - + // 0-size + if (x.numel() == 0) { + if (dx) dev_ctx.template Alloc(dx); + if (dfilter) { + phi::Full(dev_ctx, + phi::IntArray(common::vectorize(dfilter->dims())), + 0, + dfilter); + } + return; + } + if (filter.numel() == 0) { + if (dfilter) dev_ctx.template Alloc(dfilter); + if (dx) { + phi::Full( + dev_ctx, phi::IntArray(common::vectorize(dx->dims())), 0, dx); + } + return; + } std::vector paddings_ = paddings; std::vector dilations_ = dilations; diff --git a/paddle/phi/kernels/gpu/conv_transpose_kernel.cu b/paddle/phi/kernels/gpu/conv_transpose_kernel.cu index 04028343a063f6..1993490398dc74 100644 --- a/paddle/phi/kernels/gpu/conv_transpose_kernel.cu +++ b/paddle/phi/kernels/gpu/conv_transpose_kernel.cu @@ -37,6 +37,11 @@ void DepthwiseConv2dTransposeKernel(const Context& dev_ctx, const std::vector& dilations, const std::string& data_format, DenseTensor* out) { + if (x.numel() == 0 || filter.numel() == 0) { + phi::Full( + dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out); + return; + } const DataLayout data_layout = common::StringToDataLayout(data_format); DenseTensor filter_ = filter; dev_ctx.template Alloc(out); diff --git a/paddle/phi/kernels/gpudnn/conv_transpose_grad_kernel.cu b/paddle/phi/kernels/gpudnn/conv_transpose_grad_kernel.cu index ff89caebf110e5..25051226e59960 100644 --- a/paddle/phi/kernels/gpudnn/conv_transpose_grad_kernel.cu +++ b/paddle/phi/kernels/gpudnn/conv_transpose_grad_kernel.cu @@ -37,6 +37,7 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/cuda/cudnn_workspace_helper.h" #include "paddle/phi/kernels/gpudnn/conv_cudnn_v7.h" #endif +#include "paddle/phi/kernels/full_kernel.h" namespace phi { @@ -55,6 +56,26 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& dev_ctx, const std::string& data_format, DenseTensor* dx, DenseTensor* dfilter) { + // 0-size + if (x.numel() == 0) { + if (dx) dev_ctx.template Alloc(dx); + if (dfilter) { + phi::Full(dev_ctx, + phi::IntArray(common::vectorize(dfilter->dims())), + 0, + dfilter); + } + return; + } + if (filter.numel() == 0) { + if (dfilter) dev_ctx.template Alloc(dfilter); + if (dx) { + phi::Full( + dev_ctx, phi::IntArray(common::vectorize(dx->dims())), 0, dx); + } + return; + } + const T* filter_data = filter.data(); std::vector paddings_ = paddings; std::vector dilations_ = diff --git a/paddle/phi/kernels/gpudnn/conv_transpose_kernel.cu b/paddle/phi/kernels/gpudnn/conv_transpose_kernel.cu index 81b374bf26e153..1c6558c85cb47e 100644 --- a/paddle/phi/kernels/gpudnn/conv_transpose_kernel.cu +++ b/paddle/phi/kernels/gpudnn/conv_transpose_kernel.cu @@ -35,6 +35,7 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/cuda/cudnn_workspace_helper.h" #include "paddle/phi/kernels/gpudnn/conv_cudnn_v7.h" #endif +#include "paddle/phi/kernels/full_kernel.h" namespace phi { @@ -51,6 +52,11 @@ void ConvTransposeRawGPUDNNKernel(const Context& dev_ctx, const std::vector& dilations, const std::string& data_format, DenseTensor* out) { + if (x.numel() == 0 || filter.numel() == 0) { + phi::Full( + dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out); + return; + } std::vector paddings_ = paddings; std::vector dilations_ = dilations; // cudnn v5 does not support dilations diff --git a/paddle/phi/kernels/impl/conv_transpose_grad_kernel_impl.h b/paddle/phi/kernels/impl/conv_transpose_grad_kernel_impl.h index 2548093eef73e3..9a21c23666a95d 100644 --- a/paddle/phi/kernels/impl/conv_transpose_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/conv_transpose_grad_kernel_impl.h @@ -18,6 +18,7 @@ #include "paddle/common/layout.h" #include "paddle/phi/kernels/conv_transpose_grad_kernel.h" #include "paddle/phi/kernels/cpu/conv_util.h" +#include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/concat_and_split_functor.h" #include "paddle/phi/kernels/funcs/im2col.h" @@ -48,6 +49,26 @@ void ConvTransposeGradRawKernel(const Context& dev_ctx, return; } + // 0-size + if (x.numel() == 0) { + if (dx) dev_ctx.template Alloc(dx); + if (dfilter) { + phi::Full(dev_ctx, + phi::IntArray(common::vectorize(dfilter->dims())), + 0, + dfilter); + } + return; + } + if (filter.numel() == 0) { + if (dfilter) dev_ctx.template Alloc(dfilter); + if (dx) { + phi::Full( + dev_ctx, phi::IntArray(common::vectorize(dx->dims())), 0, dx); + } + return; + } + std::vector paddings_ = paddings; std::vector dilations_ = dilations; diff --git a/paddle/phi/kernels/impl/conv_transpose_kernel_impl.h b/paddle/phi/kernels/impl/conv_transpose_kernel_impl.h index b99677a416e943..aadc5d2b8a0e5c 100644 --- a/paddle/phi/kernels/impl/conv_transpose_kernel_impl.h +++ b/paddle/phi/kernels/impl/conv_transpose_kernel_impl.h @@ -18,6 +18,7 @@ #include "paddle/common/layout.h" #include "paddle/phi/kernels/conv_transpose_kernel.h" #include "paddle/phi/kernels/cpu/conv_util.h" +#include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/concat_and_split_functor.h" #include "paddle/phi/kernels/funcs/im2col.h" @@ -37,6 +38,11 @@ void ConvTransposeRawKernel(const Context& dev_ctx, const std::vector& dilations, const std::string& data_format, DenseTensor* out) { + if (x.numel() == 0 || filter.numel() == 0) { + phi::Full( + dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out); + return; + } const DataLayout data_layout = common::StringToDataLayout(data_format); // The filter will be reshaped, so it should not be constant DenseTensor filter_ = filter; diff --git a/paddle/phi/kernels/xpu/conv_transpose_grad_kernel.cc b/paddle/phi/kernels/xpu/conv_transpose_grad_kernel.cc index 2de1c3653179a6..92a24f53f89f96 100644 --- a/paddle/phi/kernels/xpu/conv_transpose_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/conv_transpose_grad_kernel.cc @@ -17,6 +17,7 @@ #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/cpu/conv_util.h" +#include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/xpu/xpu_api_wrapper.h" namespace phi { @@ -40,7 +41,25 @@ void Conv2dTransposeGradKernel(const Context& dev_ctx, // that avoids modifying the variable in the Scope. DenseTensor filter_ = filter; if (!dx && !dfilter) return; - + // 0-size + if (x.numel() == 0) { + if (dx) dev_ctx.template Alloc(dx); + if (dfilter) { + phi::Full(dev_ctx, + phi::IntArray(common::vectorize(dfilter->dims())), + 0, + dfilter); + } + return; + } + if (filter.numel() == 0) { + if (dfilter) dev_ctx.template Alloc(dfilter); + if (dx) { + phi::Full( + dev_ctx, phi::IntArray(common::vectorize(dx->dims())), 0, dx); + } + return; + } std::vector strides_ = std::vector(strides.begin(), strides.end()); std::vector paddings_ = diff --git a/paddle/phi/kernels/xpu/conv_transpose_kernel.cc b/paddle/phi/kernels/xpu/conv_transpose_kernel.cc index acc1a2e6384241..c4b07af3e2b6dd 100644 --- a/paddle/phi/kernels/xpu/conv_transpose_kernel.cc +++ b/paddle/phi/kernels/xpu/conv_transpose_kernel.cc @@ -19,6 +19,7 @@ #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/cpu/conv_util.h" +#include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/xpu/conv_utils_xpu.h" #include "paddle/phi/kernels/xpu/xpu_api_wrapper.h" #ifdef PADDLE_WITH_XPU_XRE5 @@ -41,7 +42,11 @@ void Conv2dTransposeKernel(const Context& dev_ctx, const std::string& data_format, DenseTensor* out) { using XPUType = typename XPUTypeTrait::Type; - + if (x.numel() == 0 || filter.numel() == 0) { + phi::Full( + dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out); + return; + } dev_ctx.template Alloc(out); PADDLE_ENFORCE_EQ( diff --git a/test/legacy_test/test_conv2d_transpose_op.py b/test/legacy_test/test_conv2d_transpose_op.py index dc11464e6c949d..ff5d98240fa2b6 100644 --- a/test/legacy_test/test_conv2d_transpose_op.py +++ b/test/legacy_test/test_conv2d_transpose_op.py @@ -14,10 +14,12 @@ import os import unittest +from unittest import TestCase import numpy as np import paddle +import paddle.base.dygraph as dg import paddle.static from paddle import nn @@ -1519,5 +1521,62 @@ def call_func(self, x): return out +class TestFunctionalConv2DTranspose_ZeroSize(TestCase): + def init_data(self): + self.input = np.random.randn(0, 4, 16, 4) + self.filter = np.random.randn(4, 3, 3, 3) + self.np_out = np.zeros([0, 3, 18, 6]) + + def setUp(self): + self.init_data() + self.bias = None + self.padding = 0 + self.stride = 1 + self.dilation = 1 + self.groups = 1 + self.data_format = "NCHW" + self.places = [] + if ( + os.environ.get('FLAGS_CI_both_cpu_and_gpu', 'False').lower() + in ['1', 'true', 'on'] + or not base.core.is_compiled_with_cuda() + ): + self.places.append(base.CPUPlace()) + if base.core.is_compiled_with_cuda(): + self.places.append(base.CUDAPlace(0)) + + def test_dygraph(self): + for place in self.places: + with dg.guard(place): + input = paddle.to_tensor(self.input) + input.stop_gradient = False + filter = paddle.to_tensor(self.filter) + filter.stop_gradient = False + y = paddle.nn.functional.conv2d_transpose( + input, + filter, + self.bias, + padding=self.padding, + stride=self.stride, + dilation=self.dilation, + groups=self.groups, + data_format=self.data_format, + ) + np.testing.assert_allclose(y.numpy(), self.np_out) + loss = y.sum() + loss.backward() + np.testing.assert_allclose(input.grad.shape, input.shape) + np.testing.assert_allclose(filter.grad, np.zeros(filter.shape)) + + +class TestFunctionalConv2DTranspose_ZeroSize2( + TestFunctionalConv2DTranspose_ZeroSize +): + def init_data(self): + self.input = np.random.randn(4, 5, 3, 3) + self.filter = np.random.randn(5, 0, 4, 4) + self.np_out = np.zeros([4, 0, 6, 6]) + + if __name__ == '__main__': unittest.main() diff --git a/test/legacy_test/test_functional_conv1d_transpose.py b/test/legacy_test/test_functional_conv1d_transpose.py index 68e9eaf15e6106..3818c062b14ae6 100644 --- a/test/legacy_test/test_functional_conv1d_transpose.py +++ b/test/legacy_test/test_functional_conv1d_transpose.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import unittest from unittest import TestCase @@ -20,6 +21,7 @@ import paddle import paddle.base.dygraph as dg import paddle.nn.functional as F +from paddle import base class TestFunctionalConv1DError(TestCase): @@ -30,7 +32,7 @@ def setUp(self): self.padding = 0 self.stride = 1 self.dilation = 1 - self.groups = 1 + self.groups = 0 self.data_format = "NCL" def dygraph_case(self): @@ -82,16 +84,61 @@ def setUp(self): self.data_format = "NCL" -class TestFunctionalConv1DErrorCase3(TestFunctionalConv1DError): +class TestFunctionalConv1DTranspose_ZeroSize(TestCase): + def init_data(self): + self.input = np.random.randn(0, 1, 2) + self.filter = np.random.randn(1, 1, 2) + self.np_out = np.zeros([0, 1, 3]) + def setUp(self): - self.input = np.random.randn(6, 0, 6) - self.filter = np.random.randn(6, 0, 0) + self.init_data() self.bias = None self.padding = 0 self.stride = 1 self.dilation = 1 self.groups = 1 self.data_format = "NCL" + self.places = [] + if ( + os.environ.get('FLAGS_CI_both_cpu_and_gpu', 'False').lower() + in ['1', 'true', 'on'] + or not base.core.is_compiled_with_cuda() + ): + self.places.append(base.CPUPlace()) + if base.core.is_compiled_with_cuda(): + self.places.append(base.CUDAPlace(0)) + + def test_dygraph(self): + for place in self.places: + with dg.guard(place): + input = paddle.to_tensor(self.input) + input.stop_gradient = False + filter = paddle.to_tensor(self.filter) + filter.stop_gradient = False + y = F.conv1d_transpose( + input, + filter, + self.bias, + padding=self.padding, + stride=self.stride, + dilation=self.dilation, + groups=self.groups, + data_format=self.data_format, + ) + np.testing.assert_allclose(y.numpy(), self.np_out) + loss = y.sum() + loss.backward() + np.testing.assert_allclose(input.grad.shape, input.shape) + np.testing.assert_allclose(filter.grad, np.zeros(filter.shape)) + + +class TestFunctionalConv1DTranspose_ZeroSize2( + TestFunctionalConv1DTranspose_ZeroSize +): + def init_data(self): + self.input = np.random.randn(2, 3, 2) + self.filter = np.random.randn(3, 0, 3) + self.np_out = np.zeros([2, 0, 4]) if __name__ == "__main__": diff --git a/test/legacy_test/test_functional_conv3d_transpose.py b/test/legacy_test/test_functional_conv3d_transpose.py index 53ec33cedec125..e51b3e0602c7b6 100644 --- a/test/legacy_test/test_functional_conv3d_transpose.py +++ b/test/legacy_test/test_functional_conv3d_transpose.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import unittest from unittest import TestCase @@ -223,7 +224,7 @@ def setUp(self): self.padding = 0 self.stride = 1 self.dilation = 1 - self.groups = 1 + self.groups = 0 self.data_format = "NCDHW" def dygraph_case(self): @@ -267,6 +268,63 @@ def setUp(self): self.data_format = "NCDHW" +class TestFunctionalConv3DTranspose_ZeroSize(TestCase): + def init_data(self): + self.input = np.random.randn(0, 2, 2, 2, 3) + self.filter = np.random.randn(2, 1, 3, 3, 3) + self.np_out = np.zeros([0, 1, 4, 4, 5]) + + def setUp(self): + self.init_data() + self.bias = None + self.padding = 0 + self.stride = 1 + self.dilation = 1 + self.groups = 1 + self.data_format = "NCDHW" + self.places = [] + if ( + os.environ.get('FLAGS_CI_both_cpu_and_gpu', 'False').lower() + in ['1', 'true', 'on'] + or not base.core.is_compiled_with_cuda() + ): + self.places.append(base.CPUPlace()) + if base.core.is_compiled_with_cuda(): + self.places.append(base.CUDAPlace(0)) + + def test_dygraph(self): + for place in self.places: + with dg.guard(place): + input = paddle.to_tensor(self.input) + input.stop_gradient = False + filter = paddle.to_tensor(self.filter) + filter.stop_gradient = False + y = F.conv3d_transpose( + input, + filter, + self.bias, + padding=self.padding, + stride=self.stride, + dilation=self.dilation, + groups=self.groups, + data_format=self.data_format, + ) + np.testing.assert_allclose(y.numpy(), self.np_out) + loss = y.sum() + loss.backward() + np.testing.assert_allclose(input.grad.shape, input.shape) + np.testing.assert_allclose(filter.grad, np.zeros(filter.shape)) + + +class TestFunctionalConv3DTranspose_ZeroSize2( + TestFunctionalConv3DTranspose_ZeroSize +): + def init_data(self): + self.input = np.random.randn(2, 3, 1, 1, 1) + self.filter = np.random.randn(3, 0, 3, 3, 3) + self.np_out = np.zeros([2, 0, 3, 3, 3]) + + if __name__ == "__main__": paddle.enable_static() unittest.main()