Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions paddle/phi/kernels/impl/channel_shuffle_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ void ChannelShuffleGradKernel(const Context& dev_ctx,
auto* dout = &out_grad;
auto* dx = x_grad;
dev_ctx.template Alloc<T>(dx);
if (dx && dx->numel() == 0) {
return;
}
bool channel_last = (data_format == "NHWC");
const auto& do_dims = dout->dims();
const auto& dx_dims = dx->dims();
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/kernels/impl/channel_shuffle_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ void ChannelShuffleKernel(const Context& dev_ctx,
DenseTensor* out) {
auto* in = &x;
dev_ctx.template Alloc<T>(out);
if (out && out->numel() == 0) {
return;
}
bool channel_last = (data_format == "NHWC");
const auto& in_dims = in->dims();
const auto& o_dims = out->dims();
Expand Down
11 changes: 10 additions & 1 deletion test/legacy_test/test_channel_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@ class TestChannelShuffleOp(OpTest):
def setUp(self):
self.op_type = "channel_shuffle"
self.init_dtype()
self.init_shape()
self.init_data_format()
n, c, h, w = 2, 9, 4, 4
n, c, h, w = self.shape
self.python_api = paddle.nn.functional.channel_shuffle

if self.format == "NCHW":
Expand All @@ -63,6 +64,9 @@ def setUp(self):
self.outputs = {'Out': npresult}
self.attrs = {'groups': groups, "data_format": self.format}

def init_shape(self):
self.shape = [2, 9, 4, 4]

def init_dtype(self):
self.dtype = 'float64'

Expand All @@ -81,6 +85,11 @@ def init_data_format(self):
self.format = "NHWC"


class TestChannelLast_ZeroSize(TestChannelShuffleOp):
def init_shape(self):
self.shape = [2, 9, 0, 4]


class TestChannelShuffleAPI(unittest.TestCase):
def setUp(self):
self.x_2_np = np.random.random([2, 4, 4, 9]).astype("float64")
Expand Down