Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5331,10 +5331,13 @@ void SendUERecvInferMeta(const MetaTensor& x,
dst_index_dims.size()));
}

PADDLE_ENFORCE_EQ(src_index_dims[0],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议不为0时保留检查

  if (src_index_dims[0] != 0) {
                    common::errors::InvalidArgument(	    PADDLE_ENFORCE_EQ(
                        "Src_index and Dst_index should have the same shape."));	        src_index_dims[0],
        dst_index_dims[0],
        common::errors::InvalidArgument(
            "Src_index and Dst_index should have the same shape."));
  }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

dst_index_dims[0],
common::errors::InvalidArgument(
"Src_index and Dst_index should have the same shape."));
if (src_index_dims[0] != 0) {
PADDLE_ENFORCE_EQ(
src_index_dims[0],
dst_index_dims[0],
common::errors::InvalidArgument(
"Src_index and Dst_index should have the same shape."));
}

auto y_dims = y.dims();
PADDLE_ENFORCE_EQ(
Expand Down Expand Up @@ -5416,10 +5419,13 @@ void SendUVInferMeta(const MetaTensor& x,
dst_index_dims.size()));
}

PADDLE_ENFORCE_EQ(src_index_dims[0],
dst_index_dims[0],
common::errors::InvalidArgument(
"Src_index and Dst_index should have the same shape."));
if (src_index_dims[0] != 0) {
PADDLE_ENFORCE_EQ(
src_index_dims[0],
dst_index_dims[0],
common::errors::InvalidArgument(
"Src_index and Dst_index should have the same shape."));
}

// Infer out's shape according to x and y(need broadcasting condition)
out->set_dtype(x.dtype());
Expand Down
11 changes: 7 additions & 4 deletions paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2517,10 +2517,13 @@ void SendURecvInferMeta(const MetaTensor& x,
dst_index_dims.size()));
}

PADDLE_ENFORCE_EQ(src_index_dims[0],
dst_index_dims[0],
common::errors::InvalidArgument(
"Src_index and Dst_index should have the same shape."));
if (src_index_dims[0] != 0) {
PADDLE_ENFORCE_EQ(
src_index_dims[0],
dst_index_dims[0],
common::errors::InvalidArgument(
"Src_index and Dst_index should have the same shape."));
}

auto dims = x.dims();
std::vector<int64_t> dims_ = common::vectorize(dims);
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/kernels/cpu/send_u_recv_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/graph_send_recv_funcs.h"
#include "paddle/phi/kernels/full_kernel.h"

namespace phi {

Expand Down Expand Up @@ -128,6 +129,14 @@ void SendURecvGradKernel(const Context& dev_ctx,
const std::string& reduce_op,
DenseTensor* x_grad) {
auto index_type = src_index.dtype();

if (out_grad.numel() == 0 || x.numel() == 0 || src_index.numel() == 0 ||
dst_index.numel() == 0) {
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(x_grad->dims())), 0, x_grad);
return;
}

