Skip to content

Commit e4d5ffe

Browse files
zhengshengningco63oc
authored andcommitted
[0-size Tensor Job2 No.99] Add 0-size Tensor support for take_along_axis (PaddlePaddle#73832)
* fix * add test of TakeAlongAxis * fix test case
1 parent 98989ae commit e4d5ffe

File tree

7 files changed

+59
-0
lines changed

7 files changed

+59
-0
lines changed

paddle/phi/kernels/cpu/take_along_axis_grad_kernel.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ void TakeAlongAxisGradKernel(const Context& dev_ctx,
3434
x_grad->Resize(x.dims());
3535
dev_ctx.template Alloc<T>(x_grad);
3636

37+
if (x_grad->numel() == 0) {
38+
return;
39+
}
40+
3741
// Set to zero tensor.
3842
phi::funcs::SetConstant<Context, T> functor;
3943
functor(dev_ctx, x_grad, static_cast<T>(0));

paddle/phi/kernels/cpu/take_along_axis_kernel.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ void TakeAlongAxisKernel(const Context& dev_ctx,
4242
out->Resize(index.dims());
4343
dev_ctx.template Alloc<T>(out);
4444

45+
if (out->numel() == 0) {
46+
return;
47+
}
48+
4549
const auto& index_type = index.dtype();
4650
if (index_type == DataType::INT32) {
4751
phi::funcs::cpu_gather_kernel<T, int32_t>(

paddle/phi/kernels/gpu/take_along_axis_grad_kernel.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ void TakeAlongAxisGradKernel(const Context& dev_ctx,
3535
x_grad->Resize(x.dims());
3636
dev_ctx.template Alloc<T>(x_grad);
3737

38+
if (x_grad->numel() == 0) {
39+
return;
40+
}
41+
3842
// Set to zero tensor.
3943
phi::funcs::SetConstant<Context, T> functor;
4044
functor(dev_ctx, x_grad, static_cast<T>(0));

paddle/phi/kernels/gpu/take_along_axis_kernel.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ void TakeAlongAxisKernel(const Context& dev_ctx,
4242
out->Resize(index.dims());
4343
dev_ctx.template Alloc<T>(out);
4444

45+
if (out->numel() == 0) {
46+
return;
47+
}
48+
4549
const auto& index_type = index.dtype();
4650
if (index_type == DataType::INT32) {
4751
phi::funcs::gpu_gather_kernel<T, int32_t>(

paddle/phi/kernels/xpu/take_along_axis_grad_kernel.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ void TakeAlongAxisGradKernel(const Context& dev_ctx,
2929
using XPUType = typename XPUTypeTrait<T>::Type;
3030
dev_ctx.template Alloc<T>(x_grad);
3131

32+
if (x_grad->numel() == 0) {
33+
return;
34+
}
35+
3236
const auto& index_dtype = index.dtype();
3337
bool index_dtype_match =
3438
index_dtype == DataType::INT32 || index_dtype == DataType::INT64;

paddle/phi/kernels/xpu/take_along_axis_kernel.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ void TakeAlongAxisKernel(const Context& dev_ctx,
4242
out->Resize(index.dims());
4343
dev_ctx.template Alloc<T>(out);
4444

45+
if (out->numel() == 0) {
46+
return;
47+
}
48+
4549
if (x.numel() == 0 || index.numel() == 0) return;
4650

4751
const auto& index_dtype = index.dtype();

test/legacy_test/test_take_along_axis_op.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,41 @@ def test_static_shape_take_along_axis(self):
428428
_ = static_f(x, ind, axis=0, broadcast=False)
429429

430430

431+
class TestTakeAlongAxis_ZeroSize(OpTest):
432+
def setUp(self):
433+
self.python_api = paddle.take_along_axis
434+
self.op_type = "take_along_axis"
435+
self.dtype = "float64"
436+
self.check_pir = True
437+
438+
x = np.zeros((2, 0, 5)).astype(self.dtype)
439+
indices = np.zeros((2, 3, 5)).astype("int64")
440+
441+
self.inputs = {'Input': x, 'Index': indices}
442+
self.attrs = {'Axis': 1}
443+
444+
output = np.zeros((2, 3, 5)).astype(self.dtype)
445+
self.outputs = {'Result': output}
446+
447+
def test_check_output(self):
448+
self.check_output_with_place(
449+
paddle.CPUPlace(), check_pir=self.check_pir
450+
)
451+
if core.is_compiled_with_cuda():
452+
self.check_output_with_place(
453+
core.CUDAPlace(0), check_pir=self.check_pir
454+
)
455+
456+
def test_check_grad(self):
457+
self.check_grad_with_place(
458+
paddle.CPUPlace(), ['Input'], 'Result', check_pir=self.check_pir
459+
)
460+
if core.is_compiled_with_cuda():
461+
self.check_grad_with_place(
462+
core.CUDAPlace(0), ['Input'], 'Result', check_pir=self.check_pir
463+
)
464+
465+
431466
if __name__ == "__main__":
432467
paddle.enable_static()
433468
unittest.main()

0 commit comments

Comments
 (0)