diff --git a/paddle/phi/kernels/impl/channel_shuffle_grad_kernel_impl.h b/paddle/phi/kernels/impl/channel_shuffle_grad_kernel_impl.h index b8406b91431033..98a1875b6442cf 100644 --- a/paddle/phi/kernels/impl/channel_shuffle_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/channel_shuffle_grad_kernel_impl.h @@ -31,6 +31,9 @@ void ChannelShuffleGradKernel(const Context& dev_ctx, auto* dout = &out_grad; auto* dx = x_grad; dev_ctx.template Alloc(dx); + if (dx && dx->numel() == 0) { + return; + } bool channel_last = (data_format == "NHWC"); const auto& do_dims = dout->dims(); const auto& dx_dims = dx->dims(); diff --git a/paddle/phi/kernels/impl/channel_shuffle_kernel_impl.h b/paddle/phi/kernels/impl/channel_shuffle_kernel_impl.h index 7e31e02851591e..20b76d12ccebbf 100644 --- a/paddle/phi/kernels/impl/channel_shuffle_kernel_impl.h +++ b/paddle/phi/kernels/impl/channel_shuffle_kernel_impl.h @@ -30,6 +30,9 @@ void ChannelShuffleKernel(const Context& dev_ctx, DenseTensor* out) { auto* in = &x; dev_ctx.template Alloc(out); + if (out && out->numel() == 0) { + return; + } bool channel_last = (data_format == "NHWC"); const auto& in_dims = in->dims(); const auto& o_dims = out->dims(); diff --git a/test/legacy_test/test_channel_shuffle.py b/test/legacy_test/test_channel_shuffle.py index 3b480e34eee985..eee123e9d507cb 100644 --- a/test/legacy_test/test_channel_shuffle.py +++ b/test/legacy_test/test_channel_shuffle.py @@ -45,8 +45,9 @@ class TestChannelShuffleOp(OpTest): def setUp(self): self.op_type = "channel_shuffle" self.init_dtype() + self.init_shape() self.init_data_format() - n, c, h, w = 2, 9, 4, 4 + n, c, h, w = self.shape self.python_api = paddle.nn.functional.channel_shuffle if self.format == "NCHW": @@ -63,6 +64,9 @@ def setUp(self): self.outputs = {'Out': npresult} self.attrs = {'groups': groups, "data_format": self.format} + def init_shape(self): + self.shape = [2, 9, 4, 4] + def init_dtype(self): self.dtype = 'float64' @@ -81,6 +85,11 @@ def init_data_format(self): self.format = "NHWC" +class TestChannelLast_ZeroSize(TestChannelShuffleOp): + def init_shape(self): + self.shape = [2, 9, 0, 4] + + class TestChannelShuffleAPI(unittest.TestCase): def setUp(self): self.x_2_np = np.random.random([2, 4, 4, 9]).astype("float64")