Skip to content

Commit 9b40438

Browse files
[0-size Tensor Job2 No.96] Add 0-size Tensor support for reshape (#73844)
* fix 0-size reshape func * add onednn reshape test case
1 parent 1f27316 commit 9b40438

File tree

4 files changed

+36
-1
lines changed

4 files changed

+36
-1
lines changed

paddle/phi/kernels/onednn/reshape_grad_kernel.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ limitations under the License. */
1111

1212
#include "paddle/phi/backends/onednn/onednn_reuse.h"
1313
#include "paddle/phi/core/kernel_registry.h"
14+
#include "paddle/phi/kernels/full_kernel.h"
1415

1516
namespace phi {
1617

@@ -19,6 +20,12 @@ void ReshapeGradKernel(const Context& dev_ctx,
1920
const DenseTensor& x,
2021
const DenseTensor& out_grad,
2122
DenseTensor* x_grad) {
23+
if ((x_grad && x_grad->numel() == 0) || out_grad.numel() == 0) {
24+
phi::Full<T, Context>(
25+
dev_ctx, phi::IntArray(common::vectorize(x_grad->dims())), 0, x_grad);
26+
return;
27+
}
28+
2229
auto out_grad_vec_dims = out_grad.dims().size() != 0
2330
? common::vectorize(out_grad.dims())
2431
: std::vector<int64_t>{1};

paddle/phi/kernels/onednn/reshape_kernel.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,9 @@ void ReshapeKernel(const Context& dev_ctx,
154154
const DenseTensor& x,
155155
const IntArray& shape,
156156
DenseTensor* out) {
157+
if (x.numel() == 0) {
158+
dev_ctx.Alloc(out, x.dtype());
159+
}
157160
auto x_dims = x.dims();
158161
ExecuteReshape<T, Context>(dev_ctx, x, shape, x_dims, out);
159162
}

paddle/phi/kernels/reshape_grad_kernel.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ void ReshapeGradKernel<phi::XPUContext>(const XPUContext& dev_ctx,
4646
const DenseTensor& x,
4747
const DenseTensor& out_grad,
4848
DenseTensor* x_grad) {
49+
if (x_grad->numel() == 0) {
50+
dev_ctx.Alloc(x_grad, x_grad->dtype());
51+
return;
52+
}
4953
auto x_dims = x_grad->dims();
5054
dev_ctx.Alloc(x_grad, out_grad.dtype());
5155
auto* src_ptr = out_grad.data();

test/legacy_test/test_reshape_op.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,17 @@
1515
import unittest
1616

1717
import numpy as np
18-
from op_test import OpTest, convert_float_to_uint16, skip_check_grad_ci
18+
from op_test import (
19+
OpTest,
20+
OpTestTool,
21+
convert_float_to_uint16,
22+
skip_check_grad_ci,
23+
)
1924

2025
import paddle
2126
from paddle import base
27+
from paddle.base import core
28+
from paddle.base.framework import _current_expected_place
2229
from paddle.static import Program, program_guard
2330

2431

@@ -92,6 +99,20 @@ def init_data(self):
9299
self.inferred_shape = ()
93100

94101

102+
@OpTestTool.skip_if(
103+
not (isinstance(_current_expected_place(), core.CPUPlace)),
104+
"GPU is not supported",
105+
)
106+
class TestReshapeOp_ZeroDim4(OpTest):
107+
def init_kernel_type(self):
108+
self.use_onednn = True
109+
110+
def init_data(self):
111+
self.ori_shape = (1,)
112+
self.new_shape = ()
113+
self.inferred_shape = ()
114+
115+
95116
class TestReshapeOp_ZeroSize(OpTest):
96117
def init_data(self):
97118
self.ori_shape = (0, 2)

0 commit comments

Comments
 (0)