Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 0 additions & 15 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1368,21 +1368,6 @@ void DistInferMeta(const MetaTensor& x,
const MetaTensor& y,
float p,
MetaTensor* out) {
auto x_dims = x.dims();
auto y_dims = y.dims();

PADDLE_ENFORCE_NE(common::product(x_dims),
0,
common::errors::InvalidArgument(
"The Input(X) has not been initialized properly. The "
"shape of Input(X) = [%s].",
x_dims));
PADDLE_ENFORCE_NE(common::product(y_dims),
0,
common::errors::InvalidArgument(
"The Input(Y) has not been initialized properly. The "
"shape of Input(Y) = [%s].",
y_dims));
out->set_dims(common::make_ddim({}));
out->set_dtype(x.dtype());
}
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/kernels/cpu/diagonal_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ void DiagonalGradKernel(const Context& dev_ctx,
int axis1,
int axis2,
DenseTensor* in_grad) {
if (in_grad->numel() == 0) {
dev_ctx.template Alloc<T>(in_grad);
return;
}
const auto* dout = &out_grad;
const T* dout_data = dout->data<T>();
auto dout_dim = common::vectorize(dout->dims());
Expand Down
8 changes: 7 additions & 1 deletion paddle/phi/kernels/cpu/diagonal_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/diagonal.h"

namespace phi {

template <typename T, typename Context>
Expand All @@ -27,6 +27,12 @@ void DiagonalKernel(const Context& dev_ctx,
int axis1,
int axis2,
DenseTensor* out) {
if (x.numel() == 0) {
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out);
return;
}

auto* input = &x;
const T* input_data = input->data<T>();
auto input_dim = common::vectorize(input->dims());
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/kernels/cpu/kthvalue_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ void KthvalueGradKernel(const Context& dev_ctx,
auto in_dims = x.dims();
auto out_dims = indices.dims();
T* x_grad_data = dev_ctx.template Alloc<T>(d_x);
if (d_x && d_x->numel() == 0) {
return;
}

// For 0D Tensor
if (in_dims.size() == 0) {
Expand Down
9 changes: 8 additions & 1 deletion paddle/phi/kernels/cpu/kthvalue_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h"

namespace phi {
template <typename T, typename Type>
static void getKthvalue(Type input_height,
Expand Down Expand Up @@ -80,6 +80,13 @@ void KthvalueKernel(const Context& dev_ctx,
bool keepdim,
DenseTensor* output,
DenseTensor* indices) {
if (x.numel() == 0) {
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(output->dims())), NAN, output);
phi::Full<int64_t, Context>(
dev_ctx, phi::IntArray(common::vectorize(indices->dims())), 0, indices);
return;
}
const auto& in_dims = x.dims();
if (axis < 0) axis += in_dims.size();

Expand Down
23 changes: 23 additions & 0 deletions paddle/phi/kernels/dist_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/elementwise_subtract_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/p_norm_grad_kernel.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"
#include "paddle/phi/kernels/scale_kernel.h"
Expand Down Expand Up @@ -56,6 +57,28 @@ void DistGradKernel(const Context& dev_ctx,
return;
}

if ((x_grad && x_grad->numel() == 0) || (y_grad && y_grad->numel() == 0)) {
if (x_grad) {
dev_ctx.template Alloc<T>(x_grad);
if (x_grad->numel() != 0) {
phi::Full<T, Context>(dev_ctx,
phi::IntArray(common::vectorize(x_grad->dims())),
0,
x_grad);
}
}
if (y_grad) {
dev_ctx.template Alloc<T>(y_grad);
if (y_grad->numel() != 0) {
phi::Full<T, Context>(dev_ctx,
phi::IntArray(common::vectorize(y_grad->dims())),
0,
y_grad);
}
}
return;
}

auto t = Subtract<T, Context>(dev_ctx, x, y);
DenseTensor x_grad_tmp;
x_grad_tmp.Resize(t.dims());
Expand Down
7 changes: 6 additions & 1 deletion paddle/phi/kernels/dist_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/elementwise_subtract_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/p_norm_kernel.h"

namespace phi {

template <typename T, typename Context>
Expand All @@ -27,6 +27,11 @@ void DistKernel(const Context& dev_ctx,
const DenseTensor& y,
float p,
DenseTensor* out) {
if (x.numel() == 0 || y.numel() == 0) {
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out);
return;
}
auto t = Subtract<T, Context>(dev_ctx, x, y);
PNormKernel<T, Context>(dev_ctx, t, p, -1, 1e-12, false, true, out);
}
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/kernels/gpu/diagonal_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ void DiagonalGradKernel(const Context& dev_ctx,
int axis1,
int axis2,
DenseTensor* in_grad) {
if (in_grad->numel() == 0) {
dev_ctx.template Alloc<T>(in_grad);
return;
}
const auto* dout = &out_grad;
const auto* dout_data = dout->data<T>();
auto dout_dim = dout->dims().Get();
Expand Down
7 changes: 6 additions & 1 deletion paddle/phi/kernels/gpu/diagonal_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/diagonal.h"

namespace phi {
using phi::PADDLE_CUDA_NUM_THREADS;
template <typename T, typename Context>
Expand All @@ -28,6 +28,11 @@ void DiagonalKernel(const Context& dev_ctx,
int axis1,
int axis2,
DenseTensor* out) {
if (x.numel() == 0) {
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out);
return;
}
auto* input = &x;
const auto* input_data = input->data<T>();
auto input_dim = input->dims().Get();
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/kernels/gpu/dist_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/dist_kernel.h"
#include "paddle/phi/kernels/elementwise_subtract_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"
#include "paddle/phi/kernels/gpu/reduce.h"
#include "paddle/phi/kernels/legacy/reduce_max_kernel.h"
Expand Down Expand Up @@ -123,6 +124,12 @@ void DistKernel(const Context& dev_ctx,
const DenseTensor& y,
float p,
DenseTensor* out) {
if (x.numel() == 0 || y.numel() == 0) {
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out);
return;
}

using MT = typename phi::dtype::MPTypeTrait<T>::Type;
DenseTensor intermediate;
const T* x_ptr = x.data<T>();
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/kernels/gpu/kthvalue_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ void KthvalueGradKernel(const Context& dev_ctx,
const auto& in_dims = x.dims();
auto out_dims = indices.dims();
T* x_grad_data = dev_ctx.template Alloc<T>(d_x);
if (d_x && d_x->numel() == 0) {
return;
}

// For 0D Tensor
if (in_dims.size() == 0) {
phi::funcs::set_constant(dev_ctx, d_x, static_cast<T>(1.0));
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/kernels/gpu/kthvalue_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
Expand Down Expand Up @@ -160,6 +161,14 @@ void KthvalueKernel(const Context& dev_ctx,
bool keepdim,
DenseTensor* output,
DenseTensor* indices) {
if (x.numel() == 0) {
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(output->dims())), NAN, output);
phi::Full<int64_t, Context>(
dev_ctx, phi::IntArray(common::vectorize(indices->dims())), 0, indices);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

output设置为nan没问题。indices设置为0,这个和torch对比过了吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indices 类型是int64_t, int64_t类型没有nan

return;
}

const auto& in_dims = x.dims();
if (axis < 0) axis += in_dims.size();
auto out_dims = output->dims();
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/kernels/stride/diagonal_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ void DiagonalGradStridedKernel(const Context& dev_ctx,
}
dev_ctx.Alloc(in_grad, in_grad->dtype());
in_grad->set_strides(DenseTensorMeta::calc_strides(in_grad->dims()));
if (in_grad->numel() == 0) {
return;
}
PD_VISIT_ALL_TYPES(in_grad->dtype(), "DiagonalGradStridedKernel", ([&] {
phi::StridedTensorFill<data_t>(*in_grad, 0, in_grad);
}));
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/kernels/xpu/activation_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,9 @@ void ExpGradKernel(const Context& dev_ctx,
DenseTensor* dx) {
using XPUType = typename XPUTypeTrait<T>::Type;
dev_ctx.template Alloc<T>(dx);
if (dx && dx->numel() == 0) {
return;
}
const XPUType* y_data = reinterpret_cast<const XPUType*>(out.data<T>());
const XPUType* y_grad = reinterpret_cast<const XPUType*>(dout.data<T>());
XPUType* x_grad = reinterpret_cast<XPUType*>(dx->data<T>());
Expand Down
8 changes: 7 additions & 1 deletion paddle/phi/kernels/xpu/diagonal_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"

#include "paddle/phi/kernels/full_kernel.h"
namespace phi {

template <typename T, typename Context>
Expand All @@ -26,6 +26,12 @@ void DiagonalKernel(const Context& dev_ctx,
int axis1,
int axis2,
DenseTensor* out) {
if (x.numel() == 0) {
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out);
return;
}

using XPUType = typename XPUTypeTrait<T>::Type;
T* out_data = dev_ctx.template Alloc<T>(out);
std::vector<int64_t> xshape = common::vectorize<int64_t>(x.dims());
Expand Down
13 changes: 13 additions & 0 deletions test/legacy_test/test_diagonal_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,19 @@ def init_config(self):
)


class TestDiagonalOp_ZeroSize(TestDiagonalOp):
def init_config(self):
self.case = np.random.randn(0, 2, 4, 4).astype(self.dtype)
self.inputs = {'Input': self.case}
self.attrs = {'offset': -2, 'axis1': 0, 'axis2': 3}
self.target = np.diagonal(
self.inputs['Input'],
offset=self.attrs['offset'],
axis1=self.attrs['axis1'],
axis2=self.attrs['axis2'],
)


class TestDiagonalAPI(unittest.TestCase):
def setUp(self):
self.shape = [10, 3, 4]
Expand Down
33 changes: 33 additions & 0 deletions test/legacy_test/test_dist_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,39 @@ def init_case(self):
self.p = 1.5


class TestDistOp_ZeroSize1(TestDistOp):
def setUp(self):
self.op_type = 'dist'
self.python_api = paddle.dist
self.attrs = {}
self.init_case()
self.init_data_type()
self.inputs = {
"X": np.random.random(self.x_shape).astype(self.data_type),
"Y": np.random.random(self.y_shape).astype(self.data_type),
}

self.attrs["p"] = self.p
self.outputs = {
"Out": dist(self.inputs["X"], self.inputs["Y"], self.attrs["p"])
}

def test_check_grad(self):
self.check_grad(["X", "Y"], "Out", check_pir=True)

def init_case(self):
self.x_shape = (0, 1, 5, 6)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单测是否可以增加一个x 或 y 只有一个为0 size Tensor的case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

self.y_shape = (0, 5, 6)
self.p = 1.0


class TestDistOp_ZeroSize2(TestDistOp_ZeroSize1):
def init_case(self):
self.x_shape = (0, 1, 5, 6)
self.y_shape = (1, 5, 6)
self.p = 1.0


class TestDistAPI(unittest.TestCase):
def init_data_type(self):
self.data_type = (
Expand Down
11 changes: 10 additions & 1 deletion test/legacy_test/test_kthvalue_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,17 @@ def init_args(self):
def init_dtype(self):
self.dtype = np.float64

def init_shape(self):
self.shape = [2, 1, 2, 4, 10]

def setUp(self):
self.op_type = "kthvalue"
self.prim_op_type = "prim"
self.python_api = paddle.kthvalue
self.public_python_api = paddle.kthvalue
self.init_dtype()
self.input_data = np.random.random([2, 1, 2, 4, 10]).astype(self.dtype)
self.init_shape()
self.input_data = np.random.random(self.shape).astype(self.dtype)
self.init_args()
self.inputs = {'X': self.input_data}
self.attrs = {'k': self.k, 'axis': self.axis}
Expand Down Expand Up @@ -77,6 +81,11 @@ def init_dtype(self):
self.dtype = np.float16


class TestKthvalueOp_ZeroSize(TestKthvalueOp):
def init_shape(self):
self.shape = [2, 1, 0, 4, 10]


class TestKthvalueOpWithKeepdim(OpTest):
def init_args(self):
self.k = 2
Expand Down
Loading
Loading