Skip to content

Commit b78a40c

Browse files
authored
[0-size Tensor Job2 No.47、49、88、90]Add 0-size Tensor support for amin (#73333)
* Fix * Fix * Fix * Fix
1 parent c83a7c1 commit b78a40c

File tree

6 files changed

+39
-5
lines changed

6 files changed

+39
-5
lines changed

paddle/phi/kernels/cpu/reduce_amax_grad_kernel.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ void ReduceAMaxGradKernel(const Context& dev_ctx,
2828
bool keep_dim,
2929
bool reduce_all,
3030
DenseTensor* x_grad) {
31+
if (x_grad && x_grad->numel() == 0) {
32+
dev_ctx.template Alloc<T>(x_grad);
33+
return;
34+
}
3135
reduce_all = recompute_reduce_all(x, dims, reduce_all);
3236
ReduceGradKernel<Context, T, funcs::AMaxOrAMinGradFunctor>(
3337
dev_ctx, x, out, out_grad, dims, keep_dim, reduce_all, x_grad);

paddle/phi/kernels/cpu/reduce_amin_grad_kernel.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ void ReduceAMinGradKernel(const Context& dev_ctx,
2828
bool keep_dim,
2929
bool reduce_all,
3030
DenseTensor* x_grad) {
31+
if (x_grad && x_grad->numel() == 0) {
32+
dev_ctx.template Alloc<T>(x_grad);
33+
return;
34+
}
3135
reduce_all = recompute_reduce_all(x, dims, reduce_all);
3236
ReduceGradKernel<Context, T, funcs::AMaxOrAMinGradFunctor>(
3337
dev_ctx, x, out, out_grad, dims, keep_dim, reduce_all, x_grad);

paddle/phi/kernels/funcs/reduce_function.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1460,12 +1460,10 @@ void ReduceKernelImpl(const Context& dev_ctx,
14601460
const std::vector<int64_t>& dims,
14611461
bool keep_dim,
14621462
bool reduce_all) {
1463-
PADDLE_ENFORCE_GT(input.numel(),
1464-
0,
1465-
common::errors::InvalidArgument(
1466-
"Tensor need be reduced must not empty."));
1467-
14681463
dev_ctx.template Alloc<OutT>(output);
1464+
if (input.numel() == 0) {
1465+
return;
1466+
}
14691467

14701468
if (reduce_all) {
14711469
// Flatten and reduce 1-D tensor

paddle/phi/kernels/xpu/reduce_max_grad_kernel.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ void ReduceMaxGradKernel(const Context& dev_ctx,
3131
bool keep_dim,
3232
bool reduce_all,
3333
DenseTensor* x_grad) {
34+
if (x_grad && x_grad->numel() == 0) {
35+
dev_ctx.template Alloc<T>(x_grad);
36+
return;
37+
}
3438
using XPUDataType = typename XPUTypeTrait<T>::Type;
3539
reduce_all = recompute_reduce_all(x, dims_arr, reduce_all);
3640
auto dims = dims_arr.GetData();

paddle/phi/kernels/xpu/reduce_min_grad_kernel.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ void ReduceMinGradKernel(const Context& dev_ctx,
3131
bool keep_dim,
3232
bool reduce_all,
3333
DenseTensor* x_grad) {
34+
if (x_grad && x_grad->numel() == 0) {
35+
dev_ctx.template Alloc<T>(x_grad);
36+
return;
37+
}
3438
reduce_all = recompute_reduce_all(x, dims_arr, reduce_all);
3539
auto dims = dims_arr.GetData();
3640

test/legacy_test/test_max_min_amax_amin_op.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def _choose_paddle_func(self, func, x):
9292

9393
def test_static_graph(self):
9494
def _test_static_graph(func):
95+
paddle.enable_static()
9596
startup_program = base.Program()
9697
train_program = base.Program()
9798
with base.program_guard(startup_program, train_program):
@@ -107,6 +108,7 @@ def _test_static_graph(func):
107108
fetch_list=[out],
108109
)
109110
self.assertTrue((np.array(res[0]) == self.np_out[func]).all())
111+
paddle.disable_static()
110112

111113
_test_static_graph('amax')
112114
_test_static_graph('amin')
@@ -264,5 +266,23 @@ def _test_dygraph(func):
264266
_test_dygraph('min')
265267

266268

269+
class TestMaxMinAmaxAminAPI_ZeroSize(TestMaxMinAmaxAminAPI):
270+
def init_case(self):
271+
self.x_np = np.random.randn(1, 0, 10).astype(np.float32)
272+
self.shape = [1, 0, 10]
273+
self.dtype = 'float32'
274+
self.axis = 0
275+
self.keepdim = False
276+
277+
278+
class TestMaxMinAmaxAminAPI_ZeroSize2(TestMaxMinAmaxAminAPI):
279+
def init_case(self):
280+
self.x_np = np.random.randn(1, 0, 10).astype(np.float32)
281+
self.shape = [1, 0, 10]
282+
self.dtype = 'float32'
283+
self.axis = -1
284+
self.keepdim = True
285+
286+
267287
if __name__ == '__main__':
268288
unittest.main()

0 commit comments

Comments
 (0)