Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move some activation to phi #40727

Merged
merged 50 commits into from
Mar 28, 2022
Merged
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
3fc0d19
update
phlrain Mar 8, 2022
4e23ac6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Mar 15, 2022
c552d1a
add forward case
phlrain Mar 16, 2022
def3363
update
phlrain Mar 16, 2022
c7c81fe
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Mar 17, 2022
6c7a03b
update; test=develop
phlrain Mar 17, 2022
4be77e5
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Mar 17, 2022
21beb08
add some grad kernel; test=develop
phlrain Mar 17, 2022
77812c0
move gpu kernel; test=develop
phlrain Mar 17, 2022
8b8c770
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Mar 17, 2022
7834cf1
update
phlrain Mar 18, 2022
66b6a5b
update;
phlrain Mar 18, 2022
7a23e76
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Mar 19, 2022
aa9c1bb
update test;
phlrain Mar 19, 2022
685432f
fix selected rows bug;
phlrain Mar 19, 2022
e4ffcc8
add mix vector include ;
phlrain Mar 19, 2022
377d2a0
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Mar 19, 2022
aa5c3b0
add mixed vector depen; test=develop
phlrain Mar 19, 2022
aca7bb9
add logit grad signature;
phlrain Mar 19, 2022
c5c8aa0
polish code
phlrain Mar 23, 2022
0f829f4
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Mar 23, 2022
0b23b8e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Mar 23, 2022
a19cbd4
fix bug;
phlrain Mar 23, 2022
f7855a5
add namespace for abs
phlrain Mar 24, 2022
de125ec
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Mar 24, 2022
b2c1c68
revert code
phlrain Mar 24, 2022
b6e99a4
not move softsign
phlrain Mar 24, 2022
e0faca8
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Mar 24, 2022
6ef93fe
revmove duplate register;
phlrain Mar 25, 2022
2d0c111
fix softsign bug
phlrain Mar 25, 2022
4e92e26
polish code
phlrain Mar 25, 2022
002599d
format
phlrain Mar 25, 2022
88597a5
format
phlrain Mar 25, 2022
609bdee
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Mar 25, 2022
8b13e86
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Mar 25, 2022
9b735e6
fix bug
phlrain Mar 25, 2022
d8bb56e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Mar 25, 2022
80c2d39
remove cmake dep
phlrain Mar 25, 2022
0306098
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Mar 25, 2022
95e1451
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Mar 26, 2022
e084fb3
add square sqrt selected rows support
phlrain Mar 27, 2022
14910f8
update
phlrain Mar 27, 2022
4f1dcbe
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Mar 27, 2022
28cbd53
remove clip norm
phlrain Mar 27, 2022
3917bc1
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Mar 27, 2022
e8e2d71
add standalone executor sqrt dep
phlrain Mar 27, 2022
2b354db
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Mar 27, 2022
7904904
standalone exec denp sqrt
phlrain Mar 27, 2022
9f61ab9
remove sqrt op in cmkaelist
phlrain Mar 27, 2022
ece6173
open some case
phlrain Mar 27, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 37 additions & 69 deletions paddle/fluid/operators/activation_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1496,6 +1496,33 @@ REGISTER_ACTIVATION_OP(hard_sigmoid, HardSigmoid, HardSigmoidFunctor,
HardSigmoidGradFunctor);
REGISTER_ACTIVATION_OP(logsigmoid, LogSigmoid, LogSigmoidFunctor,
LogSigmoidGradFunctor);
REGISTER_ACTIVATION_OP(expm1, Expm1, Expm1Functor, Expm1GradFunctor);
REGISTER_ACTIVATION_OP(softplus, Softplus, SoftplusFunctor,
SoftplusGradFunctor);
REGISTER_ACTIVATION_OP(mish, Mish, MishFunctor, MishGradFunctor);
REGISTER_ACTIVATION_OP(stanh, STanh, STanhFunctor, STanhGradFunctor);
REGISTER_ACTIVATION_OP(reciprocal, Reciprocal, ReciprocalFunctor,
ReciprocalGradFunctor);

