Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions paddle/phi/kernels/gpu/index_add_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ __global__ void index_add_cuda_kernel(const T* input,
int64_t dim_idx = idx % (stride * size) / stride;
IndexT src_dim_idx =
(index[dim_idx] < 0 ? index[dim_idx] + index_dim_size : index[dim_idx]);
if (src_dim_idx < 0 || src_dim_idx >= index_dim_size) {
printf("Index out of bounds: index[%d] = %d, index_dim_size = %d\n",
dim_idx,
src_dim_idx,
index_dim_size);
return;
}
int64_t input_idx =
idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride;
phi::CudaAtomicAdd(&output[input_idx], add_value[idx]);
Expand All @@ -60,6 +67,20 @@ void IndexAddKernel(const Context& dev_ctx,
dev_ctx.template Alloc<T>(output);
return;
}
if (x.numel() == 0) {
if (output->numel() > 0) {
dev_ctx.template Alloc<T>(output);
}
return;
}
if (index.numel() == 0) {
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, output);
return;
}
if (add_value.numel() == 0) {
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, output);
return;
}
auto input_dim = x.dims();
auto output_dim = output->dims();
auto add_value_dim = add_value.dims();
Expand All @@ -73,12 +94,12 @@ void IndexAddKernel(const Context& dev_ctx,

auto* in_data = x.data<T>();
T* out_data = dev_ctx.template Alloc<T>(output);
PADDLE_ENFORCE_NOT_NULL(
out_data,
errors::InvalidArgument("The output tensor memory is not allocated."));
auto* add_value_data = add_value.data<T>();

int64_t numel = add_value.numel();
if (numel == 0) {
return;
}
auto stream = dev_ctx.stream();

unsigned int block_dim = PADDLE_CUDA_NUM_THREADS;
Expand All @@ -88,7 +109,6 @@ void IndexAddKernel(const Context& dev_ctx,
// copy input to output.
// todo(@limin29): inplace do not need copy.
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, output);
if (index.numel() == 0) return;

if (FLAGS_cudnn_deterministic) {
VLOG(2) << "Run grad kernel of index_add with single thread.";
Expand Down
Loading