Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions paddle/phi/kernels/onednn/reshape_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ limitations under the License. */

#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/full_kernel.h"

namespace phi {

Expand All @@ -19,6 +20,12 @@ void ReshapeGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
if ((x_grad && x_grad->numel() == 0) || out_grad.numel() == 0) {
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(x_grad->dims())), 0, x_grad);
return;
}

auto out_grad_vec_dims = out_grad.dims().size() != 0
? common::vectorize(out_grad.dims())
: std::vector<int64_t>{1};
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/kernels/onednn/reshape_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ void ReshapeKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& shape,
DenseTensor* out) {
if (x.numel() == 0) {
dev_ctx.Alloc(out, x.dtype());
}
auto x_dims = x.dims();
ExecuteReshape<T, Context>(dev_ctx, x, shape, x_dims, out);
}
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/kernels/reshape_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ void ReshapeGradKernel<phi::XPUContext>(const XPUContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
if (x_grad->numel() == 0) {
dev_ctx.Alloc(x_grad, x_grad->dtype());
return;
}
auto x_dims = x_grad->dims();
dev_ctx.Alloc(x_grad, out_grad.dtype());
auto* src_ptr = out_grad.data();
Expand Down
23 changes: 22 additions & 1 deletion test/legacy_test/test_reshape_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,17 @@
import unittest

import numpy as np
from op_test import OpTest, convert_float_to_uint16, skip_check_grad_ci
from op_test import (
OpTest,
OpTestTool,
convert_float_to_uint16,
skip_check_grad_ci,
)

import paddle
from paddle import base
from paddle.base import core
from paddle.base.framework import _current_expected_place
from paddle.static import Program, program_guard


Expand Down Expand Up @@ -92,6 +99,20 @@ def init_data(self):
self.inferred_shape = ()


@OpTestTool.skip_if(
not (isinstance(_current_expected_place(), core.CPUPlace)),
"GPU is not supported",
)
class TestReshapeOp_ZeroDim4(OpTest):
def init_kernel_type(self):
self.use_onednn = True

def init_data(self):
self.ori_shape = (1,)
self.new_shape = ()
self.inferred_shape = ()


class TestReshapeOp_ZeroSize(OpTest):
def init_data(self):
self.ori_shape = (0, 2)
Expand Down