REGISTER_ACTIVATION_CPU_KERNEL(sqrt, Sqrt, SqrtFunctor, SqrtGradFunctor);
REGISTER_OP_CPU_KERNEL(square,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::SquareFunctor<float>>,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::SquareFunctor<double>>,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::SquareFunctor<int>>,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::SquareFunctor<int64_t>>);
REGISTER_OP_CPU_KERNEL(
square_grad, ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
ops::SquareGradFunctor<float>>,
ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
ops::SquareGradFunctor<double>>,
ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
ops::SquareGradFunctor<int>>,
ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
ops::SquareGradFunctor<int64_t>>);
REGISTER_ACTIVATION_OP(log2, Log2, Log2Functor, Log2GradFunctor);
REGISTER_ACTIVATION_OP(log10, Log10, Log10Functor, Log10GradFunctor);
REGISTER_ACTIVATION_OP(log1p, Log1p, Log1pFunctor, Log1pGradFunctor);
Expand Down Expand Up @@ -1630,12 +1657,7 @@ REGISTER_OPERATOR(logit, ops::LogitOp, ops::LogitOpMaker,
ops::LogitGradOpMaker<paddle::framework::OpDesc>,
ops::LogitGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(logit_grad, ops::LogitGradOp);
REGISTER_OP_CPU_KERNEL(
logit, ops::LogitKernel<paddle::platform::CPUDeviceContext, float>,
ops::LogitKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
logit_grad, ops::LogitGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::LogitGradKernel<paddle::platform::CPUDeviceContext, double>);

/* ========================================================================== */

/* ======================== celu register ============================
Expand Down Expand Up @@ -1684,7 +1706,6 @@ REGISTER_OPERATOR(
ops::ActivationOpDoubleGrad<ops::SqrtGradGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInferer);

REGISTER_ACTIVATION_CPU_KERNEL(sqrt, Sqrt, SqrtFunctor, SqrtGradFunctor);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sqrt和square有SelcetedRows类型的输入,在selectedrows未迁移的情况下删掉这里会不会出现问题

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我恢复下

REGISTER_OP_CPU_KERNEL(
sqrt_grad_grad, ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
ops::SqrtGradGradFunctor<float>>,
Expand Down Expand Up @@ -1712,7 +1733,6 @@ REGISTER_OPERATOR(
ops::ActivationOpDoubleGrad<ops::RsqrtGradGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInferer);

REGISTER_ACTIVATION_CPU_KERNEL(rsqrt, Rsqrt, RsqrtFunctor, RsqrtGradFunctor);
REGISTER_OP_CPU_KERNEL(
rsqrt_grad_grad,
ops::RsqrtDoubleGradKernel<plat::CPUDeviceContext,
Expand Down Expand Up @@ -1741,25 +1761,6 @@ REGISTER_OPERATOR(
ops::ActivationOpDoubleGrad<ops::SquareGradGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInferer);

REGISTER_OP_CPU_KERNEL(square,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::SquareFunctor<float>>,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::SquareFunctor<double>>,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::SquareFunctor<int>>,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::SquareFunctor<int64_t>>);
REGISTER_OP_CPU_KERNEL(
square_grad, ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
ops::SquareGradFunctor<float>>,
ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
ops::SquareGradFunctor<double>>,
ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
ops::SquareGradFunctor<int>>,
ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
ops::SquareGradFunctor<int64_t>>);

REGISTER_OP_CPU_KERNEL(
square_grad_grad,
ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
Expand Down Expand Up @@ -1798,52 +1799,19 @@ REGISTER_OPERATOR(
REGISTER_OPERATOR(exp_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInferer);

REGISTER_OP_CPU_KERNEL(exp,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::ExpFunctor<float>>,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::ExpFunctor<double>>,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::ExpFunctor<int>>,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::ExpFunctor<int64_t>>);
REGISTER_OP_CPU_KERNEL(
exp_grad, ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
ops::ExpGradFunctor<float>>,
ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
ops::ExpGradFunctor<double>>,
ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
ops::ExpGradFunctor<int>>,
ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
ops::ExpGradFunctor<int64_t>>);
/* ========================================================================== */

/* ========================== expm1 register ============================ */
REGISTER_OPERATOR(
expm1, ops::ActivationOp, ops::Expm1OpMaker, ops::ActivationOpInferVarType,
ops::ActivationGradOpMaker<ops::Expm1GradFunctor<float>::FwdDeps(),
paddle::framework::OpDesc>,
ops::ActivationGradOpMaker<ops::Expm1GradFunctor<float>::FwdDeps(),
paddle::imperative::OpBase>,
std::conditional<ops::CanInplaceAct<ops::Expm1GradFunctor<float>>(),
ops::ActFwdInplaceInferer, void>::type);
REGISTER_OPERATOR(expm1_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInferer);
// REGISTER_OPERATOR(
// expm1, ops::ActivationOp, ops::Expm1OpMaker,
// ops::ActivationOpInferVarType,
// ops::ActivationGradOpMaker<ops::Expm1GradFunctor<float>::FwdDeps(),
// paddle::framework::OpDesc>,
// ops::ActivationGradOpMaker<ops::Expm1GradFunctor<float>::FwdDeps(),
// paddle::imperative::OpBase>,
// std::conditional<ops::CanInplaceAct<ops::Expm1GradFunctor<float>>(),
// ops::ActFwdInplaceInferer, void>::type);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

无用注释删掉

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是不是没删

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


REGISTER_OP_CPU_KERNEL(expm1,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::Expm1Functor<float>>,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::Expm1Functor<double>>,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::Expm1Functor<plat::float16>>);
REGISTER_OP_CPU_KERNEL(
expm1_grad, ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
ops::Expm1GradFunctor<float>>,
ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
ops::Expm1GradFunctor<double>>,
ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
ops::Expm1GradFunctor<plat::float16>>);
/* ========================================================================== */

/* ========================== Log register ==================================*/
Expand Down
Loading