Skip to content

Commit 79bcce1

Browse files
authored
[0-size Tensor No.271] Add 0-size Tensor support for paddle.take_along_axis API. (#73736)
1 parent f9e28c3 commit 79bcce1

File tree

4 files changed

+85
-0
lines changed

4 files changed

+85
-0
lines changed

paddle/phi/kernels/cpu/take_along_axis_kernel.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "paddle/phi/common/data_type.h"
1919
#include "paddle/phi/common/place.h"
2020
#include "paddle/phi/core/kernel_registry.h"
21+
#include "paddle/phi/kernels/full_kernel.h"
2122
#include "paddle/phi/kernels/funcs/gather_scatter_functor.h"
2223

2324
namespace phi {
@@ -28,6 +29,16 @@ void TakeAlongAxisKernel(const Context& dev_ctx,
2829
const DenseTensor& index,
2930
int axis,
3031
DenseTensor* out) {
32+
if (index.numel() == 0) {
33+
dev_ctx.template Alloc<T>(out);
34+
return;
35+
}
36+
if (x.numel() == 0) {
37+
phi::Full<T, Context>(
38+
dev_ctx, common::vectorize(out->dims()), static_cast<T>(0), out);
39+
return;
40+
}
41+
3142
out->Resize(index.dims());
3243
dev_ctx.template Alloc<T>(out);
3344

paddle/phi/kernels/gpu/take_along_axis_kernel.cu

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "paddle/phi/common/place.h"
1919
#include "paddle/phi/core/kernel_registry.h"
2020
#include "paddle/phi/core/utils/data_type.h"
21+
#include "paddle/phi/kernels/full_kernel.h"
2122
#include "paddle/phi/kernels/funcs/gather_scatter_functor.h"
2223

2324
namespace phi {
@@ -28,6 +29,16 @@ void TakeAlongAxisKernel(const Context& dev_ctx,
2829
const DenseTensor& index,
2930
int axis,
3031
DenseTensor* out) {
32+
if (index.numel() == 0) {
33+
dev_ctx.template Alloc<T>(out);
34+
return;
35+
}
36+
if (x.numel() == 0) {
37+
phi::Full<T, Context>(
38+
dev_ctx, common::vectorize(out->dims()), static_cast<T>(0), out);
39+
return;
40+
}
41+
3142
out->Resize(index.dims());
3243
dev_ctx.template Alloc<T>(out);
3344

paddle/phi/kernels/xpu/take_along_axis_kernel.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "paddle/common/layout.h"
2020
#include "paddle/phi/backends/xpu/enforce_xpu.h"
2121
#include "paddle/phi/core/kernel_registry.h"
22+
#include "paddle/phi/kernels/full_kernel.h"
2223

2324
namespace phi {
2425

@@ -28,6 +29,16 @@ void TakeAlongAxisKernel(const Context& dev_ctx,
2829
const DenseTensor& index,
2930
int axis,
3031
DenseTensor* out) {
32+
if (index.numel() == 0) {
33+
dev_ctx.template Alloc<T>(out);
34+
return;
35+
}
36+
if (x.numel() == 0) {
37+
phi::Full<T, Context>(
38+
dev_ctx, common::vectorize(out->dims()), static_cast<T>(0), out);
39+
return;
40+
}
41+
3142
out->Resize(index.dims());
3243
dev_ctx.template Alloc<T>(out);
3344

test/legacy_test/test_take_along_axis_op.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,58 @@
2525
paddle.enable_static()
2626

2727

28+
class TestTakeAlongAxis0Size(OpTest):
29+
def setUp(self):
30+
self.python_api = paddle.take_along_axis
31+
self.op_type = "take_along_axis"
32+
self.dtype = "float64"
33+
self.check_pir = True
34+
35+
x = np.zeros((2, 0, 5)).astype(self.dtype)
36+
indices = np.zeros((2, 3, 5)).astype("int64")
37+
38+
self.inputs = {'Input': x, 'Index': indices}
39+
self.attrs = {'Axis': 1}
40+
41+
output = np.zeros((2, 3, 5)).astype(self.dtype)
42+
self.outputs = {'Result': output}
43+
44+
def test_check_output(self):
45+
self.check_output(check_pir=self.check_pir)
46+
47+
def test_check_grad(self):
48+
self.check_grad(['Input'], 'Result', check_pir=self.check_pir)
49+
50+
51+
class TestTakeAlongAxis0Size2(OpTest):
52+
def setUp(self):
53+
self.python_api = paddle.take_along_axis
54+
self.op_type = "take_along_axis"
55+
self.dtype = "float64"
56+
self.check_pir = True
57+
58+
x = np.random.rand(2, 3, 5).astype(self.dtype)
59+
indices = np.zeros((2, 0, 5)).astype("int64")
60+
61+
self.inputs = {'Input': x, 'Index': indices}
62+
self.attrs = {'Axis': 1}
63+
64+
output = np.zeros((2, 0, 5)).astype(self.dtype)
65+
self.outputs = {'Result': output}
66+
67+
def test_check_output(self):
68+
self.check_output(check_pir=self.check_pir)
69+
70+
def test_check_grad(self):
71+
self.grad = np.zeros_like(self.outputs['Result']).astype(self.dtype)
72+
self.check_grad(
73+
['Input'],
74+
'Result',
75+
user_defined_grads=[self.grad],
76+
check_pir=self.check_pir,
77+
)
78+
79+
2880
class TestTakeAlongAxisOp(OpTest):
2981
def setUp(self):
3082
self.init_data()

0 commit comments

Comments
 (0)