Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move some activation to phi #40727

Merged
merged 50 commits into from
Mar 28, 2022

Conversation

phlrain
Copy link
Collaborator

@phlrain phlrain commented Mar 19, 2022

PR types

Breaking changes

PR changes

OPs

Describe

move activation to phi

@paddle-bot-old
Copy link

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = static_cast<T>(0.5) * dout / out;
struct SqrtGradGradFunctor : public BaseActivationFunctor<T> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

二阶导是后边的pr迁移吗

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是发生了奇怪结果错误,代码有了,准备分另外一个pr来搞

functor(place, eigen_in, eigen_out, eigen_p, eps);
}
};

template <typename DeviceContext, typename T>
class LogitGradKernel : public framework::OpKernel<T> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kernel已迁,这个可以删掉

Comment on lines 1790 to 1798
// REGISTER_OPERATOR(
// expm1, ops::ActivationOp, ops::Expm1OpMaker,
// ops::ActivationOpInferVarType,
// ops::ActivationGradOpMaker<ops::Expm1GradFunctor<float>::FwdDeps(),
// paddle::framework::OpDesc>,
// ops::ActivationGradOpMaker<ops::Expm1GradFunctor<float>::FwdDeps(),
// paddle::imperative::OpBase>,
// std::conditional<ops::CanInplaceAct<ops::Expm1GradFunctor<float>>(),
// ops::ActFwdInplaceInferer, void>::type);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

无用注释删掉

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是不是没删

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

};

template <typename T>
struct CudaMishGradFunctor : public BaseActivationFunctor<T> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些cuda functor也需要使用USE_PHI_FUNCTOR引过来,这个文件里还有xpu的注册需要用到这些functor

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 1005 to 1016
// template <typename T>
// struct SoftsignGradFunctor : public BaseActivationFunctor<T> {
// template <typename Device, typename X, typename Out, typename dOut,
// typename dX>
// void operator()(Device d, X x, Out out, dOut dout, dX dx) {
// dx.device(d) =
// dout * (static_cast<T>(1) / (static_cast<T>(1) + x.abs()).square());
// }

// static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX;
// }
// };
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

无用注释删掉

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 44 to 45
// functor(*place, To32BitIndex(x), To32BitIndex(out));
functor(*place, x, out);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里修改的原因是?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

误操作,我回复过去

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -1675,7 +1679,6 @@ REGISTER_OPERATOR(
ops::ActivationOpDoubleGrad<ops::SqrtGradGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInferer);

REGISTER_ACTIVATION_CPU_KERNEL(sqrt, Sqrt, SqrtFunctor, SqrtGradFunctor);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sqrt和square有SelcetedRows类型的输入,在selectedrows未迁移的情况下删掉这里会不会出现问题

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我恢复下

@DannyIsFunny
Copy link
Contributor

修复infrt CI 失败的问题:
#40826

@PaddlePaddle PaddlePaddle locked and limited conversation to collaborators Mar 25, 2022
@PaddlePaddle PaddlePaddle unlocked this conversation Mar 25, 2022
Comment on lines 1790 to 1798
// REGISTER_OPERATOR(
// expm1, ops::ActivationOp, ops::Expm1OpMaker,
// ops::ActivationOpInferVarType,
// ops::ActivationGradOpMaker<ops::Expm1GradFunctor<float>::FwdDeps(),
// paddle::framework::OpDesc>,
// ops::ActivationGradOpMaker<ops::Expm1GradFunctor<float>::FwdDeps(),
// paddle::imperative::OpBase>,
// std::conditional<ops::CanInplaceAct<ops::Expm1GradFunctor<float>>(),
// ops::ActFwdInplaceInferer, void>::type);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是不是没删

template <typename T, typename Context>
void ClipByNormSparseKernel(const Context& ctx,
const SelectedRows& x,
float max_norm,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

selected rows移动到对应目录

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clip by norm有其他的注册问题,单独提pr处理


} // namespace phi

// PD_REGISTER_KERNEL(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注释移除吗

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clip by norm有其他的注册问题,单独提pr处理

return KernelSignature("sqrt", {"X"}, {}, {"Out"});
}

return KernelSignature("unregistered", {}, {}, {});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果迁移完整的话,应该不需要unregister

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if (ctx.IsDenseTensorInput("X")) {
return KernelSignature("clip_by_norm", {"X"}, {"max_norm"}, {"Out"});
} else {
return KernelSignature("clip_by_norm_sparse", {"X"}, {"max_norm"}, {"Out"});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

selected rows和已有的命名风格保持一致

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clip by norm有其他的注册问题,单独提pr处理

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@phlrain phlrain merged commit e77a947 into PaddlePaddle:develop Mar 28, 2022
phlrain pushed a commit that referenced this pull request Mar 29, 2022
phlrain added a commit that referenced this pull request Mar 29, 2022
phlrain added a commit that referenced this pull request Mar 30, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants