Skip to content

Commit c06c450

Browse files
committed
Fix
1 parent 425c14d commit c06c450

18 files changed

+153
-21
lines changed

paddle/phi/infermeta/binary.cc

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1368,21 +1368,6 @@ void DistInferMeta(const MetaTensor& x,
13681368
const MetaTensor& y,
13691369
float p,
13701370
MetaTensor* out) {
1371-
auto x_dims = x.dims();
1372-
auto y_dims = y.dims();
1373-
1374-
PADDLE_ENFORCE_NE(common::product(x_dims),
1375-
0,
1376-
common::errors::InvalidArgument(
1377-
"The Input(X) has not been initialized properly. The "
1378-
"shape of Input(X) = [%s].",
1379-
x_dims));
1380-
PADDLE_ENFORCE_NE(common::product(y_dims),
1381-
0,
1382-
common::errors::InvalidArgument(
1383-
"The Input(Y) has not been initialized properly. The "
1384-
"shape of Input(Y) = [%s].",
1385-
y_dims));
13861371
out->set_dims(common::make_ddim({}));
13871372
out->set_dtype(x.dtype());
13881373
}

paddle/phi/kernels/cpu/diagonal_grad_kernel.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ void DiagonalGradKernel(const Context& dev_ctx,
2828
int axis1,
2929
int axis2,
3030
DenseTensor* in_grad) {
31+
if (in_grad->numel() == 0) {
32+
dev_ctx.template Alloc<T>(in_grad);
33+
return;
34+
}
3135
const auto* dout = &out_grad;
3236
const T* dout_data = dout->data<T>();
3337
auto dout_dim = common::vectorize(dout->dims());

paddle/phi/kernels/cpu/diagonal_kernel.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
#include "paddle/phi/backends/cpu/cpu_context.h"
1818
#include "paddle/phi/core/kernel_registry.h"
19+
#include "paddle/phi/kernels/full_kernel.h"
1920
#include "paddle/phi/kernels/funcs/diagonal.h"
20-
2121
namespace phi {
2222

2323
template <typename T, typename Context>
@@ -27,6 +27,12 @@ void DiagonalKernel(const Context& dev_ctx,
2727
int axis1,
2828
int axis2,
2929
DenseTensor* out) {
30+
if (x.numel() == 0) {
31+
phi::Full<T, Context>(
32+
dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out);
33+
return;
34+
}
35+
3036
auto* input = &x;
3137
const T* input_data = input->data<T>();
3238
auto input_dim = common::vectorize(input->dims());

paddle/phi/kernels/cpu/kthvalue_grad_kernel.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ void KthvalueGradKernel(const Context& dev_ctx,
5656
auto in_dims = x.dims();
5757
auto out_dims = indices.dims();
5858
T* x_grad_data = dev_ctx.template Alloc<T>(d_x);
59+
if (d_x && d_x->numel() == 0) {
60+
return;
61+
}
5962

6063
// For 0D Tensor
6164
if (in_dims.size() == 0) {

paddle/phi/kernels/cpu/kthvalue_kernel.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616

1717
#include "paddle/phi/backends/cpu/cpu_context.h"
1818
#include "paddle/phi/core/kernel_registry.h"
19+
#include "paddle/phi/kernels/full_kernel.h"
1920
#include "paddle/phi/kernels/funcs/eigen/common.h"
2021
#include "paddle/phi/kernels/funcs/math_function.h"
21-
2222
namespace phi {
2323
template <typename T, typename Type>
2424
static void getKthvalue(Type input_height,
@@ -80,6 +80,13 @@ void KthvalueKernel(const Context& dev_ctx,
8080
bool keepdim,
8181
DenseTensor* output,
8282
DenseTensor* indices) {
83+
if (x.numel() == 0) {
84+
phi::Full<T, Context>(
85+
dev_ctx, phi::IntArray(common::vectorize(output->dims())), NAN, output);
86+
phi::Full<int64_t, Context>(
87+
dev_ctx, phi::IntArray(common::vectorize(indices->dims())), 0, indices);
88+
return;
89+
}
8390
const auto& in_dims = x.dims();
8491
if (axis < 0) axis += in_dims.size();
8592

paddle/phi/kernels/dist_grad_kernel.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "paddle/phi/backends/cpu/cpu_context.h"
2020
#include "paddle/phi/core/kernel_registry.h"
2121
#include "paddle/phi/kernels/elementwise_subtract_kernel.h"
22+
#include "paddle/phi/kernels/full_kernel.h"
2223
#include "paddle/phi/kernels/p_norm_grad_kernel.h"
2324
#include "paddle/phi/kernels/reduce_sum_kernel.h"
2425
#include "paddle/phi/kernels/scale_kernel.h"
@@ -56,6 +57,28 @@ void DistGradKernel(const Context& dev_ctx,
5657
return;
5758
}
5859

60+
if ((x_grad && x_grad->numel() == 0) || (y_grad && y_grad->numel() == 0)) {
61+
if (x_grad) {
62+
dev_ctx.template Alloc<T>(x_grad);
63+
if (x_grad->numel() != 0) {
64+
phi::Full<T, Context>(dev_ctx,
65+
phi::IntArray(common::vectorize(x_grad->dims())),
66+
0,
67+
x_grad);
68+
}
69+
}
70+
if (y_grad) {
71+
dev_ctx.template Alloc<T>(y_grad);
72+
if (y_grad->numel() != 0) {
73+
phi::Full<T, Context>(dev_ctx,
74+
phi::IntArray(common::vectorize(y_grad->dims())),
75+
0,
76+
y_grad);
77+
}
78+
}
79+
return;
80+
}
81+
5982
auto t = Subtract<T, Context>(dev_ctx, x, y);
6083
DenseTensor x_grad_tmp;
6184
x_grad_tmp.Resize(t.dims());

paddle/phi/kernels/dist_kernel.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
#include "paddle/phi/backends/cpu/cpu_context.h"
1818
#include "paddle/phi/core/kernel_registry.h"
1919
#include "paddle/phi/kernels/elementwise_subtract_kernel.h"
20+
#include "paddle/phi/kernels/full_kernel.h"
2021
#include "paddle/phi/kernels/p_norm_kernel.h"
21-
2222
namespace phi {
2323

2424
template <typename T, typename Context>
@@ -27,6 +27,11 @@ void DistKernel(const Context& dev_ctx,
2727
const DenseTensor& y,
2828
float p,
2929
DenseTensor* out) {
30+
if (x.numel() == 0 || y.numel() == 0) {
31+
phi::Full<T, Context>(
32+
dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out);
33+
return;
34+
}
3035
auto t = Subtract<T, Context>(dev_ctx, x, y);
3136
PNormKernel<T, Context>(dev_ctx, t, p, -1, 1e-12, false, true, out);
3237
}

paddle/phi/kernels/gpu/diagonal_grad_kernel.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ void DiagonalGradKernel(const Context& dev_ctx,
3131
int axis1,
3232
int axis2,
3333
DenseTensor* in_grad) {
34+
if (in_grad->numel() == 0) {
35+
dev_ctx.template Alloc<T>(in_grad);
36+
return;
37+
}
3438
const auto* dout = &out_grad;
3539
const auto* dout_data = dout->data<T>();
3640
auto dout_dim = dout->dims().Get();

paddle/phi/kernels/gpu/diagonal_kernel.cu

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
#include "paddle/phi/backends/gpu/gpu_primitives.h"
1818
#include "paddle/phi/core/kernel_registry.h"
1919
#include "paddle/phi/core/tensor_utils.h"
20+
#include "paddle/phi/kernels/full_kernel.h"
2021
#include "paddle/phi/kernels/funcs/diagonal.h"
21-
2222
namespace phi {
2323
using phi::PADDLE_CUDA_NUM_THREADS;
2424
template <typename T, typename Context>
@@ -28,6 +28,11 @@ void DiagonalKernel(const Context& dev_ctx,
2828
int axis1,
2929
int axis2,
3030
DenseTensor* out) {
31+
if (x.numel() == 0) {
32+
phi::Full<T, Context>(
33+
dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out);
34+
return;
35+
}
3136
auto* input = &x;
3237
const auto* input_data = input->data<T>();
3338
auto input_dim = input->dims().Get();

paddle/phi/kernels/gpu/dist_kernel.cu

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "paddle/phi/core/kernel_registry.h"
2020
#include "paddle/phi/kernels/dist_kernel.h"
2121
#include "paddle/phi/kernels/elementwise_subtract_kernel.h"
22+
#include "paddle/phi/kernels/full_kernel.h"
2223
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"
2324
#include "paddle/phi/kernels/gpu/reduce.h"
2425
#include "paddle/phi/kernels/legacy/reduce_max_kernel.h"
@@ -123,6 +124,12 @@ void DistKernel(const Context& dev_ctx,
123124
const DenseTensor& y,
124125
float p,
125126
DenseTensor* out) {
127+
if (x.numel() == 0 || y.numel() == 0) {
128+
phi::Full<T, Context>(
129+
dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out);
130+
return;
131+
}
132+
126133
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
127134
DenseTensor intermediate;
128135
const T* x_ptr = x.data<T>();

0 commit comments

Comments
 (0)