Skip to content

Commit 31b01e3

Browse files
authored
[0-size Tensor No.160、162、164] Add 0-size Tensor support for conv1d_transpose (#73698)
* Fix * Fix
1 parent 1d60828 commit 31b01e3

12 files changed

+291
-21
lines changed

paddle/phi/infermeta/binary.cc

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -888,19 +888,23 @@ void ConvTransposeInferMeta(const MetaTensor& x,
888888
common::make_ddim(output_size).to_str(),
889889
i,
890890
infer_shape));
891-
PADDLE_ENFORCE_LT(
892-
output_size[i],
893-
infer_shape + strides[i],
894-
errors::InvalidArgument(
895-
"output_size of Op(ConvTransposeOp) should be less "
896-
"than inferred size + stride. But received output_size = [%s], "
897-
"whose dim %d is not less than the inferred output size (%d) + "
898-
"stride (%d) = %d",
899-
common::make_ddim(output_size).to_str(),
900-
i,
901-
infer_shape,
902-
strides[i],
903-
infer_shape + strides[i]));
891+
if (common::product(x_dims) != 0) {
892+
PADDLE_ENFORCE_LT(
893+
output_size[i],
894+
infer_shape + strides[i],
895+
errors::InvalidArgument(
896+
"output_size of Op(ConvTransposeOp) should be less "
897+
"than inferred size + stride. But received output_size = "
898+
"[%s], "
899+
"whose dim %d is not less than the inferred output size (%d) "
900+
"+ "
901+
"stride (%d) = %d",
902+
common::make_ddim(output_size).to_str(),
903+
i,
904+
infer_shape,
905+
strides[i],
906+
infer_shape + strides[i]));
907+
}
904908
}
905909
output_shape.push_back(output_size[i]);
906910
} else if (!output_padding.empty()) {

paddle/phi/kernels/gpu/conv_transpose_grad_kernel.cu

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "paddle/common/layout.h"
1919
#include "paddle/phi/core/kernel_registry.h"
2020
#include "paddle/phi/kernels/cpu/conv_util.h"
21+
#include "paddle/phi/kernels/full_kernel.h"
2122
#include "paddle/phi/kernels/funcs/math_function.h"
2223
#include "paddle/phi/kernels/gpu/depthwise_conv.h"
2324
#include "paddle/phi/kernels/impl/conv_transpose_grad_kernel_impl.h"
@@ -77,7 +78,25 @@ void DepthwiseConv2dTransposeGradKernel(const Context& dev_ctx,
7778
if (!dx && !dfilter) {
7879
return;
7980
}
80-
81+
// 0-size
82+
if (x.numel() == 0) {
83+
if (dx) dev_ctx.template Alloc<T>(dx);
84+
if (dfilter) {
85+
phi::Full<T, Context>(dev_ctx,
86+
phi::IntArray(common::vectorize(dfilter->dims())),
87+
0,
88+
dfilter);
89+
}
90+
return;
91+
}
92+
if (filter.numel() == 0) {
93+
if (dfilter) dev_ctx.template Alloc<T>(dfilter);
94+
if (dx) {
95+
phi::Full<T, Context>(
96+
dev_ctx, phi::IntArray(common::vectorize(dx->dims())), 0, dx);
97+
}
98+
return;
99+
}
81100
std::vector<int> paddings_ = paddings;
82101
std::vector<int> dilations_ = dilations;
83102

paddle/phi/kernels/gpu/conv_transpose_kernel.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ void DepthwiseConv2dTransposeKernel(const Context& dev_ctx,
3737
const std::vector<int>& dilations,
3838
const std::string& data_format,
3939
DenseTensor* out) {
40+
if (x.numel() == 0 || filter.numel() == 0) {
41+
phi::Full<T, Context>(
42+
dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out);
43+
return;
44+
}
4045
const DataLayout data_layout = common::StringToDataLayout(data_format);
4146
DenseTensor filter_ = filter;
4247
dev_ctx.template Alloc<T>(out);

paddle/phi/kernels/gpudnn/conv_transpose_grad_kernel.cu

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ limitations under the License. */
3737
#include "paddle/phi/backends/gpu/cuda/cudnn_workspace_helper.h"
3838
#include "paddle/phi/kernels/gpudnn/conv_cudnn_v7.h"
3939
#endif
40+
#include "paddle/phi/kernels/full_kernel.h"
4041

4142
namespace phi {
4243

@@ -55,6 +56,26 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& dev_ctx,
5556
const std::string& data_format,
5657
DenseTensor* dx,
5758
DenseTensor* dfilter) {
59+
// 0-size
60+
if (x.numel() == 0) {
61+
if (dx) dev_ctx.template Alloc<T>(dx);
62+
if (dfilter) {
63+
phi::Full<T, Context>(dev_ctx,
64+
phi::IntArray(common::vectorize(dfilter->dims())),
65+
0,
66+
dfilter);
67+
}
68+
return;
69+
}
70+
if (filter.numel() == 0) {
71+
if (dfilter) dev_ctx.template Alloc<T>(dfilter);
72+
if (dx) {
73+
phi::Full<T, Context>(
74+
dev_ctx, phi::IntArray(common::vectorize(dx->dims())), 0, dx);
75+
}
76+
return;
77+
}
78+
5879
const T* filter_data = filter.data<T>();
5980
std::vector<int> paddings_ = paddings;
6081
std::vector<int> dilations_ =

paddle/phi/kernels/gpudnn/conv_transpose_kernel.cu

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ limitations under the License. */
3535
#include "paddle/phi/backends/gpu/cuda/cudnn_workspace_helper.h"
3636
#include "paddle/phi/kernels/gpudnn/conv_cudnn_v7.h"
3737
#endif
38+
#include "paddle/phi/kernels/full_kernel.h"
3839

3940
namespace phi {
4041

@@ -51,6 +52,11 @@ void ConvTransposeRawGPUDNNKernel(const Context& dev_ctx,
5152
const std::vector<int>& dilations,
5253
const std::string& data_format,
5354
DenseTensor* out) {
55+
if (x.numel() == 0 || filter.numel() == 0) {
56+
phi::Full<T, Context>(
57+
dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out);
58+
return;
59+
}
5460
std::vector<int> paddings_ = paddings;
5561
std::vector<int> dilations_ =
5662
dilations; // cudnn v5 does not support dilations

paddle/phi/kernels/impl/conv_transpose_grad_kernel_impl.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "paddle/common/layout.h"
1919
#include "paddle/phi/kernels/conv_transpose_grad_kernel.h"
2020
#include "paddle/phi/kernels/cpu/conv_util.h"
21+
#include "paddle/phi/kernels/full_kernel.h"
2122
#include "paddle/phi/kernels/funcs/blas/blas.h"
2223
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
2324
#include "paddle/phi/kernels/funcs/im2col.h"
@@ -48,6 +49,26 @@ void ConvTransposeGradRawKernel(const Context& dev_ctx,
4849
return;
4950
}
5051

52+
// 0-size
53+
if (x.numel() == 0) {
54+
if (dx) dev_ctx.template Alloc<T>(dx);
55+
if (dfilter) {
56+
phi::Full<T, Context>(dev_ctx,
57+
phi::IntArray(common::vectorize(dfilter->dims())),
58+
0,
59+
dfilter);
60+
}
61+
return;
62+
}
63+
if (filter.numel() == 0) {
64+
if (dfilter) dev_ctx.template Alloc<T>(dfilter);
65+
if (dx) {
66+
phi::Full<T, Context>(
67+
dev_ctx, phi::IntArray(common::vectorize(dx->dims())), 0, dx);
68+
}
69+
return;
70+
}
71+
5172
std::vector<int> paddings_ = paddings;
5273
std::vector<int> dilations_ = dilations;
5374

paddle/phi/kernels/impl/conv_transpose_kernel_impl.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "paddle/common/layout.h"
1919
#include "paddle/phi/kernels/conv_transpose_kernel.h"
2020
#include "paddle/phi/kernels/cpu/conv_util.h"
21+
#include "paddle/phi/kernels/full_kernel.h"
2122
#include "paddle/phi/kernels/funcs/blas/blas.h"
2223
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
2324
#include "paddle/phi/kernels/funcs/im2col.h"
@@ -37,6 +38,11 @@ void ConvTransposeRawKernel(const Context& dev_ctx,
3738
const std::vector<int>& dilations,
3839
const std::string& data_format,
3940
DenseTensor* out) {
41+
if (x.numel() == 0 || filter.numel() == 0) {
42+
phi::Full<T, Context>(
43+
dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out);
44+
return;
45+
}
4046
const DataLayout data_layout = common::StringToDataLayout(data_format);
4147
// The filter will be reshaped, so it should not be constant
4248
DenseTensor filter_ = filter;

paddle/phi/kernels/xpu/conv_transpose_grad_kernel.cc

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "paddle/phi/backends/xpu/enforce_xpu.h"
1818
#include "paddle/phi/core/kernel_registry.h"
1919
#include "paddle/phi/kernels/cpu/conv_util.h"
20+
#include "paddle/phi/kernels/full_kernel.h"
2021
#include "paddle/phi/kernels/xpu/xpu_api_wrapper.h"
2122

2223
namespace phi {
@@ -40,7 +41,25 @@ void Conv2dTransposeGradKernel(const Context& dev_ctx,
4041
// that avoids modifying the variable in the Scope.
4142
DenseTensor filter_ = filter;
4243
if (!dx && !dfilter) return;
43-
44+
// 0-size
45+
if (x.numel() == 0) {
46+
if (dx) dev_ctx.template Alloc<T>(dx);
47+
if (dfilter) {
48+
phi::Full<T, Context>(dev_ctx,
49+
phi::IntArray(common::vectorize(dfilter->dims())),
50+
0,
51+
dfilter);
52+
}
53+
return;
54+
}
55+
if (filter.numel() == 0) {
56+
if (dfilter) dev_ctx.template Alloc<T>(dfilter);
57+
if (dx) {
58+
phi::Full<T, Context>(
59+
dev_ctx, phi::IntArray(common::vectorize(dx->dims())), 0, dx);
60+
}
61+
return;
62+
}
4463
std::vector<int64_t> strides_ =
4564
std::vector<int64_t>(strides.begin(), strides.end());
4665
std::vector<int64_t> paddings_ =

paddle/phi/kernels/xpu/conv_transpose_kernel.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "paddle/phi/backends/xpu/enforce_xpu.h"
2020
#include "paddle/phi/core/kernel_registry.h"
2121
#include "paddle/phi/kernels/cpu/conv_util.h"
22+
#include "paddle/phi/kernels/full_kernel.h"
2223
#include "paddle/phi/kernels/xpu/conv_utils_xpu.h"
2324
#include "paddle/phi/kernels/xpu/xpu_api_wrapper.h"
2425
#ifdef PADDLE_WITH_XPU_XRE5
@@ -41,7 +42,11 @@ void Conv2dTransposeKernel(const Context& dev_ctx,
4142
const std::string& data_format,
4243
DenseTensor* out) {
4344
using XPUType = typename XPUTypeTrait<T>::Type;
44-
45+
if (x.numel() == 0 || filter.numel() == 0) {
46+
phi::Full<T, Context>(
47+
dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out);
48+
return;
49+
}
4550
dev_ctx.template Alloc<T>(out);
4651

4752
PADDLE_ENFORCE_EQ(

test/legacy_test/test_conv2d_transpose_op.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414

1515
import os
1616
import unittest
17+
from unittest import TestCase
1718

1819
import numpy as np
1920

2021
import paddle
22+
import paddle.base.dygraph as dg
2123
import paddle.static
2224
from paddle import nn
2325

@@ -1519,5 +1521,62 @@ def call_func(self, x):
15191521
return out
15201522

15211523

1524+
class TestFunctionalConv2DTranspose_ZeroSize(TestCase):
1525+
def init_data(self):
1526+
self.input = np.random.randn(0, 4, 16, 4)
1527+
self.filter = np.random.randn(4, 3, 3, 3)
1528+
self.np_out = np.zeros([0, 3, 18, 6])
1529+
1530+
def setUp(self):
1531+
self.init_data()
1532+
self.bias = None
1533+
self.padding = 0
1534+
self.stride = 1
1535+
self.dilation = 1
1536+
self.groups = 1
1537+
self.data_format = "NCHW"
1538+
self.places = []
1539+
if (
1540+
os.environ.get('FLAGS_CI_both_cpu_and_gpu', 'False').lower()
1541+
in ['1', 'true', 'on']
1542+
or not base.core.is_compiled_with_cuda()
1543+
):
1544+
self.places.append(base.CPUPlace())
1545+
if base.core.is_compiled_with_cuda():
1546+
self.places.append(base.CUDAPlace(0))
1547+
1548+
def test_dygraph(self):
1549+
for place in self.places:
1550+
with dg.guard(place):
1551+
input = paddle.to_tensor(self.input)
1552+
input.stop_gradient = False
1553+
filter = paddle.to_tensor(self.filter)
1554+
filter.stop_gradient = False
1555+
y = paddle.nn.functional.conv2d_transpose(
1556+
input,
1557+
filter,
1558+
self.bias,
1559+
padding=self.padding,
1560+
stride=self.stride,
1561+
dilation=self.dilation,
1562+
groups=self.groups,
1563+
data_format=self.data_format,
1564+
)
1565+
np.testing.assert_allclose(y.numpy(), self.np_out)
1566+
loss = y.sum()
1567+
loss.backward()
1568+
np.testing.assert_allclose(input.grad.shape, input.shape)
1569+
np.testing.assert_allclose(filter.grad, np.zeros(filter.shape))
1570+
1571+
1572+
class TestFunctionalConv2DTranspose_ZeroSize2(
1573+
TestFunctionalConv2DTranspose_ZeroSize
1574+
):
1575+
def init_data(self):
1576+
self.input = np.random.randn(4, 5, 3, 3)
1577+
self.filter = np.random.randn(5, 0, 4, 4)
1578+
self.np_out = np.zeros([4, 0, 6, 6])
1579+
1580+
15221581
if __name__ == '__main__':
15231582
unittest.main()

0 commit comments

Comments
 (0)