-
Notifications
You must be signed in to change notification settings - Fork 5.9k
[0-size Tensor Job2 No.21-23] Add 0-size Tensor support for send_u_recv #73806
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
39f1403
dfa6229
a482fb3
8c51359
2661bc7
309ddef
2df57d4
895cbe1
384bcd2
fd73ea9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,6 +22,7 @@ | |
| #include "paddle/phi/backends/cpu/cpu_context.h" | ||
| #include "paddle/phi/core/kernel_registry.h" | ||
| #include "paddle/phi/kernels/cpu/graph_send_recv_funcs.h" | ||
| #include "paddle/phi/kernels/full_kernel.h" | ||
|
|
||
| namespace phi { | ||
|
|
||
|
|
@@ -154,6 +155,28 @@ void SendURecvKernel(const Context& dev_ctx, | |
| DenseTensor* dst_count) { | ||
| auto index_type = src_index.dtype(); | ||
| auto& out_size_data = out_size.GetData(); | ||
|
|
||
| if (x.numel() == 0 || src_index.numel() == 0 || dst_index.numel() == 0) { | ||
| if (out_size_data[0] <= 0) { | ||
| out->Resize(x.dims()); | ||
| } else { | ||
| out->Resize(common::make_ddim(out_size_data)); | ||
| } | ||
|
Comment on lines
+160
to
+164
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. out_size 出现负数是输入的问题还是shape推导的问题?不应在kernel层面重新处理shape。应该在infermeta的时候就检查好或者保证推导正确,不应该到kernel层面dim中还出现负数。
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 目前正常非0-size的逻辑也是在kernel层面重新处理了shape,所有为了统一,暂时也按照该方式处理 |
||
| if (reduce_op == "MEAN") { | ||
| int64_t input_size = | ||
| out_size_data[0] <= 0 ? x.dims()[0] : out_size_data[0]; | ||
| dst_count->Resize({input_size}); | ||
| } | ||
| phi::Full<T, Context>( | ||
| dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out); | ||
| phi::Full<int32_t, Context>( | ||
| dev_ctx, | ||
| phi::IntArray(common::vectorize(dst_count->dims())), | ||
| 0, | ||
| dst_count); | ||
| return; | ||
| } | ||
|
|
||
| if (index_type == phi::DataType::INT32) { | ||
| GraphSendRecvOpKernelLaunchHelper<Context, T, int32_t>(dev_ctx, | ||
| x, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议不为0时保留检查
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的