if (index_type == phi::DataType::INT32) {
GraphSendRecvGradOpKernelLaunchHelper<Context, T, int32_t>(
dev_ctx,
Expand Down
23 changes: 23 additions & 0 deletions paddle/phi/kernels/cpu/send_u_recv_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/graph_send_recv_funcs.h"
#include "paddle/phi/kernels/full_kernel.h"

namespace phi {

Expand Down Expand Up @@ -154,6 +155,28 @@ void SendURecvKernel(const Context& dev_ctx,
DenseTensor* dst_count) {
auto index_type = src_index.dtype();
auto& out_size_data = out_size.GetData();

if (x.numel() == 0 || src_index.numel() == 0 || dst_index.numel() == 0) {
if (out_size_data[0] <= 0) {
out->Resize(x.dims());
} else {
out->Resize(common::make_ddim(out_size_data));
}
Comment on lines +160 to +164
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out_size 出现负数是输入的问题还是shape推导的问题?不应在kernel层面重新处理shape。应该在infermeta的时候就检查好或者保证推导正确,不应该到kernel层面dim中还出现负数。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前正常非0-size的逻辑也是在kernel层面重新处理了shape,所有为了统一,暂时也按照该方式处理

if (reduce_op == "MEAN") {
int64_t input_size =
out_size_data[0] <= 0 ? x.dims()[0] : out_size_data[0];
dst_count->Resize({input_size});
}
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out);
phi::Full<int32_t, Context>(
dev_ctx,
phi::IntArray(common::vectorize(dst_count->dims())),
0,
dst_count);
return;
}

if (index_type == phi::DataType::INT32) {
GraphSendRecvOpKernelLaunchHelper<Context, T, int32_t>(dev_ctx,
x,
Expand Down
11 changes: 11 additions & 0 deletions paddle/phi/kernels/cpu/send_ue_recv_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "paddle/phi/kernels/cpu/graph_send_recv_funcs.h"
#include "paddle/phi/kernels/cpu/graph_send_ue_recv_funcs.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/impl/graph_message_passing_impl.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"
Expand Down Expand Up @@ -458,6 +459,16 @@ void SendUERecvGradKernel(const Context& dev_ctx,
DenseTensor* x_grad,
DenseTensor* y_grad) {
auto index_type = src_index.dtype();

if (out_grad.numel() == 0 || x.numel() == 0 || y.numel() == 0 ||
src_index.numel() == 0 || dst_index.numel() == 0) {
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(x_grad->dims())), 0, x_grad);
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(y_grad->dims())), 0, y_grad);
return;
}

if (index_type == phi::DataType::INT32) {
GraphSendUERecvGradOpKernelLaunchHelper<Context, T, int32_t>(
dev_ctx,
Expand Down
25 changes: 25 additions & 0 deletions paddle/phi/kernels/cpu/send_ue_recv_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/graph_send_ue_recv_funcs.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/impl/graph_message_passing_impl.h"

namespace phi {
Expand Down Expand Up @@ -256,6 +257,30 @@ void SendUERecvKernel(const Context& dev_ctx,
DenseTensor* dst_count) {
auto index_type = src_index.dtype();
auto& out_size_data = out_size.GetData();

if (x.numel() == 0 || y.numel() == 0 || src_index.numel() == 0 ||
dst_index.numel() == 0) {
std::vector<int64_t> dims_ = common::vectorize(out->dims());
if (out_size_data[0] <= 0) {
dims_[0] = x.dims()[0];
} else {
dims_[0] = out_size_data[0];
}
if (reduce_op == "MEAN") {
int64_t input_size =
out_size_data[0] <= 0 ? x.dims()[0] : out_size_data[0];
dst_count->Resize({input_size});
}
out->Resize(common::make_ddim(dims_));
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out);
phi::Full<int, Context>(dev_ctx,
phi::IntArray(common::vectorize(dst_count->dims())),
0,
dst_count);
return;
}

if (index_type == phi::DataType::INT32) {
GraphSendUERecvOpKernelLaunchHelper<Context, T, int32_t>(dev_ctx,
x,
Expand Down
11 changes: 11 additions & 0 deletions paddle/phi/kernels/cpu/send_uv_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/impl/graph_message_passing_impl.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"
Expand Down Expand Up @@ -241,6 +242,16 @@ void SendUVGradKernel(const Context& dev_ctx,
DenseTensor* x_grad,
DenseTensor* y_grad) {
auto index_type = src_index.dtype();

if (out_grad.numel() == 0 || x.numel() == 0 || y.numel() == 0 ||
src_index.numel() == 0 || dst_index.numel() == 0) {
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(x_grad->dims())), 0, x_grad);
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(y_grad->dims())), 0, y_grad);
return;
}

if (index_type == phi::DataType::INT32) {
GraphSendUVGradOpKernelLaunchHelper<Context, T, int32_t>(dev_ctx,
x,
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/kernels/cpu/send_uv_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/graph_send_ue_recv_funcs.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/impl/graph_message_passing_impl.h"

namespace phi {
Expand Down Expand Up @@ -105,6 +106,14 @@ void SendUVKernel(const Context& dev_ctx,
const std::string& message_op,
DenseTensor* out) {
auto index_type = src_index.dtype();

if (x.numel() == 0 || y.numel() == 0 || src_index.numel() == 0 ||
dst_index.numel() == 0) {
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out);
return;
}

if (index_type == phi::DataType::INT32) {
GraphSendUVOpKernelLaunchHelper<Context, T, int32_t>(
dev_ctx, x, y, src_index, dst_index, message_op, out);
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/kernels/gpu/send_u_recv_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "paddle/common/hostdevice.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/gpu/graph_send_recv_funcs.h"

namespace phi {
Expand Down Expand Up @@ -105,6 +106,14 @@ void SendURecvGradKernel(const Context& dev_ctx,
const std::string& reduce_op,
DenseTensor* x_grad) {
auto index_type = src_index.dtype();

if (out_grad.numel() == 0 || x.numel() == 0 || src_index.numel() == 0 ||
dst_index.numel() == 0) {
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(x_grad->dims())), 0, x_grad);
return;
}

if (index_type == phi::DataType::INT32) {
GraphSendRecvGradOpCUDAKernelLaunchHelper<Context, T, int32_t>(
dev_ctx,
Expand Down
23 changes: 23 additions & 0 deletions paddle/phi/kernels/gpu/send_u_recv_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "paddle/common/hostdevice.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/gpu/graph_send_recv_funcs.h"

Expand Down Expand Up @@ -152,6 +153,28 @@ void SendURecvKernel(const Context& dev_ctx,
DenseTensor* dst_count) {
auto index_type = src_index.dtype();
auto& out_size_data = out_size.GetData();

if (x.numel() == 0 || src_index.numel() == 0 || dst_index.numel() == 0) {
if (out_size_data[0] <= 0) {
out->Resize(x.dims());
} else {
out->Resize(common::make_ddim(out_size_data));
}
if (reduce_op == "MEAN") {
int64_t input_size =
out_size_data[0] <= 0 ? x.dims()[0] : out_size_data[0];
dst_count->Resize({input_size});
}
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out);
phi::Full<int32_t, Context>(
dev_ctx,
phi::IntArray(common::vectorize(dst_count->dims())),
0,
dst_count);
return;
}

if (index_type == phi::DataType::INT32) {
GraphSendRecvOpCUDAKernelLaunchHelper<Context, T, int32_t>(dev_ctx,
x,
Expand Down
11 changes: 11 additions & 0 deletions paddle/phi/kernels/gpu/send_ue_recv_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/gpu/graph_send_recv_funcs.h"
Expand Down Expand Up @@ -569,6 +570,16 @@ void SendUERecvGradKernel(const Context& dev_ctx,
DenseTensor* x_grad,
DenseTensor* y_grad) {
auto index_type = src_index.dtype();

if (out_grad.numel() == 0 || x.numel() == 0 || y.numel() == 0 ||
src_index.numel() == 0 || dst_index.numel() == 0) {
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(x_grad->dims())), 0, x_grad);
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(y_grad->dims())), 0, y_grad);
return;
}

if (index_type == phi::DataType::INT32) {
GraphSendUERecvGradOpCUDAKernelLaunchHelper<Context, T, int32_t>(
dev_ctx,
Expand Down
25 changes: 25 additions & 0 deletions paddle/phi/kernels/gpu/send_ue_recv_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "paddle/common/hostdevice.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/gpu/graph_send_recv_funcs.h"
Expand Down Expand Up @@ -282,6 +283,30 @@ void SendUERecvKernel(const Context& dev_ctx,
DenseTensor* dst_count) {
auto index_type = src_index.dtype();
auto& out_size_data = out_size.GetData();

if (x.numel() == 0 || y.numel() == 0 || src_index.numel() == 0 ||
dst_index.numel() == 0) {
std::vector<int64_t> dims_ = common::vectorize(out->dims());
if (out_size_data[0] <= 0) {
dims_[0] = x.dims()[0];
} else {
dims_[0] = out_size_data[0];
}
if (reduce_op == "MEAN") {
int64_t input_size =
out_size_data[0] <= 0 ? x.dims()[0] : out_size_data[0];
dst_count->Resize({input_size});
}
out->Resize(common::make_ddim(dims_));
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out);
phi::Full<int, Context>(dev_ctx,
phi::IntArray(common::vectorize(dst_count->dims())),
0,
dst_count);
return;
}

if (index_type == phi::DataType::INT32) {
GraphSendUERecvOpCUDAKernelLaunchHelper<Context, T, int32_t>(
dev_ctx,
Expand Down
11 changes: 11 additions & 0 deletions paddle/phi/kernels/gpu/send_uv_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/gpu/graph_send_recv_funcs.h"
Expand Down Expand Up @@ -298,6 +299,16 @@ void SendUVGradKernel(const Context& dev_ctx,
DenseTensor* x_grad,
DenseTensor* y_grad) {
auto index_type = src_index.dtype();

if (out_grad.numel() == 0 || x.numel() == 0 || y.numel() == 0 ||
src_index.numel() == 0 || dst_index.numel() == 0) {
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(x_grad->dims())), 0, x_grad);
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(y_grad->dims())), 0, y_grad);
return;
}

if (index_type == phi::DataType::INT32) {
GraphSendUVGradOpCUDAKernelLaunchHelper<Context, T, int32_t>(dev_ctx,
x,
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/kernels/gpu/send_uv_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "paddle/common/hostdevice.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h"
#include "paddle/phi/kernels/impl/graph_message_passing_impl.h"
Expand Down Expand Up @@ -150,6 +151,14 @@ void SendUVKernel(const Context& dev_ctx,
const std::string& message_op,
DenseTensor* out) {
auto index_type = src_index.dtype();

if (x.numel() == 0 || y.numel() == 0 || src_index.numel() == 0 ||
dst_index.numel() == 0) {
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out);
return;
}

if (index_type == phi::DataType::INT32) {
GraphSendUVOpCUDAKernelLaunchHelper<Context, T, int32_t>(
dev_ctx, x, y, src_index, dst_index, message_op, out);
Expand Down
Loading
Loading