diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 50c2d0801f0852..6c728898bba753 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -2501,6 +2501,19 @@ void IndexAddInferMeta(const MetaTensor& x, int axis, MetaTensor* output) { auto input_dim = x.dims(); + if (common::product(input_dim) == 0) { + output->set_dims(input_dim); + output->set_dtype(x.dtype()); + output->set_layout(x.layout()); + return; + } + if (index.dims().size() == 1 && index.dims()[0] == 0) { + output->set_dims(input_dim); + output->set_dtype(x.dtype()); + output->set_layout(x.layout()); + output->share_lod(x); + return; + } auto index_dim = index.dims(); auto add_value_dim = add_value.dims(); @@ -2524,7 +2537,13 @@ void IndexAddInferMeta(const MetaTensor& x, "the dimension of Input(Index) is [%d].", index_dim, index_dim.size())); - + if (common::product(add_value_dim) == 0) { + output->set_dims(input_dim); + output->set_dtype(x.dtype()); + output->set_layout(x.layout()); + output->share_lod(x); + return; + } // Note, add_value does not support broadcast now. PADDLE_ENFORCE_EQ(input_dim.size() == add_value_dim.size(), true, diff --git a/paddle/phi/kernels/gpu/index_add_grad_kernel.cu b/paddle/phi/kernels/gpu/index_add_grad_kernel.cu index 2b65dbe0f97081..b3690573d52224 100644 --- a/paddle/phi/kernels/gpu/index_add_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/index_add_grad_kernel.cu @@ -36,7 +36,9 @@ void IndexAddGradKernel(const Context& dev_ctx, DenseTensor* x_grad, DenseTensor* add_value_grad) { if (out_grad.numel() == 0) { - dev_ctx.template Alloc(x_grad); + if (x_grad) { + dev_ctx.template Alloc(x_grad); + } if (add_value_grad) { phi::Full( dev_ctx, @@ -46,7 +48,28 @@ void IndexAddGradKernel(const Context& dev_ctx, } return; } - + if (index.numel() == 0) { + if (x_grad) { + phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad); + } + if (add_value_grad) { + phi::Full( + dev_ctx, + phi::IntArray(common::vectorize(add_value_grad->dims())), + 0, + add_value_grad); + } + return; + } + if (add_value.numel() == 0) { + if (x_grad) { + phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad); + } + if (add_value_grad) { + dev_ctx.template Alloc(add_value_grad); + } + return; + } // x.shape == out.shape in index_grad op auto input_dim = out_grad.dims(); auto add_value_dim = add_value.dims(); diff --git a/paddle/phi/kernels/gpu/index_add_kernel.cu b/paddle/phi/kernels/gpu/index_add_kernel.cu index 1e165fd2dfa17d..a45ce703ec99cc 100644 --- a/paddle/phi/kernels/gpu/index_add_kernel.cu +++ b/paddle/phi/kernels/gpu/index_add_kernel.cu @@ -56,8 +56,16 @@ void IndexAddKernel(const Context& dev_ctx, const DenseTensor& add_value, int axis, DenseTensor* output) { - if (output && output->numel() == 0) { - dev_ctx.template Alloc(output); + if (x.numel() == 0) { + phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, output); + return; + } + if (index.numel() == 0) { + phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, output); + return; + } + if (add_value.numel() == 0) { + phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, output); return; } auto input_dim = x.dims(); @@ -76,9 +84,6 @@ void IndexAddKernel(const Context& dev_ctx, auto* add_value_data = add_value.data(); int64_t numel = add_value.numel(); - if (numel == 0) { - return; - } auto stream = dev_ctx.stream(); unsigned int block_dim = PADDLE_CUDA_NUM_THREADS; @@ -88,7 +93,6 @@ void IndexAddKernel(const Context& dev_ctx, // copy input to output. // todo(@limin29): inplace do not need copy. phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, output); - if (index.numel() == 0) return; if (FLAGS_cudnn_deterministic) { VLOG(2) << "Run grad kernel of index_add with single thread."; diff --git a/test/legacy_test/test_index_add_op.py b/test/legacy_test/test_index_add_op.py index b3383e1ce14cef..bc3df244420095 100644 --- a/test/legacy_test/test_index_add_op.py +++ b/test/legacy_test/test_index_add_op.py @@ -513,5 +513,40 @@ def test_check_grad_normal(self): ) +class TestIndexAdd_ZeroSize2(OpTest): + def setUp(self): + self.python_api = raw_index_add + self.op_type = "index_add" + self.prim_op_type = "prim" + self.public_python_api = raw_index_add + self.init_dtype_type() + index_np = np.array([], dtype=self.index_type) + x_np = np.random.random(self.x_shape).astype(self.x_type) + add_value_np = np.random.random(self.add_value_shape).astype( + self.x_type + ) + + self.inputs = {'X': x_np, 'Index': index_np, 'AddValue': add_value_np} + self.attrs = {'axis': self.axis} + out = x_np.copy() + self.outputs = {'Out': out} + + def init_dtype_type(self): + self.x_type = np.float32 + self.index_type = np.int32 + self.x_shape = (10,) + self.index_size = 0 + self.axis = 0 + self.add_value_shape = (0,) + + def test_check_output(self): + self.check_output(atol=1e-2, check_pir=True) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'AddValue'], 'Out', check_pir=True, check_prim_pir=True + ) + + if __name__ == '__main__': unittest.main()