-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move GumbelSoftmax OP to phi #39873
Move GumbelSoftmax OP to phi #39873
Conversation
Thanks for your contribution! |
paddle/phi/infermeta/unary.cc
Outdated
PADDLE_ENFORCE_GE( | ||
axis, | ||
-rank, | ||
paddle::platform::errors::InvalidArgument( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
替换成phi的命名空间
paddle/phi/infermeta/unary.cc
Outdated
PADDLE_ENFORCE_LT( | ||
axis, | ||
rank, | ||
paddle::platform::errors::InvalidArgument( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
paddle/phi/infermeta/unary.cc
Outdated
PADDLE_ENFORCE_EQ( | ||
out.dims(), | ||
dout.dims(), | ||
paddle::platform::errors::InvalidArgument( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
|
||
PADDLE_ENFORCE_GT(temperature, | ||
0, | ||
paddle::platform::errors::InvalidArgument( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
default: | ||
PADDLE_ENFORCE_LE(out->dims().size(), | ||
6, | ||
paddle::platform::errors::InvalidArgument( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
phi命名空间
paddle/phi/infermeta/unary.h
Outdated
int axis, | ||
MetaTensor* out); | ||
|
||
void GumbelSoftmaxGradInferMeta(const MetaTensor& out, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
反向的InferMeta函数放到infermeta目录下的backward.h/.cc文件里
paddle/phi/infermeta/unary.h
Outdated
int axis, | ||
MetaTensor* out); | ||
|
||
void GumbelSoftmaxGradInferMeta(const MetaTensor& out, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
grad infermeta统一放到backward.h/cc中
// generate uniform random number | ||
const int size = size_to_axis * size_from_axis; | ||
std::uniform_real_distribution<T> dist(0.00001, 1); | ||
auto engine = paddle::framework::GetCPURandomEngine(0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这一点和秋良同步下,秋良后续应该会把generator也移过来
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
重新修改了下,现在这个OP里使用的engine也同步换过来了,distribution仍然保持使用std的
|
||
namespace phi { | ||
|
||
static inline int CanonicalAxis(const int axis, const int rank) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这几个函数框架里有三份一模一样的,我在另一个PR里把它统一放到funcs下了,我们预计会冲突一波,同步一下,#39547
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
看了下PR #39547 未改动GumbelSoftmax OP代码,我这个PR里是自己单独实现了一份,未引用相关文件,应该不会冲突。后续你的PR合入后,我再把这个kernel中的重复实现移除,改成include新添加的axis_util即可。
const int size_to_axis = SizeToAxis(axis, dx->dims()); | ||
const int size_from_axis = SizeFromAxis(axis, dx->dims()); | ||
DenseTensor dx_2d, out_2d, dout_2d; | ||
dx_2d.ShareDataWith(*dx).Resize({size_to_axis, size_from_axis}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个ShareDataWith如果想避免CI检查的话,可以直接改成copy构造,DenseTensor dx_2d(*dx),目前实现应该是一样的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thx
当前迁移继续推进即可,但我后面倾向于将softmax几个kernel实现都放到softmax_kernel.h中,减少一些文件 @YuanRisheng |
errors命名空间问题和反向Infermeta放置位置已更改 @YuanRisheng @MingMingShangTian @zyfncg @chenwhql |
PR types
Function optimization
PR changes
OPs
Describe
Move GumbelSoftmax OP to phi.