Skip to content

Commit bc238f5

Browse files
committed
Fix
1 parent 3efb8db commit bc238f5

File tree

7 files changed

+84
-8
lines changed

7 files changed

+84
-8
lines changed

paddle/phi/infermeta/binary.cc

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2010,12 +2010,14 @@ void GatherInferMeta(const MetaTensor& x,
20102010
auto index_dims = index.dims();
20112011

20122012
if (index_dims.size() == 2) {
2013-
PADDLE_ENFORCE_EQ(
2014-
index_dims[1],
2015-
1,
2016-
common::errors::InvalidArgument(
2017-
"The last dim of index should be 1 when it is 2D, but we get %d",
2018-
index_dims[1]));
2013+
if (index_dims[1] != 0) {
2014+
PADDLE_ENFORCE_EQ(
2015+
index_dims[1],
2016+
1,
2017+
common::errors::InvalidArgument("The last dim of index should be 0 "
2018+
"or 1 when it is 2D, but we get %d",
2019+
index_dims[1]));
2020+
}
20192021
} else {
20202022
PADDLE_ENFORCE_EQ(
20212023
index_dims.size() == 1 || index_dims.size() == 0,
@@ -2084,13 +2086,19 @@ void GatherInferMeta(const MetaTensor& x,
20842086
if (axis.FromTensor() || axis_v == 0) {
20852087
// if axis.FromTensor(), we can not obtain correct shape of output
20862088
int batch_size = static_cast<int>(index_dims[0]);
2089+
if (index_dims.size() == 2 && index_dims[1] == 0) {
2090+
batch_size = 0;
2091+
}
20872092
phi::DDim output_dims(input_dim);
20882093
output_dims[0] = batch_size;
20892094
out->set_dims(output_dims);
20902095
out->set_dtype(x.dtype());
20912096
out->share_lod(x);
20922097
} else {
20932098
int index_size = static_cast<int>(index_dims[0]);
2099+
if (index_dims.size() == 2 && index_dims[1] == 0) {
2100+
index_size = 0;
2101+
}
20942102
std::vector<int> out_dim_vec;
20952103
for (int i = 0; i < axis_v; i++) {
20962104
out_dim_vec.push_back(input_dim[i]); // NOLINT

paddle/phi/kernels/cpu/gather_grad_kernel.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "paddle/phi/common/bfloat16.h"
1818
#include "paddle/phi/core/kernel_registry.h"
19+
#include "paddle/phi/kernels/full_kernel.h"
1920
#include "paddle/phi/kernels/funcs/eigen/common.h"
2021
#include "paddle/phi/kernels/funcs/gather.h"
2122
#include "paddle/phi/kernels/funcs/scatter.h"
@@ -29,6 +30,13 @@ void GatherGradKernel(const Context& dev_ctx,
2930
const DenseTensor& out_grad,
3031
const Scalar& axis,
3132
DenseTensor* x_grad) {
33+
if (out_grad.numel() == 0) {
34+
if (x_grad) {
35+
phi::Full<T, Context>(
36+
dev_ctx, phi::IntArray(common::vectorize(x_grad->dims())), 0, x_grad);
37+
}
38+
return;
39+
}
3240
const auto& index_type = index.dtype();
3341
auto axis_v = axis.to<int>();
3442
if (axis_v < 0) {

paddle/phi/kernels/cpu/gather_kernel.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ void GatherKernel(const Context& dev_ctx,
2626
const DenseTensor& index,
2727
const Scalar& axis,
2828
DenseTensor* out) {
29+
if (out && out->numel() == 0) {
30+
dev_ctx.template Alloc<T>(out);
31+
return;
32+
}
2933
const auto& index_type = index.dtype();
3034
auto axis_v = axis.to<int>();
3135
if (axis_v < 0) {

paddle/phi/kernels/gpu/gather_grad_kernel.cu

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
#include "paddle/phi/common/bfloat16.h"
1616
#include "paddle/phi/common/float16.h"
1717
#include "paddle/phi/core/kernel_registry.h"
18+
#include "paddle/phi/kernels/full_kernel.h"
1819
#include "paddle/phi/kernels/funcs/eigen/common.h"
1920
#include "paddle/phi/kernels/funcs/gather.cu.h"
2021
#include "paddle/phi/kernels/funcs/scatter.cu.h"
2122
#include "paddle/phi/kernels/gather_kernel.h"
22-
2323
namespace phi {
2424

2525
template <typename T, typename Context>
@@ -29,6 +29,14 @@ void GatherGradKernel(const Context& dev_ctx,
2929
const DenseTensor& out_grad,
3030
const Scalar& axis,
3131
DenseTensor* x_grad) {
32+
// x [4, 2], index [2, 0], out [2, 0], x_grad [4, 2]
33+
if (out_grad.numel() == 0) {
34+
if (x_grad) {
35+
phi::Full<T, Context>(
36+
dev_ctx, phi::IntArray(common::vectorize(x_grad->dims())), 0, x_grad);
37+
}
38+
return;
39+
}
3240
const auto& index_type = index.dtype();
3341
auto axis_v = axis.to<int>();
3442
if (axis_v < 0) {

paddle/phi/kernels/gpu/gather_kernel.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ void GatherKernel(const Context& dev_ctx,
2727
const DenseTensor& index,
2828
const Scalar& axis,
2929
DenseTensor* out) {
30+
if (out && out->numel() == 0) {
31+
dev_ctx.template Alloc<T>(out);
32+
return;
33+
}
3034
const auto& index_type = index.dtype();
3135
auto axis_v = axis.to<int>();
3236
if (axis_v < 0) {

paddle/phi/kernels/xpu/gather_grad_kernel.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
#include "paddle/phi/backends/xpu/enforce_xpu.h"
1818
#include "paddle/phi/core/kernel_registry.h"
19-
19+
#include "paddle/phi/kernels/full_kernel.h"
2020
namespace phi {
2121

2222
template <typename T, typename Context>
@@ -34,6 +34,10 @@ void GatherGradKernel(const Context& dev_ctx,
3434
const auto& index_type = index.dtype();
3535

3636
if (out_grad.numel() == 0) {
37+
if (x_grad) {
38+
phi::Full<T, Context>(
39+
dev_ctx, phi::IntArray(common::vectorize(x_grad->dims())), 0, x_grad);
40+
}
3741
return;
3842
}
3943

test/legacy_test/test_gather_op.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -948,6 +948,46 @@ def test_gather_backward(self):
948948
np.testing.assert_allclose(res_list[0], res_list[1])
949949

950950

951+
class TestGatherOp_ZeroSize(OpTest):
952+
def setUp(self):
953+
self.op_type = "gather"
954+
self.python_api = paddle.gather
955+
self.public_python_api = paddle.gather
956+
self.config()
957+
self.init_inputs_and_outputs()
958+
959+
def test_check_output(self):
960+
self.check_output(check_pir=True)
961+
962+
def test_check_grad(self):
963+
self.check_grad(['X'], 'Out', check_pir=True)
964+
965+
def config(self):
966+
self.x_shape = (3, 0, 4)
967+
self.config_dtype()
968+
self.index = [2]
969+
self.index_type = "int32"
970+
971+
def config_dtype(self):
972+
self.x_type = "float64"
973+
974+
def init_inputs_and_outputs(self):
975+
xnp = np.random.random(self.x_shape).astype(self.x_type)
976+
self.inputs = {
977+
'X': xnp,
978+
'Index': np.array(self.index).astype(self.index_type),
979+
}
980+
self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]}
981+
982+
983+
class TestGatherOp_ZeroSize2(TestGatherOp_ZeroSize):
984+
def config(self):
985+
self.x_shape = (10, 20)
986+
self.config_dtype()
987+
self.index = [2, 0]
988+
self.index_type = "int32"
989+
990+
951991
if __name__ == "__main__":
952992
paddle.enable_static()
953993
unittest.main()

0 commit comments

Comments
 (0)