-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PHI] Move segment_pool to phi. #40099
Conversation
Thanks for your contribution! |
#include "paddle/phi/core/dense_tensor.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
#include "paddle/phi/kernels/impl/segment_pool_grad_kernel_impl.h" | ||
#include "paddle/phi/kernels/segment_pool_grad_kernel.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
segment_pool_grad_kernel.h建议放在开头
#include "paddle/phi/core/dense_tensor.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
#include "paddle/phi/kernels/impl/segment_pool_kernel_impl.h" | ||
#include "paddle/phi/kernels/segment_pool_kernel.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
phi::funcs::SetConstant<Context, T> set_zero; | ||
set_zero(dev_ctx, x_grad, static_cast<T>(0)); | ||
|
||
auto index_type = paddle::framework::TransToProtoVarType(segment_ids.type()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里不需要转成proto::VarType了,直接用DataType类型做判断就行
|
||
namespace phi { | ||
|
||
using Tensor = DenseTensor; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
using Tensor可以移除
length.Resize(phi::make_ddim({1})); | ||
IndexT* length_data = dev_ctx.template HostAlloc<IndexT>(&length); | ||
|
||
// IndexT* length_data = length.data<IndexT>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
注释可以去掉
const std::string& pooltype, | ||
DenseTensor* out, | ||
DenseTensor* summed_ids) { | ||
auto index_type = paddle::framework::TransToProtoVarType(segment_ids.dtype()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不用转ProtoVarType
KernelSignature SegmentPoolOpArgumentMapping( | ||
const ArgumentMappingContext& ctx) { | ||
return KernelSignature( | ||
"segment_pool", {"X", "SegmentIds"}, {"pooltype"}, {"Out", "SummedIds"}); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里前向ArgumentMapping感觉可以不写,用默认的应该也能work
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for unity_build_rule.cmake
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for PADDLE_THROW
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Others
PR changes
OPs
Describe
Move segment_pool to phi.