Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2502,11 +2502,6 @@ void IndexAddInferMeta(const MetaTensor& x,
index_dim,
index_dim.size()));

PADDLE_ENFORCE_EQ(index_dim[0] != 0,
true,
common::errors::InvalidArgument(
"The length of Input(Index) can't be 0."));

// Note, add_value does not support broadcast now.
PADDLE_ENFORCE_EQ(input_dim.size() == add_value_dim.size(),
true,
Expand Down
12 changes: 12 additions & 0 deletions paddle/phi/kernels/cpu/index_add_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/cpu/index_select_impl.h"
#include "paddle/phi/kernels/full_kernel.h"

namespace phi {

Expand All @@ -28,6 +29,17 @@ void IndexAddGradKernel(const Context& dev_ctx,
int axis,
DenseTensor* x_grad,
DenseTensor* add_value_grad) {
if (out_grad.numel() == 0) {
dev_ctx.template Alloc<T>(x_grad);
if (add_value_grad) {
phi::Full<T, Context>(
dev_ctx,
phi::IntArray(common::vectorize(add_value_grad->dims())),
0,
add_value_grad);
}
return;
}
if (axis < 0) {
axis += out_grad.dims().size();
}
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/kernels/cpu/index_add_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ void IndexAddInner(const Context& dev_ctx,
// copy x to output.
// todo(@limin29): inplace do not need copy.
phi::Copy(dev_ctx, *input, dev_ctx.GetPlace(), false, output);
if (index.numel() == 0) return;

auto slice_size = 1;
for (auto i = axis + 1; i < input_dim_size; i++) {
Expand Down Expand Up @@ -107,6 +108,10 @@ void IndexAddBaseKernel(const Context& dev_ctx,
int axis,
const DenseTensor& add_value,
DenseTensor* output) {
if (output && output->numel() == 0) {
dev_ctx.template Alloc<T>(output);
return;
}
const auto& index_type = index.dtype();
if (axis < 0) {
axis += x.dims().size();
Expand Down
13 changes: 13 additions & 0 deletions paddle/phi/kernels/gpu/index_add_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/gpu/index_select_impl.h"

Expand All @@ -34,6 +35,18 @@ void IndexAddGradKernel(const Context& dev_ctx,
int dim,
DenseTensor* x_grad,
DenseTensor* add_value_grad) {
if (out_grad.numel() == 0) {
dev_ctx.template Alloc<T>(x_grad);
if (add_value_grad) {
phi::Full<T, Context>(
dev_ctx,
phi::IntArray(common::vectorize(add_value_grad->dims())),
0,
add_value_grad);
}
return;
}

// x.shape == out.shape in index_grad op
auto input_dim = out_grad.dims();
auto add_value_dim = add_value.dims();
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/kernels/gpu/index_add_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ void IndexAddKernel(const Context& dev_ctx,
const DenseTensor& add_value,
int axis,
DenseTensor* output) {
if (output && output->numel() == 0) {
dev_ctx.template Alloc<T>(output);
return;
}
auto input_dim = x.dims();
auto output_dim = output->dims();
auto add_value_dim = add_value.dims();
Expand Down Expand Up @@ -84,6 +88,7 @@ void IndexAddKernel(const Context& dev_ctx,
// copy input to output.
// todo(@limin29): inplace do not need copy.
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, output);
if (index.numel() == 0) return;

if (FLAGS_cudnn_deterministic) {
VLOG(2) << "Run grad kernel of index_add with single thread.";
Expand Down
13 changes: 13 additions & 0 deletions paddle/phi/kernels/xpu/index_add_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/index_select_kernel.h"

namespace phi {
Expand All @@ -29,6 +30,18 @@ void IndexAddGradKernel(const Context& dev_ctx,
int dim,
DenseTensor* x_grad,
DenseTensor* add_value_grad) {
if (out_grad.numel() == 0) {
dev_ctx.template Alloc<T>(x_grad);
if (add_value_grad) {
phi::Full<T, Context>(
dev_ctx,
phi::IntArray(common::vectorize(add_value_grad->dims())),
0,
add_value_grad);
}
return;
}

if (dim < 0) {
dim += out_grad.dims().size();
}
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/kernels/xpu/index_add_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ void IndexAddKernel(const Context& dev_ctx,
DataTypeToString(DataType::INT32),
DataTypeToString(DataType::INT64)));

if (out && out->numel() == 0) {
dev_ctx.template Alloc<T>(out);
return;
}
if (index.numel() == 0) {
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
return;
}

using XPUType = typename XPUTypeTrait<T>::Type;
auto input_dim = x.dims();
int dim = axis >= 0 ? axis : axis + input_dim.size();
Expand Down
44 changes: 44 additions & 0 deletions test/legacy_test/test_index_add_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,5 +458,49 @@ def config(self):

# self.assertRaises(ValueError, test_add_value_broadcast)


class TestIndexAddOp_ZeroSize(OpTest):
def setUp(self):
self.python_api = raw_index_add
self.op_type = "index_add"
self.init_dtype_type()
index_np = np.random.randint(
low=-self.x_shape[self.axis],
high=self.x_shape[self.axis],
size=self.index_size,
)
x_np = np.random.random(self.x_shape).astype(self.x_type)
add_value_np = np.random.random(self.add_value_shape).astype(
self.x_type
)

self.inputs = {'X': x_np, 'Index': index_np, 'AddValue': add_value_np}
self.attrs = {'axis': self.axis}
out = compute_index_add_ref(
self.axis,
self.x_shape,
x_np,
self.add_value_shape,
add_value_np,
self.index_size,
index_np,
)
self.outputs = {'Out': out}

def init_dtype_type(self):
self.axis = 0
self.x_type = np.float64
self.index_type = np.int64
self.x_shape = (101, 0)
self.index_size = 3
self.add_value_shape = (3, 0)

def test_check_output(self):
self.check_output(atol=1e-2, check_pir=True)

def test_check_grad_normal(self):
self.check_grad(['X', 'AddValue'], 'Out', check_pir=True)


if __name__ == '__main__':
unittest.main()
Loading