From 6a4b08338b64492d9d0b7ae80f0386e563cf1d14 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Fri, 4 Mar 2022 11:22:51 +0000 Subject: [PATCH 1/8] move activation op --- cmake/operators.cmake | 2 +- .../ir/mkldnn/mkldnn_inplace_pass_tester.cc | 2 +- .../paddle2cinn/cinn_compiler_test.cc | 2 +- .../fluid/imperative/tests/test_prepare_op.cc | 2 +- .../tensorrt/convert/test_activation_op.cc | 2 +- .../fluid/operators/activation_cudnn_op.cu.cc | 19 +- paddle/fluid/operators/activation_op.cc | 35 +- paddle/fluid/operators/activation_op.h | 586 +++---------- paddle/fluid/operators/activation_op.kps | 454 ++-------- .../operators/mkldnn/test_mkldnn_caching.cc | 2 +- .../mkldnn/test_mkldnn_op_inplace.cc | 2 +- .../operators/mkldnn/test_mkldnn_op_nhwc.cc | 2 +- .../operators/mlu/activation_op_mlu_test.cc | 2 +- .../test_common_infer_shape_functions.cc | 2 +- paddle/phi/kernels/activation_grad_kernel.h | 54 ++ paddle/phi/kernels/activation_kernel.h | 40 + .../phi/kernels/cpu/activation_grad_kernel.cc | 102 +++ paddle/phi/kernels/cpu/activation_kernel.cc | 55 ++ paddle/phi/kernels/funcs/activation_functor.h | 829 ++++++++++++++++++ .../phi/kernels/gpu/activation_grad_kernel.cu | 233 +++++ paddle/phi/kernels/gpu/activation_kernel.cu | 142 +++ paddle/phi/kernels/impl/activation_impl.h | 159 ++++ paddle/phi/ops/compat/activation_sig.cc | 66 ++ 23 files changed, 1901 insertions(+), 893 deletions(-) create mode 100644 paddle/phi/kernels/activation_grad_kernel.h create mode 100644 paddle/phi/kernels/activation_kernel.h create mode 100644 paddle/phi/kernels/cpu/activation_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/activation_kernel.cc create mode 100644 paddle/phi/kernels/funcs/activation_functor.h create mode 100644 paddle/phi/kernels/gpu/activation_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/activation_kernel.cu create mode 100644 paddle/phi/kernels/impl/activation_impl.h create mode 100644 paddle/phi/ops/compat/activation_sig.cc diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 7affd59de162d..ad2dbdca60552 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -478,7 +478,7 @@ function(op_library TARGET) if (${pybind_flag} EQUAL 0) # NOTE(*): activation use macro to regist the kernels, set use_op manually. if(${TARGET} STREQUAL "activation") - file(APPEND ${pybind_file} "USE_OP(relu);\n") + file(APPEND ${pybind_file} "USE_OP_ITSELF(relu);\n") elseif(${TARGET} STREQUAL "fake_dequantize") file(APPEND ${pybind_file} "USE_OP(fake_dequantize_max_abs);\n") elseif(${TARGET} STREQUAL "fake_quantize") diff --git a/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass_tester.cc index 0a95444f852dd..796aa4039c9e8 100644 --- a/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass_tester.cc @@ -27,7 +27,7 @@ USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN); USE_OP(leaky_relu); USE_OP_DEVICE_KERNEL(leaky_relu, MKLDNN); USE_OP(gelu); -USE_OP(relu); +USE_OP_ITSELF(relu); USE_OP(tanh); USE_OP_DEVICE_KERNEL(tanh, MKLDNN); diff --git a/paddle/fluid/framework/paddle2cinn/cinn_compiler_test.cc b/paddle/fluid/framework/paddle2cinn/cinn_compiler_test.cc index e8badab27b9b9..cdccc4c554690 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_compiler_test.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_compiler_test.cc @@ -301,5 +301,5 @@ TEST(CinnCompilerTest, Compile) { USE_PASS(build_cinn_pass); USE_PASS(graph_viz_pass); USE_OP(mul); -USE_OP(relu); +USE_OP_ITSELF(relu); USE_OP_ITSELF(elementwise_add); diff --git a/paddle/fluid/imperative/tests/test_prepare_op.cc b/paddle/fluid/imperative/tests/test_prepare_op.cc index f5ca13cb99ad3..17cbe06748234 100644 --- a/paddle/fluid/imperative/tests/test_prepare_op.cc +++ b/paddle/fluid/imperative/tests/test_prepare_op.cc @@ -226,7 +226,7 @@ TEST(test_prepare_op, test_prepare_data_cpu_mkldnn) { } // namespace paddle USE_OP_ITSELF(split); -USE_OP(relu); +USE_OP_ITSELF(relu); #ifdef PADDLE_WITH_MKLDNN USE_OP_DEVICE_KERNEL(relu, MKLDNN); #endif diff --git a/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc b/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc index f2dc5ba1c7c2c..7f7313fbcb596 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc @@ -52,7 +52,7 @@ TEST(Relu6OpConverter, main) { test_activation("relu6"); } } // namespace inference } // namespace paddle -USE_OP(relu); +USE_OP_ITSELF(relu); USE_OP(sigmoid); USE_OP(tanh); USE_OP(relu6); diff --git a/paddle/fluid/operators/activation_cudnn_op.cu.cc b/paddle/fluid/operators/activation_cudnn_op.cu.cc index 0ac29e6d3ada7..b4a97e24cf292 100644 --- a/paddle/fluid/operators/activation_cudnn_op.cu.cc +++ b/paddle/fluid/operators/activation_cudnn_op.cu.cc @@ -132,7 +132,9 @@ struct CudnnReluGradFunctor : public CudnnActivationGradFunctor { explicit CudnnReluGradFunctor(const CUDADeviceContext& ctx) : CudnnActivationGradFunctor(ctx, 0.0, GPUDNN_ACTIVATION_RELU) {} - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } }; template @@ -146,7 +148,9 @@ struct CudnnRelu6GradFunctor : public CudnnActivationGradFunctor { : CudnnActivationGradFunctor(ctx, 6.0, GPUDNN_ACTIVATION_CLIPPED_RELU) {} - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } }; template @@ -159,7 +163,9 @@ struct CudnnSigmoidGradFunctor : public CudnnActivationGradFunctor { explicit CudnnSigmoidGradFunctor(const CUDADeviceContext& ctx) : CudnnActivationGradFunctor(ctx, 0.0, GPUDNN_ACTIVATION_SIGMOID) {} - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } }; template @@ -172,7 +178,9 @@ struct CudnnTanhGradFunctor : public CudnnActivationGradFunctor { explicit CudnnTanhGradFunctor(const CUDADeviceContext& ctx) : CudnnActivationGradFunctor(ctx, 0.0, GPUDNN_ACTIVATION_TANH) {} - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } }; template @@ -197,7 +205,8 @@ class CudnnActivationGradKernel public: using T = typename Functor::ELEMENT_TYPE; void Compute(const framework::ExecutionContext& context) const override { - static_assert(Functor::FwdDeps() == kDepOut, "Forward deps must be Out."); + static_assert(Functor::FwdDeps() == ActBwdOpFwdDeps::kDepOut, + "Forward deps must be Out."); const framework::Tensor *X, *Out, *dOut; X = Out = dOut = nullptr; diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 73d65b7c6e7e0..12a629aa3024f 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -34,7 +34,8 @@ using paddle::framework::Tensor; template static constexpr bool CanInplaceAct() { - return GradFunctor::FwdDeps() == kDepOut || GradFunctor::FwdDeps() == kNoDeps; + return GradFunctor::FwdDeps() == ActBwdOpFwdDeps::kDepOut || + GradFunctor::FwdDeps() == ActBwdOpFwdDeps::kNoDeps; } #define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT) \ @@ -921,7 +922,8 @@ class ActivationOpDoubleGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - if (static_cast(kDepValue) & static_cast(kDepX)) { + if (static_cast(kDepValue) & + static_cast(ActBwdOpFwdDeps::kDepX)) { if (ctx->HasOutput("DX")) { ctx->ShareDim("X", "DX"); ctx->ShareLoD("X", "DX"); @@ -931,7 +933,8 @@ class ActivationOpDoubleGrad : public framework::OperatorWithKernel { ctx->ShareLoD("X", "DDOut"); } } - if (static_cast(kDepValue) & static_cast(kDepOut)) { + if (static_cast(kDepValue) & + static_cast(ActBwdOpFwdDeps::kDepOut)) { if (ctx->HasOutput("DOut")) { ctx->ShareDim("Out", "DOut"); ctx->ShareLoD("Out", "DOut"); @@ -960,13 +963,15 @@ class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - if (static_cast(kDepValue) & static_cast(kDepX)) { + if (static_cast(kDepValue) & + static_cast(ActBwdOpFwdDeps::kDepX)) { if (ctx->HasOutput("DDOut")) { ctx->ShareDim("X", "DDOut"); ctx->ShareLoD("X", "DDOut"); } } - if (static_cast(kDepValue) & static_cast(kDepOut)) { + if (static_cast(kDepValue) & + static_cast(ActBwdOpFwdDeps::kDepOut)) { if (ctx->HasOutput("DDOut")) { ctx->ShareDim("Out", "DDOut"); ctx->ShareLoD("Out", "DDOut"); @@ -987,7 +992,8 @@ class ActivationOpTripleGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - if (static_cast(kDepValue) & static_cast(kDepX)) { + if (static_cast(kDepValue) & + static_cast(ActBwdOpFwdDeps::kDepX)) { if (ctx->HasOutput("DX")) { ctx->ShareDim("X", "DX"); ctx->ShareLoD("X", "DX"); @@ -997,7 +1003,8 @@ class ActivationOpTripleGrad : public framework::OperatorWithKernel { ctx->ShareLoD("X", "DDOut"); } } - if (static_cast(kDepValue) & static_cast(kDepOut)) { + if (static_cast(kDepValue) & + static_cast(ActBwdOpFwdDeps::kDepOut)) { if (ctx->HasOutput("D_DOut")) { ctx->ShareDim("Out", "D_DOut"); ctx->ShareLoD("Out", "D_DOut"); @@ -1464,6 +1471,18 @@ namespace plat = paddle::platform; FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP); FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CPU_KERNEL); +REGISTER_ACTIVATION_OP(cos, Cos, CosFunctor, CosGradFunctor) +REGISTER_ACTIVATION_OP(tan, Tan, TanFunctor, TanGradFunctor); +REGISTER_ACTIVATION_OP(acos, Acos, AcosFunctor, AcosGradFunctor); +REGISTER_ACTIVATION_OP(sin, Sin, SinFunctor, SinGradFunctor); +REGISTER_ACTIVATION_OP(asin, Asin, AsinFunctor, AsinGradFunctor); +REGISTER_ACTIVATION_OP(atan, Atan, AtanFunctor, AtanGradFunctor); +REGISTER_ACTIVATION_OP(sinh, Sinh, SinhFunctor, SinhGradFunctor); +REGISTER_ACTIVATION_OP(cosh, Cosh, CoshFunctor, CoshGradFunctor); +REGISTER_ACTIVATION_OP(asinh, Asinh, AsinhFunctor, AsinhGradFunctor); +REGISTER_ACTIVATION_OP(acosh, Acosh, AcoshFunctor, AcoshGradFunctor); +REGISTER_ACTIVATION_OP(atanh, Atanh, AtanhFunctor, AtanhGradFunctor); + /* ========================== sigmoid register ============================= */ // 1. Register Sigmoid Operator @@ -1584,8 +1603,6 @@ REGISTER_OPERATOR( ops::ActivationOpDoubleGrad2::FwdDeps()>, ops::ActivationDoubleGradOpInplaceInferer); -REGISTER_ACTIVATION_CPU_KERNEL(relu, Relu, ReluCPUFunctor, ReluGradFunctor); - REGISTER_OP_CPU_KERNEL( relu_grad_grad, ops::ActivationDoubleGradKernel(kDepValue) & static_cast(kDepOut)) { + if (static_cast(kDepValue) & + static_cast(ActBwdOpFwdDeps::kDepOut)) { out_var = context.InputVar("Out"); PADDLE_ENFORCE_NOT_NULL( out_var, platform::errors::NotFound( @@ -139,7 +138,7 @@ inline void ExtractActivationGradTensor( "Output(Out), variable name = %s", context.OutputName(framework::GradVarName("X")))); - if (static_cast(kDepValue) & static_cast(kDepX)) { + if (static_cast(kDepValue) & static_cast(ActBwdOpFwdDeps::kDepX)) { auto x_var = context.InputVar("X"); PADDLE_ENFORCE_NOT_NULL(x_var, platform::errors::NotFound( "Cannot get the tensor from the " @@ -248,6 +247,24 @@ struct SigmoidFunctor : public BaseActivationFunctor { } }; +#define USE_PHI_FUNCTOR(name) \ + template \ + using name##Functor = phi::funcs::name##Functor; \ + template \ + using name##GradFunctor = phi::funcs::name##GradFunctor; + +USE_PHI_FUNCTOR(Cos) +USE_PHI_FUNCTOR(Tan) +USE_PHI_FUNCTOR(Acos) +USE_PHI_FUNCTOR(Sin) +USE_PHI_FUNCTOR(Asin) +USE_PHI_FUNCTOR(Atan) +USE_PHI_FUNCTOR(Sinh) +USE_PHI_FUNCTOR(Cosh) +USE_PHI_FUNCTOR(Asinh) +USE_PHI_FUNCTOR(Acosh) +USE_PHI_FUNCTOR(Atanh) + template struct SigmoidGradFunctor : public BaseActivationFunctor { template { dx.device(d) = dout * out * (static_cast(1) - out); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } }; /* @@ -293,7 +312,9 @@ struct SigmoidGradGradFunctor : public BaseActivationFunctor { ddout.device(*d) = (static_cast(1) - out) * out * ddx; } } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } }; /* @@ -351,7 +372,9 @@ struct SigmoidTripleGradFunctor : public BaseActivationFunctor { (static_cast(1) - static_cast(2) * out) * dout * d_dOutNew; } } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } }; // silu(x) = x / (1 + exp(-x)) @@ -376,7 +399,7 @@ struct SiluGradFunctor : public BaseActivationFunctor { (static_cast(1) + (temp2 / temp1))); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // Originally: logsigmoid(x) = -log (1 + exp(-x)) @@ -414,7 +437,7 @@ struct LogSigmoidGradFunctor : public BaseActivationFunctor { dout * ((-x - temp).exp() / ((-temp).exp() + (-x - temp).exp())); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // exp(x) = e^x @@ -434,7 +457,9 @@ struct ExpGradFunctor : public BaseActivationFunctor { dx.device(d) = dout * out; } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } }; // expm1(x) = e^x - 1 @@ -454,20 +479,20 @@ struct Expm1GradFunctor : public BaseActivationFunctor { dx.device(d) = dout * out + dout; } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } }; // relu(x) = max(x, 0) + template -struct ReluCPUFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out) const { - out.device(d) = x.unaryExpr([] HOSTDEVICE(T v) { - return v > static_cast(0) ? v : static_cast(0); - }); - } -}; +using ReluCPUFunctor = phi::funcs::ReluCPUFunctor; +template +using ReluGradFunctor = phi::funcs::ReluGradFunctor; +template +using ReluGradGradFunctor = phi::funcs::ReluGradGradFunctor; template struct ReluCUDAFunctor : public BaseActivationFunctor { template @@ -476,17 +501,6 @@ struct ReluCUDAFunctor : public BaseActivationFunctor { } }; -template -struct ReluGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - dx.device(d) = dout * (out > static_cast(0)).template cast(); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } -}; - // tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) template struct TanhFunctor : public BaseActivationFunctor { @@ -504,7 +518,9 @@ struct TanhGradFunctor : public BaseActivationFunctor { dx.device(d) = dout * (static_cast(1) - out * out); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } }; template @@ -534,7 +550,9 @@ struct TanhGradGradFunctor : public BaseActivationFunctor { ddout.device(*d) = (static_cast(1) - out * out) * ddx; } } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } }; /* Out @@ -589,7 +607,9 @@ struct TanhTripleGradFunctor : public BaseActivationFunctor { static_cast(2) * out * dout * d_dOutNew; } } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } }; // tanhshrink(x) = x - tanh(x) @@ -610,7 +630,7 @@ struct TanhShrinkGradFunctor : public BaseActivationFunctor { dx.device(d) = dout * (x.tanh() * x.tanh()); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // tanhshrink(x) = x - tanh(x) @@ -646,7 +666,7 @@ struct HardShrinkGradFunctor : public BaseActivationFunctor { dx.device(d) = dout * (temp1 || temp2).template cast(); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // softshrink(x) = x - lambda, if x > lambda; x + lambda, if x < -lambda; 0 @@ -682,7 +702,7 @@ struct SoftShrinkGradFunctor : public BaseActivationFunctor { dx.device(d) = dout * (temp1 + temp2).template cast(); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // sqrt(x) = x^(1/2) @@ -702,7 +722,9 @@ struct SqrtGradFunctor : public BaseActivationFunctor { dx.device(d) = static_cast(0.5) * dout / out; } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } }; // rsqrt(x) = x^(-1/2) @@ -722,7 +744,9 @@ struct RsqrtGradFunctor : public BaseActivationFunctor { dx.device(d) = static_cast(-0.5) * dout * out * out * out; } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } }; // ceil(x) = ceiling(x) @@ -742,7 +766,9 @@ struct ZeroGradFunctor : public BaseActivationFunctor { dx.device(d) = static_cast(0) * out; } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kNoDeps; } + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kNoDeps; + } }; // floor(x) = flooring(x) @@ -754,373 +780,6 @@ struct FloorFunctor : public BaseActivationFunctor { } }; -template -struct Sine { - HOSTDEVICE T operator()(const T& val) const { return sin(val); } -}; - -template <> -struct Sine { - HOSTDEVICE platform::float16 operator()(const platform::float16& val) const { - return platform::float16(sin(static_cast(val))); - } -}; - -template -struct Cosine { - HOSTDEVICE T operator()(const T& val) const { return cos(val); } -}; - -template <> -struct Cosine { - HOSTDEVICE platform::float16 operator()(const platform::float16& val) const { - return platform::float16(cos(static_cast(val))); - } -}; - -// cosine'(x) = -sin(x) -template -struct CosGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - dx.device(d) = -dout * x.unaryExpr(Sine()); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } -}; - -// cosine(x) = cos(x) -template -struct CosFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out) const { - out.device(d) = x.unaryExpr(Cosine()); - } -}; - -// sine'(x) = cos(x) -template -struct SinGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - dx.device(d) = dout * x.unaryExpr(Cosine()); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } -}; - -// sine(x) = sin(x) -template -struct SinFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out) const { - out.device(d) = x.unaryExpr(Sine()); - } -}; - -template -struct Tangent { - HOSTDEVICE T operator()(const T& val) const { return tan(val); } -}; - -template <> -struct Tangent { - HOSTDEVICE platform::float16 operator()(const platform::float16& val) const { - return platform::float16(tan(static_cast(val))); - } -}; - -// Tangent'(x) = -Tangent(x) -template -struct TanGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - dx.device(d) = dout / x.unaryExpr(Cosine()).square(); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } -}; - -// Tangent(x) = tan(x) -template -struct TanFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out) const { - out.device(d) = x.unaryExpr(Tangent()); - } -}; - -template -struct Sinh { - HOSTDEVICE T operator()(const T& val) const { return sinh(val); } -}; - -template <> -struct Sinh { - HOSTDEVICE platform::float16 operator()(const platform::float16& val) const { - return platform::float16(sinhf(static_cast(val))); - } -}; - -template -struct Cosh { - HOSTDEVICE T operator()(const T& val) const { return cosh(val); } -}; - -template <> -struct Cosh { - HOSTDEVICE platform::float16 operator()(const platform::float16& val) const { - return platform::float16(coshf(static_cast(val))); - } -}; - -// sinh(x) = sinh(x) -template -struct SinhFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out) const { - out.device(d) = x.unaryExpr(Sinh()); - } -}; - -// cosh(x) = cosh(x) -template -struct CoshFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out) const { - out.device(d) = x.unaryExpr(Cosh()); - } -}; - -// sinh'(x) = cosh(x) -template -struct SinhGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - dx.device(d) = dout * x.unaryExpr(Cosh()); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } -}; - -// cosh'(x) = sinh(x) -template -struct CoshGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - dx.device(d) = dout * x.unaryExpr(Sinh()); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } -}; - -template -struct Acos { - HOSTDEVICE T operator()(const T& val) const { return acos(val); } -}; - -template <> -struct Acos { - HOSTDEVICE platform::float16 operator()(const platform::float16& val) const { - return platform::float16(acos(static_cast(val))); - } -}; - -// Acos(x) = acos(x) -template -struct AcosFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out) const { - out.device(d) = x.unaryExpr(Acos()); - } -}; - -// acos'(x) = -1/sqrt(1-x^2) -template -struct AcosGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - dx.device(d) = - -dout * static_cast(1) / (static_cast(1) - x.square()).sqrt(); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } -}; - -template -struct Asin { - HOSTDEVICE T operator()(const T& val) const { return asin(val); } -}; - -template <> -struct Asin { - HOSTDEVICE platform::float16 operator()(const platform::float16& val) const { - return platform::float16(asin(static_cast(val))); - } -}; - -// Asin(x) = asin(x) -template -struct AsinFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out) const { - out.device(d) = x.unaryExpr(Asin()); - } -}; - -// asin'(x) = 1/sqrt(1-x^2) -template -struct AsinGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - dx.device(d) = - dout * static_cast(1) / (static_cast(1) - x.square()).sqrt(); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } -}; - -template -struct Atan { - HOSTDEVICE T operator()(const T& val) const { return atan(val); } -}; - -template <> -struct Atan { - HOSTDEVICE platform::float16 operator()(const platform::float16& val) const { - return platform::float16(atan(static_cast(val))); - } -}; - -// Atan(x) = atan(x) -template -struct AtanFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out) const { - out.device(d) = x.unaryExpr(Atan()); - } -}; - -// atan'(x) = 1 / (1 + x^2) -template -struct AtanGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - dx.device(d) = dout * static_cast(1) / (static_cast(1) + x.square()); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } -}; - -template -struct Acosh { - HOSTDEVICE T operator()(const T& val) const { return acosh(val); } -}; - -template <> -struct Acosh { - HOSTDEVICE platform::float16 operator()(const platform::float16& val) const { - return platform::float16(acosh(static_cast(val))); - } -}; - -// Acosh(x) = acosh(x) -template -struct AcoshFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out) const { - out.device(d) = x.unaryExpr(Acosh()); - } -}; - -// acosh'(x) = 1/sqrt(x^2 - 1) -template -struct AcoshGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - dx.device(d) = - dout * static_cast(1) / (x * x - static_cast(1)).sqrt(); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } -}; - -template -struct Asinh { - HOSTDEVICE T operator()(const T& val) const { return asinh(val); } -}; - -template <> -struct Asinh { - HOSTDEVICE platform::float16 operator()(const platform::float16& val) const { - return platform::float16(asinh(static_cast(val))); - } -}; - -// Asinh(x) = asinh(x) -template -struct AsinhFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out) const { - out.device(d) = x.unaryExpr(Asinh()); - } -}; - -// asinh'(x) = 1/sqrt(x^2 + 1) -template -struct AsinhGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - dx.device(d) = - dout * static_cast(1) / (x.square() + static_cast(1)).sqrt(); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } -}; - -template -struct Atanh { - HOSTDEVICE T operator()(const T& val) const { return atanh(val); } -}; - -template <> -struct Atanh { - HOSTDEVICE platform::float16 operator()(const platform::float16& val) const { - return platform::float16(atanh(static_cast(val))); - } -}; - -// Atanh(x) = atanh(x) -template -struct AtanhFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out) const { - out.device(d) = x.unaryExpr(Atanh()); - } -}; - -// atanh'(x) = 1/(1 - x^2) -template -struct AtanhGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - dx.device(d) = dout * static_cast(1) / (static_cast(1) - x.square()); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } -}; - // round(x) = [x] template struct RoundFunctor : public BaseActivationFunctor { @@ -1147,7 +806,9 @@ struct ReciprocalGradFunctor : public BaseActivationFunctor { dx.device(d) = dout * static_cast(-1) * out * out; } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } }; // log(x) = natural logarithm of x @@ -1167,7 +828,7 @@ struct LogGradFunctor : public BaseActivationFunctor { dx.device(d) = dout * (static_cast(1) / x); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // log2(x) = logarithm to the base 2 of the elements of x @@ -1188,7 +849,7 @@ struct Log2GradFunctor : public BaseActivationFunctor { dx.device(d) = dout * static_cast(1) / (x * static_cast(log(2))); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // log10(x) = logarithm to the base 10 of the elements of x @@ -1209,7 +870,7 @@ struct Log10GradFunctor : public BaseActivationFunctor { dx.device(d) = dout * static_cast(1) / (x * static_cast(log(10))); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // log1p(x) = natural logarithm of x+1 @@ -1229,7 +890,7 @@ struct Log1pGradFunctor : public BaseActivationFunctor { dx.device(d) = dout * (static_cast(1) / (x + static_cast(1))); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // square(x) = x^2 @@ -1249,7 +910,7 @@ struct SquareGradFunctor : public BaseActivationFunctor { dx.device(d) = dout * static_cast(2) * x; } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template @@ -1285,7 +946,7 @@ struct BReluGradFunctor : public BaseActivationFunctor { .template cast(); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // relu6(x) = min(max(0, x), 6) @@ -1319,7 +980,9 @@ struct Relu6GradFunctor : public BaseActivationFunctor { .template cast(); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } }; // HardSwish = min(max(0, x+3), 6) * x / 6 @@ -1364,7 +1027,7 @@ struct HardSwishGradFunctor : public BaseActivationFunctor { static_cast(1) * (static_cast(1) - tmp)); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // For numerical stability, using the following formula instead of softplus(x) = @@ -1409,7 +1072,7 @@ struct SoftplusGradFunctor : public BaseActivationFunctor { .select(dout, dout / (static_cast(1) + (-x_beta).exp())); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // mish(x) = x * tanh(softplus(x)) @@ -1449,7 +1112,7 @@ struct MishGradFunctor : public BaseActivationFunctor { dx.device(d) = dout * (tsp + x * (static_cast(1) - tsp * tsp) * gsp); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // softsign(x) = x / (1 + |x|) @@ -1472,7 +1135,7 @@ struct SoftsignGradFunctor : public BaseActivationFunctor { dout * (static_cast(1) / (static_cast(1) + x.abs()).square()); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template @@ -1504,7 +1167,9 @@ struct SoftReluGradFunctor : public BaseActivationFunctor { dx.device(d) = dout * (static_cast(1) - (-out).exp()) * temp; } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } }; template @@ -1539,7 +1204,7 @@ struct LeakyReluGradFunctor : public BaseActivationFunctor { dx.device(d) = dout * (temp1 + temp2).template cast(); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template @@ -1573,7 +1238,7 @@ struct ELUGradFunctor : public BaseActivationFunctor { .select(dout, dout * (out + static_cast(alpha))); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template @@ -1592,7 +1257,7 @@ struct ELUGradNegativeAlphaFunctor : public BaseActivationFunctor { .select(dout, dout * static_cast(alpha) * x.exp()); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template @@ -1672,7 +1337,7 @@ struct CELUGradFunctor : public BaseActivationFunctor { dout * (x / static_cast(alpha)).exp() * temp_a_neg * temp_x_neg; } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5198 @@ -1701,7 +1366,7 @@ struct PowGradFunctor : public BaseActivationFunctor { x.pow(static_cast(factor) - static_cast(1)); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template @@ -1766,7 +1431,7 @@ struct STanhGradFunctor : public BaseActivationFunctor { dx.device(d) = dout * a * b * (static_cast(1) - temp); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template @@ -1797,7 +1462,7 @@ struct ThresholdedReluGradFunctor : public BaseActivationFunctor { dx.device(d) = dout * (x > th).template cast(); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template @@ -1832,7 +1497,9 @@ struct HardSigmoidGradFunctor : public BaseActivationFunctor { static_cast(slope); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } }; template @@ -1865,7 +1532,7 @@ struct SwishGradFunctor : public BaseActivationFunctor { dx.device(d) = dout * ((static_cast(beta) * out) + temp2); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; /* @@ -1902,7 +1569,7 @@ inline void ExtractActivationDoubleGradTensor( "Cannot get the tensor from the Variable Output, variable name = %s", ctx.OutputName("DDX"))); - if (static_cast(kDepValue) & static_cast(kDepX)) { + if (static_cast(kDepValue) & static_cast(ActBwdOpFwdDeps::kDepX)) { auto x_var = ctx.InputVar("X"); PADDLE_ENFORCE_NOT_NULL( x_var, platform::errors::NotFound( @@ -1925,7 +1592,8 @@ inline void ExtractActivationDoubleGradTensor( VLOG(10) << "Inplace activation of Op: " << ctx.Type(); *X = *ddX; } - if (static_cast(kDepValue) & static_cast(kDepOut)) { + if (static_cast(kDepValue) & + static_cast(ActBwdOpFwdDeps::kDepOut)) { auto out_var = ctx.InputVar("Out"); PADDLE_ENFORCE_NOT_NULL( out_var, @@ -2000,28 +1668,7 @@ struct AbsGradGradFunctor : public BaseActivationFunctor { ddout.device(*d) = ddx * x.sign(); } } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } -}; - -template -struct ReluGradGradFunctor : public BaseActivationFunctor { - template - void operator()(const Device& dev, const framework::Tensor* X, - const framework::Tensor* Out, const framework::Tensor* ddX, - framework::Tensor* ddOut, framework::Tensor* dOut, - framework::Tensor* dX) const { - auto* d = dev.eigen_device(); - auto ddx = framework::EigenVector::Flatten( - GET_DATA_SAFELY(ddX, "Input", "DDX", "ReluGradGrad")); - auto out = framework::EigenVector::Flatten( - GET_DATA_SAFELY(Out, "Output", "Out", "ReluGradGrad")); - if (ddOut) { - auto ddout = framework::EigenVector::Flatten( - GET_DATA_SAFELY(ddOut, "Output", "DDOut", "ReluGradGrad")); - ddout.device(*d) = ddx * (out > static_cast(0)).template cast(); - } - } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template @@ -2050,7 +1697,7 @@ struct LeakyReluGradGradFunctor : public BaseActivationFunctor { .template cast(); } } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template @@ -2088,7 +1735,7 @@ struct ELUGradGradFunctor : public BaseActivationFunctor { .template cast(); } } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template @@ -2127,7 +1774,7 @@ struct CELUGradGradFunctor : public BaseActivationFunctor { .template cast(); } } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template @@ -2156,7 +1803,9 @@ struct SqrtGradGradFunctor : public BaseActivationFunctor { ddout.device(*d) = ddx * static_cast(0.5) / out; } } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } }; template @@ -2185,7 +1834,9 @@ struct RsqrtGradGradFunctor : public BaseActivationFunctor { ddout.device(*d) = ddx * static_cast(-0.5) * out * out * out; } } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } }; template @@ -2214,7 +1865,7 @@ struct SquareGradGradFunctor : public BaseActivationFunctor { ddout.device(*d) = ddx * static_cast(2) * x; } } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // TODO(dengkaipeng): double gradient calculation for Square/Sqrt need @@ -2840,7 +2491,7 @@ struct LogGradGradFunctor : public BaseActivationFunctor { } } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; } // namespace operators @@ -2849,20 +2500,9 @@ struct LogGradGradFunctor : public BaseActivationFunctor { #define FOR_EACH_ACTIVATION_OP(__macro) \ __macro(silu, Silu, SiluFunctor, SiluGradFunctor); \ __macro(logsigmoid, LogSigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \ - __macro(atan, Atan, AtanFunctor, AtanGradFunctor); \ __macro(softshrink, SoftShrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \ __macro(ceil, Ceil, CeilFunctor, ZeroGradFunctor); \ __macro(floor, Floor, FloorFunctor, ZeroGradFunctor); \ - __macro(cos, Cos, CosFunctor, CosGradFunctor); \ - __macro(tan, Tan, TanFunctor, TanGradFunctor); \ - __macro(acos, Acos, AcosFunctor, AcosGradFunctor); \ - __macro(sin, Sin, SinFunctor, SinGradFunctor); \ - __macro(asin, Asin, AsinFunctor, AsinGradFunctor); \ - __macro(sinh, Sinh, SinhFunctor, SinhGradFunctor); \ - __macro(cosh, Cosh, CoshFunctor, CoshGradFunctor); \ - __macro(asinh, Asinh, AsinhFunctor, AsinhGradFunctor); \ - __macro(acosh, Acosh, AcoshFunctor, AcoshGradFunctor); \ - __macro(atanh, Atanh, AtanhFunctor, AtanhGradFunctor); \ __macro(round, Round, RoundFunctor, ZeroGradFunctor); \ __macro(reciprocal, Reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \ __macro(log1p, Log1p, Log1pFunctor, Log1pGradFunctor); \ diff --git a/paddle/fluid/operators/activation_op.kps b/paddle/fluid/operators/activation_op.kps index e1afb3919f813..54e5a92de7db6 100644 --- a/paddle/fluid/operators/activation_op.kps +++ b/paddle/fluid/operators/activation_op.kps @@ -18,28 +18,6 @@ limitations under the License. */ namespace paddle { namespace operators { -template -struct CudaReluFunctor : public BaseActivationFunctor { - T zero = static_cast(0.0f); - - // relu(x) = max(x, 0) - __device__ __forceinline__ T operator()(const T x) const { - return x > zero ? x : zero; - } -}; - -template -struct CudaReluGradFunctor : public BaseActivationFunctor { - T zero = static_cast(0.0f); - - // dx = dout * (out > 0) - __device__ __forceinline__ T operator()(const T dout, const T out) const { - return out > zero ? dout : zero; - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } -}; - template struct CudaLeakyReluFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); @@ -69,7 +47,7 @@ struct CudaLeakyReluGradFunctor : public BaseActivationFunctor { return x > zero ? dout : static_cast(alpha) * dout; } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template @@ -93,7 +71,9 @@ struct CudaSigmoidGradFunctor : public BaseActivationFunctor { return dout * out * (one - out); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } }; template @@ -122,7 +102,7 @@ struct CudaSiluGradFunctor : public BaseActivationFunctor { return static_cast(dout * (temp * (one + x * (one - temp)))); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template @@ -159,30 +139,7 @@ struct CudaLogSigmoidGradFunctor : public BaseActivationFunctor { return static_cast(dout * (temp2 / (exp(-temp1) + temp2))); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } -}; - -template -struct CudaAtanFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - - // atan(x) = atan(x) - __device__ __forceinline__ T operator()(const T arg_x) const { - MPType x = static_cast(arg_x); - return static_cast(atan(x)); - } -}; - -template -struct CudaAtanGradFunctor : public BaseActivationFunctor { - T one = static_cast(1.0f); - - // dx = dout / (1 + x^2) - __device__ __forceinline__ T operator()(const T dout, const T x) const { - return dout / (one + x * x); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template @@ -219,7 +176,7 @@ struct CudaSoftShrinkGradFunctor : public BaseActivationFunctor { return (x >= -l && x <= l) ? zero : dout; } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template @@ -262,191 +219,9 @@ struct CudaZeroGradFunctor : public BaseActivationFunctor { return static_cast(0.0f); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kNoDeps; } -}; - -template -struct CudaCosFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - - // cos(x) = cos(x) - __device__ __forceinline__ T operator()(const T arg_x) const { - MPType x = static_cast(arg_x); - return static_cast(cos(x)); - } -}; - -template -struct CudaCosGradFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - - // dx = dout * (-sin(x)) - __device__ __forceinline__ T operator()(const T arg_dout, - const T arg_x) const { - MPType dout = static_cast(arg_dout); - MPType x = static_cast(arg_x); - return static_cast(-dout * sin(x)); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } -}; - -template -struct CudaSinFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - - // sin(x) = sin(x) - __device__ __forceinline__ T operator()(const T arg_x) const { - MPType x = static_cast(arg_x); - return static_cast(sin(x)); - } -}; - -template -struct CudaSinGradFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - - // dx = dout * cos(x) - __device__ __forceinline__ T operator()(const T arg_dout, - const T arg_x) const { - MPType dout = static_cast(arg_dout); - MPType x = static_cast(arg_x); - return static_cast(dout * cos(x)); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } -}; - -template -struct CudaTanFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - - // tan(x) = tan(x) - __device__ __forceinline__ T operator()(const T arg_x) const { - MPType x = static_cast(arg_x); - return static_cast(tan(x)); - } -}; - -template -struct CudaTanGradFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - - // dx = dout / cos(x)^2 - __device__ __forceinline__ T operator()(const T arg_dout, - const T arg_x) const { - MPType dout = static_cast(arg_dout); - MPType x = static_cast(arg_x); - return static_cast(dout / (cos(x) * cos(x))); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } -}; - -template -struct CudaAsinFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - - // asin(x) = asin(x) - __device__ __forceinline__ T operator()(const T arg_x) const { - MPType x = static_cast(arg_x); - return static_cast(asin(x)); - } -}; - -template -struct CudaAsinGradFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - MPType one = static_cast(1.0f); - - // dx = dout / sqrt(1 - x^2) - __device__ __forceinline__ T operator()(const T arg_dout, - const T arg_x) const { - MPType dout = static_cast(arg_dout); - MPType x = static_cast(arg_x); - return static_cast(dout / sqrt(one - x * x)); + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kNoDeps; } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } -}; - -template -struct CudaAcosFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - - // acos(x) = acos(x) - __device__ __forceinline__ T operator()(const T arg_x) const { - MPType x = static_cast(arg_x); - return static_cast(acos(x)); - } -}; - -template -struct CudaAcosGradFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - MPType one = static_cast(1.0f); - - // dx = -dout / sqrt(1 - x^2) - __device__ __forceinline__ T operator()(const T arg_dout, - const T arg_x) const { - MPType dout = static_cast(arg_dout); - MPType x = static_cast(arg_x); - return static_cast(-dout / sqrt(one - x * x)); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } -}; - -template -struct CudaCoshFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - - // cosh(x) = cosh(x) - __device__ __forceinline__ T operator()(const T arg_x) const { - MPType x = static_cast(arg_x); - return static_cast(cosh(x)); - } -}; - -template -struct CudaCoshGradFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - - // dx = dout * sinh(x) - __device__ __forceinline__ T operator()(const T arg_dout, - const T arg_x) const { - MPType dout = static_cast(arg_dout); - MPType x = static_cast(arg_x); - return static_cast(dout * sinh(x)); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } -}; - -template -struct CudaSinhFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - - // sinh(x) = sinh(x) - __device__ __forceinline__ T operator()(const T arg_x) const { - MPType x = static_cast(arg_x); - return static_cast(sinh(x)); - } -}; - -template -struct CudaSinhGradFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - - // dx = dout * cosh(x) - __device__ __forceinline__ T operator()(const T arg_dout, - const T arg_x) const { - MPType dout = static_cast(arg_dout); - MPType x = static_cast(arg_x); - return static_cast(dout * cosh(x)); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template @@ -469,88 +244,11 @@ struct CudaTanhGradFunctor : public BaseActivationFunctor { return dout * (one - out * out); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } -}; - -template -struct CudaAcoshFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - - // Acosh(x) = acosh(x) - __device__ __forceinline__ T operator()(const T arg_x) const { - MPType x = static_cast(arg_x); - return static_cast(acosh(x)); + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; } }; -template -struct CudaAcoshGradFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - MPType one = static_cast(1.0f); - // dx = dout * 1 / sqrt(x^2 - 1) - __device__ __forceinline__ T operator()(const T arg_dout, - const T arg_x) const { - MPType dout = static_cast(arg_dout); - MPType x = static_cast(arg_x); - return static_cast(dout * one / sqrt(x * x - one)); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } -}; - -template -struct CudaAsinhFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - - // Asinh(x) = asinh(x) - __device__ __forceinline__ T operator()(const T arg_x) const { - MPType x = static_cast(arg_x); - return static_cast(asinh(x)); - } -}; - -template -struct CudaAsinhGradFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - MPType one = static_cast(1.0f); - - // dx = dout * 1/sqrt(x^2 + 1) - __device__ __forceinline__ T operator()(const T arg_dout, - const T arg_x) const { - MPType dout = static_cast(arg_dout); - MPType x = static_cast(arg_x); - return static_cast(dout * one / sqrt(x * x + one)); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } -}; - -template -struct CudaAtanhFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - - // Atanh(x) = atanh(x) - __device__ __forceinline__ T operator()(const T arg_x) const { - MPType x = static_cast(arg_x); - return static_cast(atanh(x)); - } -}; - -template -struct CudaAtanhGradFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - MPType one = static_cast(1.0f); - // dx = dout * 1/(1- x^2) - __device__ __forceinline__ T operator()(const T arg_dout, - const T arg_x) const { - MPType dout = static_cast(arg_dout); - MPType x = static_cast(arg_x); - return static_cast(dout * one / (one - x * x)); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } -}; - template struct CudaReciprocalFunctor : public BaseActivationFunctor { T one = static_cast(1.0f); @@ -566,7 +264,9 @@ struct CudaReciprocalGradFunctor : public BaseActivationFunctor { return -dout * out * out; } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } }; template @@ -587,7 +287,9 @@ struct CudaExpGradFunctor : public BaseActivationFunctor { return dout * out; } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } }; template @@ -608,7 +310,9 @@ struct CudaExpm1GradFunctor : public BaseActivationFunctor { return dout * out + dout; } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } }; template @@ -629,7 +333,7 @@ struct CudaLogGradFunctor : public BaseActivationFunctor { return dout / x; } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template @@ -647,7 +351,7 @@ struct CudaSquareGradFunctor : public BaseActivationFunctor { return dout * two * x; } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template @@ -670,7 +374,9 @@ struct CudaSqrtGradFunctor : public BaseActivationFunctor { return one_half * dout / out; } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } }; template @@ -693,7 +399,9 @@ struct CudaRsqrtGradFunctor : public BaseActivationFunctor { return minus_one_half * dout * out * out * out; } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } }; template @@ -717,7 +425,7 @@ struct CudaLog1pGradFunctor : public BaseActivationFunctor { return dout / (one + x); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template @@ -741,7 +449,7 @@ struct CudaLog2GradFunctor : public BaseActivationFunctor { return dout / (x * log_two); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template @@ -765,7 +473,7 @@ struct CudaLog10GradFunctor : public BaseActivationFunctor { return dout / (x * log_ten); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template @@ -804,7 +512,7 @@ struct CudaBReluGradFunctor : public BaseActivationFunctor { return (x > t_min_cast && x < t_max_cast) ? dout : zero; } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template @@ -849,7 +557,9 @@ struct CudaSoftReluGradFunctor : public BaseActivationFunctor { : static_cast(0.0f); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } }; template @@ -893,7 +603,7 @@ struct CudaSTanhGradFunctor : public BaseActivationFunctor { return static_cast(dout * a * b * (one - temp * temp)); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template @@ -939,7 +649,7 @@ struct CudaSoftplusGradFunctor : public BaseActivationFunctor { return x_beta > t ? arg_dout : static_cast(dout / (one + exp(-x_beta))); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template @@ -962,7 +672,7 @@ struct CudaSoftsignGradFunctor : public BaseActivationFunctor { return dout / (temp * temp); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template @@ -996,7 +706,9 @@ struct CudaRelu6GradFunctor : public BaseActivationFunctor { return (out > zero && out < t) ? dout : zero; } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } }; template @@ -1022,7 +734,7 @@ struct CudaTanhShrinkGradFunctor : public BaseActivationFunctor { return static_cast(dout * tanh(x) * tanh(x)); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template @@ -1056,7 +768,7 @@ struct CudaHardShrinkGradFunctor : public BaseActivationFunctor { return (x > -t && x < t) ? zero : dout; } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template @@ -1097,7 +809,9 @@ struct CudaHardSigmoidGradFunctor : public BaseActivationFunctor { return (out > zero && out < one) ? dout * static_cast(slope) : zero; } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } }; template @@ -1141,7 +855,7 @@ struct CudaSwishGradFunctor : public BaseActivationFunctor { return static_cast(dout * (temp2 + temp3)); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template @@ -1190,7 +904,7 @@ struct CudaMishGradFunctor : public BaseActivationFunctor { return static_cast(dout * (tsp + x * (one - tsp * tsp) * gsp)); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template @@ -1222,7 +936,7 @@ struct CudaThresholdedReluGradFunctor : public BaseActivationFunctor { return x > static_cast(threshold) ? dout : zero; } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template @@ -1274,7 +988,7 @@ struct CudaHardSwishGradFunctor : public BaseActivationFunctor { return dout * (temp1 * temp2 * (two * x + o) / s + one - temp2); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template @@ -1320,7 +1034,9 @@ struct CudaELUGradFunctor : public BaseActivationFunctor { return static_cast(dout * (out_pos + out_neg * (out + a))); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } }; template @@ -1347,7 +1063,7 @@ struct CudaELUGradNegativeAlphaFunctor : public BaseActivationFunctor { return static_cast(dout * (x_pos + x_neg * (out + a))); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template @@ -1429,7 +1145,7 @@ struct CudaCELUGradFunctor : public BaseActivationFunctor { temp_a_neg * temp_x_pos + exp(x / a) * temp_a_neg * temp_x_neg)); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template @@ -1477,13 +1193,14 @@ class ActivationGradCudaKernel std::vector ins = {d_out}; std::vector outs = {d_x}; - if (static_cast(Functor::FwdDeps()) == static_cast(kDepOut)) { + if (static_cast(Functor::FwdDeps()) == + static_cast(ActBwdOpFwdDeps::kDepOut)) { // Only need forward output Out ins.push_back(out); paddle::operators::LaunchSameDimsElementwiseCudaKernel(dev_ctx, ins, &outs, functor); } else if (static_cast(Functor::FwdDeps()) == - static_cast(kDepX)) { + static_cast(ActBwdOpFwdDeps::kDepX)) { // Only need forward input X ins.push_back(x); paddle::operators::LaunchSameDimsElementwiseCudaKernel(dev_ctx, ins, @@ -1594,50 +1311,6 @@ REGISTER_OP_CUDA_KERNEL( ops::CELUGradGradFunctor>); /* ========================================================================== */ -/* =========================== relu register ============================ */ -#ifdef PADDLE_WITH_HIP -REGISTER_ACTIVATION_CUDA_KERNEL(relu, Relu, CudaReluFunctor, - CudaReluGradFunctor); -REGISTER_OP_CUDA_KERNEL( - relu_grad_grad, - ops::ActivationDoubleGradKernel>, - ops::ActivationDoubleGradKernel>, - ops::ActivationDoubleGradKernel>); -#else -REGISTER_OP_CUDA_KERNEL( - relu, ops::ActivationCudaKernel>, - ops::ActivationCudaKernel>, - ops::ActivationCudaKernel>, - ops::ActivationCudaKernel>); -REGISTER_OP_CUDA_KERNEL( - relu_grad, ops::ActivationGradCudaKernel>, - ops::ActivationGradCudaKernel>, - ops::ActivationGradCudaKernel>, - ops::ActivationGradCudaKernel>); -REGISTER_OP_CUDA_KERNEL( - relu_grad_grad, - ops::ActivationDoubleGradKernel>, - ops::ActivationDoubleGradKernel>, - ops::ActivationDoubleGradKernel>, - ops::ActivationDoubleGradKernel>); -#endif -/* ========================================================================== */ - /* =========================== sigmoid register ============================ */ REGISTER_ACTIVATION_CUDA_KERNEL(sigmoid, Sigmoid, CudaSigmoidFunctor, @@ -1821,21 +1494,10 @@ REGISTER_OP_CUDA_KERNEL( __macro(silu, Silu, CudaSiluFunctor, CudaSiluGradFunctor); \ __macro(logsigmoid, LogSigmoid, CudaLogSigmoidFunctor, \ CudaLogSigmoidGradFunctor); \ - __macro(atan, Atan, CudaAtanFunctor, CudaAtanGradFunctor); \ __macro(softshrink, SoftShrink, CudaSoftShrinkFunctor, \ CudaSoftShrinkGradFunctor); \ __macro(ceil, Ceil, CudaCeilFunctor, CudaZeroGradFunctor); \ __macro(floor, Floor, CudaFloorFunctor, CudaZeroGradFunctor); \ - __macro(cos, Cos, CudaCosFunctor, CudaCosGradFunctor); \ - __macro(tan, Tan, CudaTanFunctor, CudaTanGradFunctor); \ - __macro(acos, Acos, CudaAcosFunctor, CudaAcosGradFunctor); \ - __macro(sin, Sin, CudaSinFunctor, CudaSinGradFunctor); \ - __macro(asin, Asin, CudaAsinFunctor, CudaAsinGradFunctor); \ - __macro(sinh, Sinh, CudaSinhFunctor, CudaSinhGradFunctor); \ - __macro(cosh, Cosh, CudaCoshFunctor, CudaCoshGradFunctor); \ - __macro(asinh, Asinh, CudaAsinhFunctor, CudaAsinhGradFunctor); \ - __macro(acosh, Acosh, CudaAcoshFunctor, CudaAcoshGradFunctor); \ - __macro(atanh, Atanh, CudaAtanhFunctor, CudaAtanhGradFunctor); \ __macro(round, Round, CudaRoundFunctor, CudaZeroGradFunctor); \ __macro(reciprocal, Reciprocal, CudaReciprocalFunctor, \ CudaReciprocalGradFunctor); \ diff --git a/paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc b/paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc index 05cd264cf3ec9..23428dd403e9b 100644 --- a/paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc +++ b/paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc @@ -29,7 +29,7 @@ USE_OP_ITSELF(elementwise_add); USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN); USE_OP(elementwise_mul); USE_OP_DEVICE_KERNEL(elementwise_mul, MKLDNN); -USE_OP(relu); +USE_OP_ITSELF(relu); USE_OP_DEVICE_KERNEL(relu, MKLDNN); USE_OP_ITSELF(softmax); USE_OP_DEVICE_KERNEL(softmax, MKLDNN); diff --git a/paddle/fluid/operators/mkldnn/test_mkldnn_op_inplace.cc b/paddle/fluid/operators/mkldnn/test_mkldnn_op_inplace.cc index c776cf2a7c792..e9dadd5ec937c 100644 --- a/paddle/fluid/operators/mkldnn/test_mkldnn_op_inplace.cc +++ b/paddle/fluid/operators/mkldnn/test_mkldnn_op_inplace.cc @@ -27,7 +27,7 @@ USE_OP_ITSELF(elementwise_add); USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN); -USE_OP(relu); +USE_OP_ITSELF(relu); USE_OP_DEVICE_KERNEL(relu, MKLDNN); USE_OP_ITSELF(softmax); USE_OP_DEVICE_KERNEL(softmax, MKLDNN); diff --git a/paddle/fluid/operators/mkldnn/test_mkldnn_op_nhwc.cc b/paddle/fluid/operators/mkldnn/test_mkldnn_op_nhwc.cc index 3791fed23a84f..916f02179b364 100644 --- a/paddle/fluid/operators/mkldnn/test_mkldnn_op_nhwc.cc +++ b/paddle/fluid/operators/mkldnn/test_mkldnn_op_nhwc.cc @@ -27,7 +27,7 @@ USE_OP(pool2d); USE_OP_DEVICE_KERNEL(pool2d, MKLDNN); -USE_OP(relu); +USE_OP_ITSELF(relu); USE_OP_DEVICE_KERNEL(relu, MKLDNN); USE_OP_ITSELF(transpose); USE_OP_DEVICE_KERNEL(transpose, MKLDNN); diff --git a/paddle/fluid/operators/mlu/activation_op_mlu_test.cc b/paddle/fluid/operators/mlu/activation_op_mlu_test.cc index 884521301750c..6e3bd5e43c9c1 100644 --- a/paddle/fluid/operators/mlu/activation_op_mlu_test.cc +++ b/paddle/fluid/operators/mlu/activation_op_mlu_test.cc @@ -22,7 +22,7 @@ limitations under the License. */ namespace fw = paddle::framework; namespace plat = paddle::platform; -USE_OP(relu); +USE_OP_ITSELF(relu); USE_OP_DEVICE_KERNEL(relu, MLU); // relu diff --git a/paddle/fluid/operators/test_common_infer_shape_functions.cc b/paddle/fluid/operators/test_common_infer_shape_functions.cc index a7c7e33f58af6..1de1b590a1311 100644 --- a/paddle/fluid/operators/test_common_infer_shape_functions.cc +++ b/paddle/fluid/operators/test_common_infer_shape_functions.cc @@ -20,7 +20,7 @@ limitations under the License. */ #include "paddle/fluid/operators/common_infer_shape_functions.h" #include "paddle/phi/core/ddim.h" -USE_OP(relu); +USE_OP_ITSELF(relu); USE_OP_ITSELF(elementwise_add); USE_OP_ITSELF(softmax); diff --git a/paddle/phi/kernels/activation_grad_kernel.h b/paddle/phi/kernels/activation_grad_kernel.h new file mode 100644 index 0000000000000..0fb430c122ba7 --- /dev/null +++ b/paddle/phi/kernels/activation_grad_kernel.h @@ -0,0 +1,54 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/infermeta/unary.h" + +namespace phi { + +#define DECLARE_ACTIVATION_GRAD_KERNEL_DepX(name) \ + template \ + void name##GradKernel(const Context& dev_ctx, \ + const DenseTensor& x, \ + const DenseTensor& dout, \ + DenseTensor* dx); + +#define DECLARE_ACTIVATION_GRAD_KERNEL_DepOut(name) \ + template \ + void name##GradKernel(const Context& dev_ctx, \ + const DenseTensor& out, \ + const DenseTensor& dout, \ + DenseTensor* dx); + +template +void ReluDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& out, + const DenseTensor& ddx, + DenseTensor* ddout); + +DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Cos) DECLARE_ACTIVATION_GRAD_KERNEL_DepX( + Tan) DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Acos) + DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Sin) + DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Asin) + DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Atan) + DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Sinh) + DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Cosh) + DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Asinh) + DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Acosh) + DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Atanh) + DECLARE_ACTIVATION_GRAD_KERNEL_DepOut(Relu) + +} // namespace phi diff --git a/paddle/phi/kernels/activation_kernel.h b/paddle/phi/kernels/activation_kernel.h new file mode 100644 index 0000000000000..bdf8f4363598f --- /dev/null +++ b/paddle/phi/kernels/activation_kernel.h @@ -0,0 +1,40 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/infermeta/unary.h" + +namespace phi { + +#define DECLARE_ACTIVATION_KERNEL(name) \ + template \ + void name##Kernel( \ + const Context& dev_ctx, const DenseTensor& x, DenseTensor* out); + +DECLARE_ACTIVATION_KERNEL(Cos) +DECLARE_ACTIVATION_KERNEL(Tan) +DECLARE_ACTIVATION_KERNEL(Acos) +DECLARE_ACTIVATION_KERNEL(Sin) +DECLARE_ACTIVATION_KERNEL(Asin) +DECLARE_ACTIVATION_KERNEL(Atan) +DECLARE_ACTIVATION_KERNEL(Sinh) +DECLARE_ACTIVATION_KERNEL(Cosh) +DECLARE_ACTIVATION_KERNEL(Asinh) +DECLARE_ACTIVATION_KERNEL(Acosh) +DECLARE_ACTIVATION_KERNEL(Atanh) +DECLARE_ACTIVATION_KERNEL(Relu) + +} // namespace phi diff --git a/paddle/phi/kernels/cpu/activation_grad_kernel.cc b/paddle/phi/kernels/cpu/activation_grad_kernel.cc new file mode 100644 index 0000000000000..b9cbb5b118438 --- /dev/null +++ b/paddle/phi/kernels/cpu/activation_grad_kernel.cc @@ -0,0 +1,102 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/activation_grad_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/activation_impl.h" + +namespace phi { + +#define DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(name, functor_class) \ + template \ + void name##GradKernel(const Context& dev_ctx, \ + const DenseTensor& x, \ + const DenseTensor& dout, \ + DenseTensor* dx) { \ + functor_class functor; \ + ActivationGradImpl( \ + dev_ctx, &x, nullptr, &dout, dx, functor); \ + } + +#define DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepOut(name, functor_class) \ + template \ + void name##GradKernel(const Context& dev_ctx, \ + const DenseTensor& out, \ + const DenseTensor& dout, \ + DenseTensor* dx) { \ + functor_class functor; \ + ActivationGradImpl( \ + dev_ctx, nullptr, &out, &dout, dx, functor); \ + } + +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX( + Cos, + funcs::CosGradFunctor< + T>) DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Tan, + funcs::TanGradFunctor) + DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX( + Acos, + funcs::AcosGradFunctor< + T>) DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Sin, + funcs::SinGradFunctor) + DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Asin, funcs::AsinGradFunctor) + DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Atan, + funcs::AtanGradFunctor) + DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX( + Sinh, funcs::SinhGradFunctor) + DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX( + Cosh, funcs::CoshGradFunctor) + DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX( + Asinh, funcs::AsinhGradFunctor) + DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX( + Acosh, funcs::AcoshGradFunctor) + DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX( + Atanh, funcs::AtanhGradFunctor) + DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepOut( + Relu, funcs::ReluGradFunctor) + +} // namespace phi +PD_REGISTER_KERNEL( + cos_grad, CPU, ALL_LAYOUT, phi::CosGradKernel, float, double) {} +PD_REGISTER_KERNEL( + tan_grad, CPU, ALL_LAYOUT, phi::TanGradKernel, float, double) {} +PD_REGISTER_KERNEL( + acos_grad, CPU, ALL_LAYOUT, phi::AcosGradKernel, float, double) {} +PD_REGISTER_KERNEL( + sin_grad, CPU, ALL_LAYOUT, phi::SinGradKernel, float, double) {} +PD_REGISTER_KERNEL( + asin_grad, CPU, ALL_LAYOUT, phi::AsinGradKernel, float, double) {} +PD_REGISTER_KERNEL( + atan_grad, CPU, ALL_LAYOUT, phi::AtanGradKernel, float, double) {} +PD_REGISTER_KERNEL( + sinh_grad, CPU, ALL_LAYOUT, phi::SinhGradKernel, float, double) {} +PD_REGISTER_KERNEL( + cosh_grad, CPU, ALL_LAYOUT, phi::CoshGradKernel, float, double) {} +PD_REGISTER_KERNEL( + asinh_grad, CPU, ALL_LAYOUT, phi::AsinhGradKernel, float, double) {} +PD_REGISTER_KERNEL( + acosh_grad, CPU, ALL_LAYOUT, phi::AcoshGradKernel, float, double) {} +PD_REGISTER_KERNEL( + atanh_grad, CPU, ALL_LAYOUT, phi::AtanhGradKernel, float, double) {} +PD_REGISTER_KERNEL( + relu_grad, CPU, ALL_LAYOUT, phi::ReluGradKernel, float, double) {} +PD_REGISTER_KERNEL(relu_double_grad, + CPU, + ALL_LAYOUT, + phi::ReluDoubleGradKernel, + float, + double, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/cpu/activation_kernel.cc b/paddle/phi/kernels/cpu/activation_kernel.cc new file mode 100644 index 0000000000000..51883f25183af --- /dev/null +++ b/paddle/phi/kernels/cpu/activation_kernel.cc @@ -0,0 +1,55 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/activation_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/activation_impl.h" + +namespace phi { + +#define DEFINE_CPU_ACTIVATION_KERNEL(name, functor_class) \ + template \ + void name##Kernel( \ + const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) { \ + functor_class functor; \ + ActivationImpl(dev_ctx, x, out, functor); \ + } + +DEFINE_CPU_ACTIVATION_KERNEL(Sin, funcs::SinFunctor) +DEFINE_CPU_ACTIVATION_KERNEL(Cos, funcs::CosFunctor) +DEFINE_CPU_ACTIVATION_KERNEL(Tan, funcs::TanFunctor) +DEFINE_CPU_ACTIVATION_KERNEL(Asin, funcs::AsinFunctor) +DEFINE_CPU_ACTIVATION_KERNEL(Atan, funcs::AtanFunctor) +DEFINE_CPU_ACTIVATION_KERNEL(Acos, funcs::AcosFunctor) +DEFINE_CPU_ACTIVATION_KERNEL(Sinh, funcs::SinhFunctor) +DEFINE_CPU_ACTIVATION_KERNEL(Cosh, funcs::CoshFunctor) +DEFINE_CPU_ACTIVATION_KERNEL(Asinh, funcs::AsinhFunctor) +DEFINE_CPU_ACTIVATION_KERNEL(Acosh, funcs::AcoshFunctor) +DEFINE_CPU_ACTIVATION_KERNEL(Atanh, funcs::AtanhFunctor) +DEFINE_CPU_ACTIVATION_KERNEL(Relu, funcs::ReluCPUFunctor) + +} // namespace phi +PD_REGISTER_KERNEL(sin, CPU, ALL_LAYOUT, phi::SinKernel, float, double) {} +PD_REGISTER_KERNEL(cos, CPU, ALL_LAYOUT, phi::CosKernel, float, double) {} +PD_REGISTER_KERNEL(tan, CPU, ALL_LAYOUT, phi::TanKernel, float, double) {} +PD_REGISTER_KERNEL(acos, CPU, ALL_LAYOUT, phi::AcosKernel, float, double) {} +PD_REGISTER_KERNEL(asin, CPU, ALL_LAYOUT, phi::AsinKernel, float, double) {} +PD_REGISTER_KERNEL(atan, CPU, ALL_LAYOUT, phi::AtanKernel, float, double) {} +PD_REGISTER_KERNEL(sinh, CPU, ALL_LAYOUT, phi::SinhKernel, float, double) {} +PD_REGISTER_KERNEL(cosh, CPU, ALL_LAYOUT, phi::CoshKernel, float, double) {} +PD_REGISTER_KERNEL(asinh, CPU, ALL_LAYOUT, phi::AsinhKernel, float, double) {} +PD_REGISTER_KERNEL(acosh, CPU, ALL_LAYOUT, phi::AcoshKernel, float, double) {} +PD_REGISTER_KERNEL(atanh, CPU, ALL_LAYOUT, phi::AtanhKernel, float, double) {} +PD_REGISTER_KERNEL(relu, CPU, ALL_LAYOUT, phi::ReluKernel, float, double) {} diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h new file mode 100644 index 0000000000000..818d5755f0034 --- /dev/null +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -0,0 +1,829 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include +#include +#include + +#include +#ifndef _USE_MATH_DEFINES +#define _USE_MATH_DEFINES +#endif + +#include + +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" + +namespace phi { +namespace funcs { +enum ActBwdOpFwdDeps { + kNoDeps = 0x00, // Do not need any forward input/output + kDepX = 0x01, // Only need forward input X + kDepOut = 0x02, // Only need forward output Out +}; + +template +struct BaseActivationFunctor { + using ELEMENT_TYPE = T; + + using AttrPair = std::vector>; + + AttrPair GetAttrs() { return AttrPair(); } +}; + +template +struct Sine { + HOSTDEVICE T operator()(const T& val) const { return sin(val); } +}; + +template <> +struct Sine { + HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { + return dtype::float16(sin(static_cast(val))); + } +}; + +template +struct Cosine { + HOSTDEVICE T operator()(const T& val) const { return cos(val); } +}; + +template <> +struct Cosine { + HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { + return dtype::float16(cos(static_cast(val))); + } +}; + +// sine'(x) = cos(x) +template +struct SinGradFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * x.unaryExpr(Cosine()); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + +// sine(x) = sin(x) +template +struct SinFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.unaryExpr(Sine()); + } +}; + +// cosine'(x) = -sin(x) +template +struct CosGradFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = -dout * x.unaryExpr(Sine()); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + +// cosine(x) = cos(x) +template +struct CosFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.unaryExpr(Cosine()); + } +}; + +template +struct Tangent { + HOSTDEVICE T operator()(const T& val) const { return tan(val); } +}; + +template <> +struct Tangent { + HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { + return dtype::float16(tan(static_cast(val))); + } +}; + +// Tangent'(x) = -Tangent(x) +template +struct TanGradFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout / x.unaryExpr(Cosine()).square(); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + +// Tangent(x) = tan(x) +template +struct TanFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.unaryExpr(Tangent()); + } +}; + +template +struct Sinh { + HOSTDEVICE T operator()(const T& val) const { return sinh(val); } +}; + +template <> +struct Sinh { + HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { + return dtype::float16(sinhf(static_cast(val))); + } +}; + +template +struct Cosh { + HOSTDEVICE T operator()(const T& val) const { return cosh(val); } +}; + +template <> +struct Cosh { + HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { + return dtype::float16(coshf(static_cast(val))); + } +}; + +// sinh(x) = sinh(x) +template +struct SinhFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.unaryExpr(Sinh()); + } +}; + +// cosh(x) = cosh(x) +template +struct CoshFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.unaryExpr(Cosh()); + } +}; + +// sinh'(x) = cosh(x) +template +struct SinhGradFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * x.unaryExpr(Cosh()); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + +// cosh'(x) = sinh(x) +template +struct CoshGradFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * x.unaryExpr(Sinh()); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + +template +struct Acos { + HOSTDEVICE T operator()(const T& val) const { return acos(val); } +}; + +template <> +struct Acos { + HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { + return dtype::float16(acos(static_cast(val))); + } +}; + +// Acos(x) = acos(x) +template +struct AcosFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.unaryExpr(Acos()); + } +}; + +// acos'(x) = -1/sqrt(1-x^2) +template +struct AcosGradFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = + -dout * static_cast(1) / (static_cast(1) - x.square()).sqrt(); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + +template +struct Asin { + HOSTDEVICE T operator()(const T& val) const { return asin(val); } +}; + +template <> +struct Asin { + HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { + return dtype::float16(asin(static_cast(val))); + } +}; + +// Asin(x) = asin(x) +template +struct AsinFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.unaryExpr(Asin()); + } +}; + +// asin'(x) = 1/sqrt(1-x^2) +template +struct AsinGradFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = + dout * static_cast(1) / (static_cast(1) - x.square()).sqrt(); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + +template +struct Atan { + HOSTDEVICE T operator()(const T& val) const { return atan(val); } +}; + +template <> +struct Atan { + HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { + return dtype::float16(atan(static_cast(val))); + } +}; + +// Atan(x) = atan(x) +template +struct AtanFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.unaryExpr(Atan()); + } +}; + +// atan'(x) = 1 / (1 + x^2) +template +struct AtanGradFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * static_cast(1) / (static_cast(1) + x.square()); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + +template +struct Acosh { + HOSTDEVICE T operator()(const T& val) const { return acosh(val); } +}; + +template <> +struct Acosh { + HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { + return dtype::float16(acosh(static_cast(val))); + } +}; + +// Acosh(x) = acosh(x) +template +struct AcoshFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.unaryExpr(Acosh()); + } +}; + +// acosh'(x) = 1/sqrt(x^2 - 1) +template +struct AcoshGradFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = + dout * static_cast(1) / (x * x - static_cast(1)).sqrt(); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + +template +struct Asinh { + HOSTDEVICE T operator()(const T& val) const { return asinh(val); } +}; + +template <> +struct Asinh { + HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { + return dtype::float16(asinh(static_cast(val))); + } +}; + +// Asinh(x) = asinh(x) +template +struct AsinhFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.unaryExpr(Asinh()); + } +}; + +// asinh'(x) = 1/sqrt(x^2 + 1) +template +struct AsinhGradFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = + dout * static_cast(1) / (x.square() + static_cast(1)).sqrt(); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + +template +struct Atanh { + HOSTDEVICE T operator()(const T& val) const { return atanh(val); } +}; + +template <> +struct Atanh { + HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { + return dtype::float16(atanh(static_cast(val))); + } +}; + +// Atanh(x) = atanh(x) +template +struct AtanhFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.unaryExpr(Atanh()); + } +}; + +// atanh'(x) = 1/(1 - x^2) +template +struct AtanhGradFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * static_cast(1) / (static_cast(1) - x.square()); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + +// relu(x) = max(x, 0) +template +struct ReluCPUFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.unaryExpr([] HOSTDEVICE(T v) { + return v > static_cast(0) ? v : static_cast(0); + }); + } +}; + +// template +// struct ReluCUDAFunctor : public BaseActivationFunctor { +// template +// void operator()(Device d, X x, Out out) const { +// out.device(d) = x.cwiseMax(static_cast(0)); +// } +// }; + +template +struct ReluGradFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * (out > static_cast(0)).template cast(); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } +}; + +template +struct ReluGradGradFunctor : public BaseActivationFunctor { + template + void operator()(const Device& dev, + const DenseTensor* X, + const DenseTensor* Out, + const DenseTensor* ddX, + DenseTensor* ddOut, + DenseTensor* dOut, + DenseTensor* dX) const { + auto* d = dev.eigen_device(); + auto ddx = EigenVector::Flatten( + GET_DATA_SAFELY(ddX, "Input", "DDX", "ReluGradGrad")); + auto out = EigenVector::Flatten( + GET_DATA_SAFELY(Out, "Output", "Out", "ReluGradGrad")); + if (ddOut) { + auto ddout = EigenVector::Flatten( + GET_DATA_SAFELY(ddOut, "Output", "DDOut", "ReluGradGrad")); + ddout.device(*d) = ddx * (out > static_cast(0)).template cast(); + } + } + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } +}; + +#if defined(__NVCC__) || defined(__HIPCC__) +template +struct CudaReluFunctor : public BaseActivationFunctor { + T zero = static_cast(0.0f); + + // relu(x) = max(x, 0) + __device__ __forceinline__ T operator()(const T x) const { + return x > zero ? x : zero; + } +}; + +template +struct CudaReluGradFunctor : public BaseActivationFunctor { + T zero = static_cast(0.0f); + + // dx = dout * (out > 0) + __device__ __forceinline__ T operator()(const T dout, const T out) const { + return out > zero ? dout : zero; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } +}; + +template +struct CudaCosFunctor : public BaseActivationFunctor { + using MPType = typename dtype::MPTypeTrait::Type; + + // cos(x) = cos(x) + __device__ __forceinline__ T operator()(const T arg_x) const { + MPType x = static_cast(arg_x); + return static_cast(cos(x)); + } +}; + +template +struct CudaCosGradFunctor : public BaseActivationFunctor { + using MPType = typename dtype::MPTypeTrait::Type; + + // dx = dout * (-sin(x)) + __device__ __forceinline__ T operator()(const T arg_dout, + const T arg_x) const { + MPType dout = static_cast(arg_dout); + MPType x = static_cast(arg_x); + return static_cast(-dout * sin(x)); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + +template +struct CudaSinFunctor : public BaseActivationFunctor { + using MPType = typename dtype::MPTypeTrait::Type; + + // sin(x) = sin(x) + __device__ __forceinline__ T operator()(const T arg_x) const { + MPType x = static_cast(arg_x); + return static_cast(sin(x)); + } +}; + +template +struct CudaSinGradFunctor : public BaseActivationFunctor { + using MPType = typename dtype::MPTypeTrait::Type; + + // dx = dout * cos(x) + __device__ __forceinline__ T operator()(const T arg_dout, + const T arg_x) const { + MPType dout = static_cast(arg_dout); + MPType x = static_cast(arg_x); + return static_cast(dout * cos(x)); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + +template +struct CudaTanFunctor : public BaseActivationFunctor { + using MPType = typename dtype::MPTypeTrait::Type; + + // tan(x) = tan(x) + __device__ __forceinline__ T operator()(const T arg_x) const { + MPType x = static_cast(arg_x); + return static_cast(tan(x)); + } +}; + +template +struct CudaTanGradFunctor : public BaseActivationFunctor { + using MPType = typename dtype::MPTypeTrait::Type; + + // dx = dout / cos(x)^2 + __device__ __forceinline__ T operator()(const T arg_dout, + const T arg_x) const { + MPType dout = static_cast(arg_dout); + MPType x = static_cast(arg_x); + return static_cast(dout / (cos(x) * cos(x))); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + +template +struct CudaAsinFunctor : public BaseActivationFunctor { + using MPType = typename dtype::MPTypeTrait::Type; + + // asin(x) = asin(x) + __device__ __forceinline__ T operator()(const T arg_x) const { + MPType x = static_cast(arg_x); + return static_cast(asin(x)); + } +}; + +template +struct CudaAsinGradFunctor : public BaseActivationFunctor { + using MPType = typename dtype::MPTypeTrait::Type; + MPType one = static_cast(1.0f); + + // dx = dout / sqrt(1 - x^2) + __device__ __forceinline__ T operator()(const T arg_dout, + const T arg_x) const { + MPType dout = static_cast(arg_dout); + MPType x = static_cast(arg_x); + return static_cast(dout / sqrt(one - x * x)); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + +template +struct CudaAcosFunctor : public BaseActivationFunctor { + using MPType = typename dtype::MPTypeTrait::Type; + + // acos(x) = acos(x) + __device__ __forceinline__ T operator()(const T arg_x) const { + MPType x = static_cast(arg_x); + return static_cast(acos(x)); + } +}; + +template +struct CudaAcosGradFunctor : public BaseActivationFunctor { + using MPType = typename dtype::MPTypeTrait::Type; + MPType one = static_cast(1.0f); + + // dx = -dout / sqrt(1 - x^2) + __device__ __forceinline__ T operator()(const T arg_dout, + const T arg_x) const { + MPType dout = static_cast(arg_dout); + MPType x = static_cast(arg_x); + return static_cast(-dout / sqrt(one - x * x)); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + +template +struct CudaCoshFunctor : public BaseActivationFunctor { + using MPType = typename dtype::MPTypeTrait::Type; + + // cosh(x) = cosh(x) + __device__ __forceinline__ T operator()(const T arg_x) const { + MPType x = static_cast(arg_x); + return static_cast(cosh(x)); + } +}; + +template +struct CudaCoshGradFunctor : public BaseActivationFunctor { + using MPType = typename dtype::MPTypeTrait::Type; + + // dx = dout * sinh(x) + __device__ __forceinline__ T operator()(const T arg_dout, + const T arg_x) const { + MPType dout = static_cast(arg_dout); + MPType x = static_cast(arg_x); + return static_cast(dout * sinh(x)); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + +template +struct CudaSinhFunctor : public BaseActivationFunctor { + using MPType = typename dtype::MPTypeTrait::Type; + + // sinh(x) = sinh(x) + __device__ __forceinline__ T operator()(const T arg_x) const { + MPType x = static_cast(arg_x); + return static_cast(sinh(x)); + } +}; + +template +struct CudaSinhGradFunctor : public BaseActivationFunctor { + using MPType = typename dtype::MPTypeTrait::Type; + + // dx = dout * cosh(x) + __device__ __forceinline__ T operator()(const T arg_dout, + const T arg_x) const { + MPType dout = static_cast(arg_dout); + MPType x = static_cast(arg_x); + return static_cast(dout * cosh(x)); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + +template +struct CudaAcoshFunctor : public BaseActivationFunctor { + using MPType = typename dtype::MPTypeTrait::Type; + + // Acosh(x) = acosh(x) + __device__ __forceinline__ T operator()(const T arg_x) const { + MPType x = static_cast(arg_x); + return static_cast(acosh(x)); + } +}; + +template +struct CudaAcoshGradFunctor : public BaseActivationFunctor { + using MPType = typename dtype::MPTypeTrait::Type; + MPType one = static_cast(1.0f); + // dx = dout * 1 / sqrt(x^2 - 1) + __device__ __forceinline__ T operator()(const T arg_dout, + const T arg_x) const { + MPType dout = static_cast(arg_dout); + MPType x = static_cast(arg_x); + return static_cast(dout * one / sqrt(x * x - one)); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + +template +struct CudaAsinhFunctor : public BaseActivationFunctor { + using MPType = typename dtype::MPTypeTrait::Type; + + // Asinh(x) = asinh(x) + __device__ __forceinline__ T operator()(const T arg_x) const { + MPType x = static_cast(arg_x); + return static_cast(asinh(x)); + } +}; + +template +struct CudaAsinhGradFunctor : public BaseActivationFunctor { + using MPType = typename dtype::MPTypeTrait::Type; + MPType one = static_cast(1.0f); + + // dx = dout * 1/sqrt(x^2 + 1) + __device__ __forceinline__ T operator()(const T arg_dout, + const T arg_x) const { + MPType dout = static_cast(arg_dout); + MPType x = static_cast(arg_x); + return static_cast(dout * one / sqrt(x * x + one)); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + +template +struct CudaAtanhFunctor : public BaseActivationFunctor { + using MPType = typename dtype::MPTypeTrait::Type; + + // Atanh(x) = atanh(x) + __device__ __forceinline__ T operator()(const T arg_x) const { + MPType x = static_cast(arg_x); + return static_cast(atanh(x)); + } +}; + +template +struct CudaAtanhGradFunctor : public BaseActivationFunctor { + using MPType = typename dtype::MPTypeTrait::Type; + MPType one = static_cast(1.0f); + // dx = dout * 1/(1- x^2) + __device__ __forceinline__ T operator()(const T arg_dout, + const T arg_x) const { + MPType dout = static_cast(arg_dout); + MPType x = static_cast(arg_x); + return static_cast(dout * one / (one - x * x)); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + +template +struct CudaAtanFunctor : public BaseActivationFunctor { + using MPType = typename dtype::MPTypeTrait::Type; + + // atan(x) = atan(x) + __device__ __forceinline__ T operator()(const T arg_x) const { + MPType x = static_cast(arg_x); + return static_cast(atan(x)); + } +}; + +template +struct CudaAtanGradFunctor : public BaseActivationFunctor { + T one = static_cast(1.0f); + + // dx = dout / (1 + x^2) + __device__ __forceinline__ T operator()(const T dout, const T x) const { + return dout / (one + x * x); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + +#endif + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/gpu/activation_grad_kernel.cu b/paddle/phi/kernels/gpu/activation_grad_kernel.cu new file mode 100644 index 0000000000000..cdd2e10893a32 --- /dev/null +++ b/paddle/phi/kernels/gpu/activation_grad_kernel.cu @@ -0,0 +1,233 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/activation_kernel.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" +#include "paddle/phi/kernels/impl/activation_impl.h" + +#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" + +namespace phi { + +template +void ActivationGradGPUImpl(const Context& dev_ctx, + const DenseTensor* x, + const DenseTensor* out, + const DenseTensor* d_out, + DenseTensor* d_x, + const Functor& functor) { + if (static_cast(Functor::FwdDeps()) & + static_cast(funcs::ActBwdOpFwdDeps::kDepOut)) { + PADDLE_ENFORCE_NOT_NULL( + out, errors::NotFound("The input DenseTensor Out can not be nullptr")); + } + PADDLE_ENFORCE_NOT_NULL( + d_out, errors::NotFound("The input DenseTensor dOut can not be nullptr")); + PADDLE_ENFORCE_NOT_NULL( + d_x, errors::NotFound("The output DenseTensor dX can not be nullptr")); + if (!out) { + out = d_out; // fake out + } + if (static_cast(Functor::FwdDeps()) & + static_cast(funcs::ActBwdOpFwdDeps::kDepX)) { + PADDLE_ENFORCE_NOT_NULL( + x, errors::NotFound("The input DenseTensor X can not be nullptr")); + } else { + VLOG(10) << "Inplace activation of Op Functor: " << typeid(Functor).name(); + x = d_x; + } + + dev_ctx.template Alloc(d_x); + + std::vector ins = {d_out}; + std::vector outs = {d_x}; + + if (static_cast(Functor::FwdDeps()) == + static_cast(funcs::ActBwdOpFwdDeps::kDepOut)) { + // Only need forward output Out + ins.push_back(out); + funcs::ElementwiseKernel(dev_ctx, ins, &outs, functor); + } else if (static_cast(Functor::FwdDeps()) == + static_cast(funcs::ActBwdOpFwdDeps::kDepX)) { + // Only need forward input X + ins.push_back(x); + funcs::ElementwiseKernel(dev_ctx, ins, &outs, functor); + } else { + funcs::ElementwiseKernel(dev_ctx, ins, &outs, functor); + } +} + +#define DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(name, functor_class) \ + template \ + void name##GradKernel(const Context& dev_ctx, \ + const DenseTensor& x, \ + const DenseTensor& dout, \ + DenseTensor* dx) { \ + functor_class functor; \ + ActivationGradGPUImpl( \ + dev_ctx, &x, nullptr, &dout, dx, functor); \ + } + +#define DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepOut(name, functor_class) \ + template \ + void name##GradKernel(const Context& dev_ctx, \ + const DenseTensor& out, \ + const DenseTensor& dout, \ + DenseTensor* dx) { \ + functor_class functor; \ + ActivationGradGPUImpl( \ + dev_ctx, nullptr, &out, &dout, dx, functor); \ + } + +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepOut( + Relu, + funcs::CudaReluGradFunctor< + T>) DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Cos, + funcs::CudaCosGradFunctor) + DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Tan, funcs::CudaTanGradFunctor) + DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Acos, + funcs::CudaAcosGradFunctor) + DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Sin, + funcs::CudaSinGradFunctor) + DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX( + Asin, funcs::CudaAsinGradFunctor) + DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX( + Atan, funcs::CudaAtanGradFunctor) + DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX( + Sinh, funcs::CudaSinhGradFunctor) + DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX( + Cosh, funcs::CudaCoshGradFunctor) + DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX( + Asinh, funcs::CudaAsinhGradFunctor) + DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX( + Acosh, funcs::CudaAcoshGradFunctor) + DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX( + Atanh, + funcs::CudaAtanhGradFunctor) + +} // namespace phi +PD_REGISTER_KERNEL(cos_grad, + GPU, + ALL_LAYOUT, + phi::CosGradKernel, + float, + double, + phi::dtype::float16) {} +PD_REGISTER_KERNEL(tan_grad, + GPU, + ALL_LAYOUT, + phi::TanGradKernel, + float, + double, + phi::dtype::float16) {} +PD_REGISTER_KERNEL(acos_grad, + GPU, + ALL_LAYOUT, + phi::AcosGradKernel, + float, + double, + phi::dtype::float16) {} +PD_REGISTER_KERNEL(sin_grad, + GPU, + ALL_LAYOUT, + phi::SinGradKernel, + float, + double, + phi::dtype::float16) {} +PD_REGISTER_KERNEL(asin_grad, + GPU, + ALL_LAYOUT, + phi::AsinGradKernel, + float, + double, + phi::dtype::float16) {} +PD_REGISTER_KERNEL(atan_grad, + GPU, + ALL_LAYOUT, + phi::AtanGradKernel, + float, + double, + phi::dtype::float16) {} +PD_REGISTER_KERNEL(sinh_grad, + GPU, + ALL_LAYOUT, + phi::SinhGradKernel, + float, + double, + phi::dtype::float16) {} +PD_REGISTER_KERNEL(cosh_grad, + GPU, + ALL_LAYOUT, + phi::CoshGradKernel, + float, + double, + phi::dtype::float16) {} +PD_REGISTER_KERNEL(asinh_grad, + GPU, + ALL_LAYOUT, + phi::AsinhGradKernel, + float, + double, + phi::dtype::float16) {} +PD_REGISTER_KERNEL(acosh_grad, + GPU, + ALL_LAYOUT, + phi::AcoshGradKernel, + float, + double, + phi::dtype::float16) {} +PD_REGISTER_KERNEL(atanh_grad, + GPU, + ALL_LAYOUT, + phi::AtanhGradKernel, + float, + double, + phi::dtype::float16) {} +#ifdef PADDLE_WITH_HIP +PD_REGISTER_KERNEL(relu_grad, + GPU, + ALL_LAYOUT, + phi::ReluGradKernel, + float, + double, + phi::dtype::float16) {} +PD_REGISTER_KERNEL(relu_double_grad, + GPU, + ALL_LAYOUT, + phi::ReluDoubleGradKernel, + float, + double, + phi::dtype::float16) {} +#else +PD_REGISTER_KERNEL(relu_grad, + GPU, + ALL_LAYOUT, + phi::ReluGradKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} +PD_REGISTER_KERNEL(relu_double_grad, + GPU, + ALL_LAYOUT, + phi::ReluDoubleGradKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} +#endif diff --git a/paddle/phi/kernels/gpu/activation_kernel.cu b/paddle/phi/kernels/gpu/activation_kernel.cu new file mode 100644 index 0000000000000..057340b22f22d --- /dev/null +++ b/paddle/phi/kernels/gpu/activation_kernel.cu @@ -0,0 +1,142 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/activation_kernel.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" +#include "paddle/phi/kernels/impl/activation_impl.h" + +#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" + +namespace phi { + +template +void ActivationGPUImpl(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out, + const Functor& functor) { + PADDLE_ENFORCE_NOT_NULL(out, + errors::NotFound("Output Out should not be nullptr")); + dev_ctx.template Alloc(out); + std::vector ins = {&x}; + std::vector outs = {out}; + funcs::ElementwiseKernel(dev_ctx, ins, &outs, functor); +} + +#define DEFINE_GPU_ACTIVATION_KERNEL(name, functor_class) \ + template \ + void name##Kernel( \ + const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) { \ + functor_class functor; \ + ActivationGPUImpl(dev_ctx, x, out, functor); \ + } + +DEFINE_GPU_ACTIVATION_KERNEL(Cos, funcs::CudaCosFunctor) +DEFINE_GPU_ACTIVATION_KERNEL(Tan, funcs::CudaTanFunctor) +DEFINE_GPU_ACTIVATION_KERNEL(Acos, funcs::CudaAcosFunctor) +DEFINE_GPU_ACTIVATION_KERNEL(Sin, funcs::CudaSinFunctor) +DEFINE_GPU_ACTIVATION_KERNEL(Asin, funcs::CudaAsinFunctor) +DEFINE_GPU_ACTIVATION_KERNEL(Atan, funcs::CudaAtanFunctor) +DEFINE_GPU_ACTIVATION_KERNEL(Sinh, funcs::CudaSinhFunctor) +DEFINE_GPU_ACTIVATION_KERNEL(Cosh, funcs::CudaCoshFunctor) +DEFINE_GPU_ACTIVATION_KERNEL(Asinh, funcs::CudaAsinhFunctor) +DEFINE_GPU_ACTIVATION_KERNEL(Acosh, funcs::CudaAcoshFunctor) +DEFINE_GPU_ACTIVATION_KERNEL(Atanh, funcs::CudaAtanhFunctor) +DEFINE_GPU_ACTIVATION_KERNEL(Relu, funcs::CudaReluFunctor) + +} // namespace phi + +#ifdef PADDLE_WITH_HIP +PD_REGISTER_KERNEL(relu, + GPU, + ALL_LAYOUT, + phi::ReluKernel, + float, + double, + phi::dtype::float16) {} +#else +PD_REGISTER_KERNEL(relu, + GPU, + ALL_LAYOUT, + phi::ReluKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} +#endif +PD_REGISTER_KERNEL( + sin, GPU, ALL_LAYOUT, phi::SinKernel, float, double, phi::dtype::float16) {} +PD_REGISTER_KERNEL( + cos, GPU, ALL_LAYOUT, phi::CosKernel, float, double, phi::dtype::float16) {} +PD_REGISTER_KERNEL( + tan, GPU, ALL_LAYOUT, phi::TanKernel, float, double, phi::dtype::float16) {} +PD_REGISTER_KERNEL(acos, + GPU, + ALL_LAYOUT, + phi::AcosKernel, + float, + double, + phi::dtype::float16) {} +PD_REGISTER_KERNEL(asin, + GPU, + ALL_LAYOUT, + phi::AsinKernel, + float, + double, + phi::dtype::float16) {} +PD_REGISTER_KERNEL(atan, + GPU, + ALL_LAYOUT, + phi::AtanKernel, + float, + double, + phi::dtype::float16) {} +PD_REGISTER_KERNEL(sinh, + GPU, + ALL_LAYOUT, + phi::SinhKernel, + float, + double, + phi::dtype::float16) {} +PD_REGISTER_KERNEL(cosh, + GPU, + ALL_LAYOUT, + phi::CoshKernel, + float, + double, + phi::dtype::float16) {} +PD_REGISTER_KERNEL(asinh, + GPU, + ALL_LAYOUT, + phi::AsinhKernel, + float, + double, + phi::dtype::float16) {} +PD_REGISTER_KERNEL(acosh, + GPU, + ALL_LAYOUT, + phi::AcoshKernel, + float, + double, + phi::dtype::float16) {} +PD_REGISTER_KERNEL(atanh, + GPU, + ALL_LAYOUT, + phi::AtanhKernel, + float, + double, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/impl/activation_impl.h b/paddle/phi/kernels/impl/activation_impl.h new file mode 100644 index 0000000000000..95acf12c19e10 --- /dev/null +++ b/paddle/phi/kernels/impl/activation_impl.h @@ -0,0 +1,159 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/activation_functor.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" + +#include "paddle/fluid/platform/device_context.h" + +namespace phi { + +#define ToString(x) #x + +template +void ActivationImpl(const Context& dev_ctx, + const DenseTensor& X, + DenseTensor* Out, + const Functor& functor) { + PADDLE_ENFORCE_NOT_NULL(Out, + errors::NotFound("Output Out should not be nullptr")); + dev_ctx.template Alloc(Out); + auto x = phi::EigenVector::Flatten( + GET_DATA_SAFELY(&X, "Input", "X", "Activation")); + auto out = phi::EigenVector::Flatten( + GET_DATA_SAFELY(Out, "Output", "Out", "Activation")); + auto* place = dev_ctx.eigen_device(); + // use 32bit index to speed up computation + bool use_32bit_index = out.size() < Eigen::NumTraits::highest(); + bool is_gpu_place = paddle::platform::is_gpu_place(dev_ctx.GetPlace()); + if (use_32bit_index && is_gpu_place) { + functor(*place, To32BitIndex(x), To32BitIndex(out)); + } else { + functor(*place, x, out); + } +} + +template +void ActivationGradImpl(const Context& dev_ctx, + const DenseTensor* X, + const DenseTensor* Out, + const DenseTensor* dOut, + DenseTensor* dX, + const Functor& functor) { + if (static_cast(Functor::FwdDeps()) & + static_cast(funcs::ActBwdOpFwdDeps::kDepOut)) { + PADDLE_ENFORCE_NOT_NULL( + Out, errors::NotFound("The input DenseTensor Out can not be nullptr")); + } + PADDLE_ENFORCE_NOT_NULL( + dOut, errors::NotFound("The input DenseTensor dOut can not be nullptr")); + PADDLE_ENFORCE_NOT_NULL( + dX, errors::NotFound("The output DenseTensor dX can not be nullptr")); + if (!Out) { + Out = dOut; // fake out + } + if (static_cast(Functor::FwdDeps()) & + static_cast(funcs::ActBwdOpFwdDeps::kDepX)) { + PADDLE_ENFORCE_NOT_NULL( + X, errors::NotFound("The input DenseTensor X can not be nullptr")); + } else { + VLOG(10) << "Inplace activation of Op Functor: " << typeid(Functor).name(); + X = dX; + } + + dev_ctx.template Alloc(dX); + auto dout = phi::EigenVector::Flatten( + GET_DATA_SAFELY(dOut, "Input", "Out@GRAD", "ActivationGrad")); + auto out = phi::EigenVector::Flatten( + GET_DATA_SAFELY(Out, "Input", "Out", "ActivationGrad")); + auto dx = phi::EigenVector::Flatten( + GET_DATA_SAFELY(dX, "Input", "X@GRAD", "ActivationGrad")); + auto x = phi::EigenVector::Flatten( + GET_DATA_SAFELY(X, "Input", "X", "ActivationGrad")); + auto* place = dev_ctx.eigen_device(); + // use 32bit index to speed up computation + bool use_32bit_index = out.size() < Eigen::NumTraits::highest(); + bool is_gpu_place = paddle::platform::is_gpu_place(dev_ctx.GetPlace()); + if (use_32bit_index && is_gpu_place) { + functor(*place, + To32BitIndex(x), + To32BitIndex(out), + To32BitIndex(dout), + To32BitIndex(dx)); + } else { + functor(*place, x, out, dout, dx); + } +} + +template +void ActivationDoubleGradImpl(const Context& dev_ctx, + const DenseTensor* X, + const DenseTensor* Out, + const DenseTensor* ddX, + DenseTensor* dX, + DenseTensor* dOut, + DenseTensor* ddOut, + const Functor& functor) { + if (static_cast(Functor::FwdDeps()) & + static_cast(funcs::ActBwdOpFwdDeps::kDepX)) { + PADDLE_ENFORCE_NOT_NULL( + X, errors::NotFound("The input DenseTensor X can not be nullptr")); + } else { + VLOG(10) << "Inplace activation of Op Functor: " << typeid(Functor).name(); + X = ddX; + } + if (static_cast(Functor::FwdDeps()) & + static_cast(funcs::ActBwdOpFwdDeps::kDepOut)) { + PADDLE_ENFORCE_NOT_NULL( + Out, errors::NotFound("The input DenseTensor Out can not be nullptr")); + } else { + VLOG(10) << "Inplace activation of Op Functor: " << typeid(Functor).name(); + Out = ddX; + } + + if (ddOut) { + dev_ctx.template Alloc(ddOut); + } + if (dOut) { + dev_ctx.template Alloc(dOut); + } + if (dX) { + dX->Resize(Out->dims()); + dev_ctx.template Alloc(dX); + } + + functor(dev_ctx, X, Out, ddX, ddOut, dOut, dX); +} + +template +void ReluDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& out, + const DenseTensor& ddx, + DenseTensor* ddout) { + funcs::ReluGradGradFunctor relu_double_grad_functor; + ActivationDoubleGradImpl>( + dev_ctx, + nullptr, + &out, + &ddx, + nullptr, + nullptr, + ddout, + relu_double_grad_functor); +} + +} // namespace phi diff --git a/paddle/phi/ops/compat/activation_sig.cc b/paddle/phi/ops/compat/activation_sig.cc new file mode 100644 index 0000000000000..edd6760401ccd --- /dev/null +++ b/paddle/phi/ops/compat/activation_sig.cc @@ -0,0 +1,66 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +#define DefineActGradDepXOpArgMap(func_name, op_name) \ + KernelSignature func_name##GradOpArgumentMapping( \ + const ArgumentMappingContext& ctx) { \ + return KernelSignature( \ + op_name "_grad", {"X", GradVarName("Out")}, {}, {GradVarName("X")}); \ + } + +#define DefineActGradDepOutOpArgMap(func_name, op_name) \ + KernelSignature func_name##GradOpArgumentMapping( \ + const ArgumentMappingContext& ctx) { \ + return KernelSignature( \ + op_name "_grad", {"Out", GradVarName("Out")}, {}, {GradVarName("X")}); \ + } + +KernelSignature ReluDoubleGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("relu_double_grad", {"Out", "DDX"}, {}, {"DDOut"}); +} + +DefineActGradDepXOpArgMap(Cos, "cos") DefineActGradDepXOpArgMap(Tan, "tan") + DefineActGradDepXOpArgMap(Acos, "acos") + DefineActGradDepXOpArgMap(Sin, "sin") DefineActGradDepXOpArgMap(Asin, + "asin") + DefineActGradDepXOpArgMap(Atan, "atan") + DefineActGradDepXOpArgMap(Sinh, "sinh") + DefineActGradDepXOpArgMap(Cosh, "cosh") + DefineActGradDepXOpArgMap(Asinh, "asinh") + DefineActGradDepXOpArgMap(Acosh, "acosh") + DefineActGradDepXOpArgMap(Atanh, "atanh") + DefineActGradDepOutOpArgMap(Relu, "relu") +} // namespace phi + +PD_REGISTER_BASE_KERNEL_NAME(relu_grad_grad, relu_double_grad); + +PD_REGISTER_ARG_MAPPING_FN(cos_grad, phi::CosGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(tan_grad, phi::TanGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(acos_grad, phi::AcosGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(sin_grad, phi::SinGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(asin_grad, phi::AsinGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(atan_grad, phi::AtanGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(sinh_grad, phi::SinhGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(cosh_grad, phi::CoshGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(asinh_grad, phi::AsinhGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(acosh_grad, phi::AcoshGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(atanh_grad, phi::AtanhGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(relu_grad, phi::ReluGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(relu_grad_grad, + phi::ReluDoubleGradOpArgumentMapping); From 344111ad7f16576b506334554a79a40626289f0f Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Fri, 4 Mar 2022 11:40:38 +0000 Subject: [PATCH 2/8] adjust code format --- paddle/phi/kernels/activation_grad_kernel.h | 23 ++++++------ .../phi/kernels/cpu/activation_grad_kernel.cc | 37 ++++++------------- .../phi/kernels/gpu/activation_grad_kernel.cu | 37 ++++++------------- 3 files changed, 36 insertions(+), 61 deletions(-) diff --git a/paddle/phi/kernels/activation_grad_kernel.h b/paddle/phi/kernels/activation_grad_kernel.h index 0fb430c122ba7..74e626e761b14 100644 --- a/paddle/phi/kernels/activation_grad_kernel.h +++ b/paddle/phi/kernels/activation_grad_kernel.h @@ -39,16 +39,17 @@ void ReluDoubleGradKernel(const Context& dev_ctx, const DenseTensor& ddx, DenseTensor* ddout); -DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Cos) DECLARE_ACTIVATION_GRAD_KERNEL_DepX( - Tan) DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Acos) - DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Sin) - DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Asin) - DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Atan) - DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Sinh) - DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Cosh) - DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Asinh) - DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Acosh) - DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Atanh) - DECLARE_ACTIVATION_GRAD_KERNEL_DepOut(Relu) +DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Cos) +DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Tan) +DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Acos) +DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Sin) +DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Asin) +DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Atan) +DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Sinh) +DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Cosh) +DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Asinh) +DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Acosh) +DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Atanh) +DECLARE_ACTIVATION_GRAD_KERNEL_DepOut(Relu) } // namespace phi diff --git a/paddle/phi/kernels/cpu/activation_grad_kernel.cc b/paddle/phi/kernels/cpu/activation_grad_kernel.cc index b9cbb5b118438..b549694392d5b 100644 --- a/paddle/phi/kernels/cpu/activation_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_grad_kernel.cc @@ -42,31 +42,18 @@ namespace phi { dev_ctx, nullptr, &out, &dout, dx, functor); \ } -DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX( - Cos, - funcs::CosGradFunctor< - T>) DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Tan, - funcs::TanGradFunctor) - DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX( - Acos, - funcs::AcosGradFunctor< - T>) DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Sin, - funcs::SinGradFunctor) - DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Asin, funcs::AsinGradFunctor) - DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Atan, - funcs::AtanGradFunctor) - DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX( - Sinh, funcs::SinhGradFunctor) - DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX( - Cosh, funcs::CoshGradFunctor) - DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX( - Asinh, funcs::AsinhGradFunctor) - DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX( - Acosh, funcs::AcoshGradFunctor) - DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX( - Atanh, funcs::AtanhGradFunctor) - DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepOut( - Relu, funcs::ReluGradFunctor) +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Cos, funcs::CosGradFunctor) +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Tan,funcs::TanGradFunctor) +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Acos, funcs::AcosGradFunctor) +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Sin, funcs::SinGradFunctor) +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Asin, funcs::AsinGradFunctor) +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Atan, funcs::AtanGradFunctor) +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Sinh, funcs::SinhGradFunctor) +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Cosh, funcs::CoshGradFunctor) +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Asinh, funcs::AsinhGradFunctor) +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Acosh, funcs::AcoshGradFunctor) +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Atanh, funcs::AtanhGradFunctor) +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepOut(Relu, funcs::ReluGradFunctor) } // namespace phi PD_REGISTER_KERNEL( diff --git a/paddle/phi/kernels/gpu/activation_grad_kernel.cu b/paddle/phi/kernels/gpu/activation_grad_kernel.cu index cdd2e10893a32..5566319887e08 100644 --- a/paddle/phi/kernels/gpu/activation_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_grad_kernel.cu @@ -94,31 +94,18 @@ void ActivationGradGPUImpl(const Context& dev_ctx, dev_ctx, nullptr, &out, &dout, dx, functor); \ } -DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepOut( - Relu, - funcs::CudaReluGradFunctor< - T>) DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Cos, - funcs::CudaCosGradFunctor) - DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Tan, funcs::CudaTanGradFunctor) - DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Acos, - funcs::CudaAcosGradFunctor) - DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Sin, - funcs::CudaSinGradFunctor) - DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX( - Asin, funcs::CudaAsinGradFunctor) - DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX( - Atan, funcs::CudaAtanGradFunctor) - DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX( - Sinh, funcs::CudaSinhGradFunctor) - DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX( - Cosh, funcs::CudaCoshGradFunctor) - DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX( - Asinh, funcs::CudaAsinhGradFunctor) - DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX( - Acosh, funcs::CudaAcoshGradFunctor) - DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX( - Atanh, - funcs::CudaAtanhGradFunctor) +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepOut(Relu, funcs::CudaReluGradFunctor) +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Cos, funcs::CudaCosGradFunctor) +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Tan, funcs::CudaTanGradFunctor) +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Acos, funcs::CudaAcosGradFunctor) +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Sin, funcs::CudaSinGradFunctor) +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Asin, funcs::CudaAsinGradFunctor) +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Atan, funcs::CudaAtanGradFunctor) +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Sinh, funcs::CudaSinhGradFunctor) +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Cosh, funcs::CudaCoshGradFunctor) +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Asinh, funcs::CudaAsinhGradFunctor) +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Acosh, funcs::CudaAcoshGradFunctor) +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Atanh, funcs::CudaAtanhGradFunctor) } // namespace phi PD_REGISTER_KERNEL(cos_grad, From bbdf4a0373525dfb7bb62676f03a974dc9d06608 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Fri, 4 Mar 2022 12:27:08 +0000 Subject: [PATCH 3/8] fix compile bugs --- paddle/phi/kernels/funcs/activation_functor.h | 43 ++++++++++--------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index 818d5755f0034..e70c593b31343 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -28,6 +28,7 @@ #include +#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/enforce.h" @@ -539,7 +540,7 @@ struct CudaReluGradFunctor : public BaseActivationFunctor { template struct CudaCosFunctor : public BaseActivationFunctor { - using MPType = typename dtype::MPTypeTrait::Type; + using MPType = typename phi::dtype::MPTypeTrait::Type; // cos(x) = cos(x) __device__ __forceinline__ T operator()(const T arg_x) const { @@ -550,7 +551,7 @@ struct CudaCosFunctor : public BaseActivationFunctor { template struct CudaCosGradFunctor : public BaseActivationFunctor { - using MPType = typename dtype::MPTypeTrait::Type; + using MPType = typename phi::dtype::MPTypeTrait::Type; // dx = dout * (-sin(x)) __device__ __forceinline__ T operator()(const T arg_dout, @@ -565,7 +566,7 @@ struct CudaCosGradFunctor : public BaseActivationFunctor { template struct CudaSinFunctor : public BaseActivationFunctor { - using MPType = typename dtype::MPTypeTrait::Type; + using MPType = typename phi::dtype::MPTypeTrait::Type; // sin(x) = sin(x) __device__ __forceinline__ T operator()(const T arg_x) const { @@ -576,7 +577,7 @@ struct CudaSinFunctor : public BaseActivationFunctor { template struct CudaSinGradFunctor : public BaseActivationFunctor { - using MPType = typename dtype::MPTypeTrait::Type; + using MPType = typename phi::dtype::MPTypeTrait::Type; // dx = dout * cos(x) __device__ __forceinline__ T operator()(const T arg_dout, @@ -591,7 +592,7 @@ struct CudaSinGradFunctor : public BaseActivationFunctor { template struct CudaTanFunctor : public BaseActivationFunctor { - using MPType = typename dtype::MPTypeTrait::Type; + using MPType = typename phi::dtype::MPTypeTrait::Type; // tan(x) = tan(x) __device__ __forceinline__ T operator()(const T arg_x) const { @@ -602,7 +603,7 @@ struct CudaTanFunctor : public BaseActivationFunctor { template struct CudaTanGradFunctor : public BaseActivationFunctor { - using MPType = typename dtype::MPTypeTrait::Type; + using MPType = typename phi::dtype::MPTypeTrait::Type; // dx = dout / cos(x)^2 __device__ __forceinline__ T operator()(const T arg_dout, @@ -617,7 +618,7 @@ struct CudaTanGradFunctor : public BaseActivationFunctor { template struct CudaAsinFunctor : public BaseActivationFunctor { - using MPType = typename dtype::MPTypeTrait::Type; + using MPType = typename phi::dtype::MPTypeTrait::Type; // asin(x) = asin(x) __device__ __forceinline__ T operator()(const T arg_x) const { @@ -628,7 +629,7 @@ struct CudaAsinFunctor : public BaseActivationFunctor { template struct CudaAsinGradFunctor : public BaseActivationFunctor { - using MPType = typename dtype::MPTypeTrait::Type; + using MPType = typename phi::dtype::MPTypeTrait::Type; MPType one = static_cast(1.0f); // dx = dout / sqrt(1 - x^2) @@ -644,7 +645,7 @@ struct CudaAsinGradFunctor : public BaseActivationFunctor { template struct CudaAcosFunctor : public BaseActivationFunctor { - using MPType = typename dtype::MPTypeTrait::Type; + using MPType = typename phi::dtype::MPTypeTrait::Type; // acos(x) = acos(x) __device__ __forceinline__ T operator()(const T arg_x) const { @@ -655,7 +656,7 @@ struct CudaAcosFunctor : public BaseActivationFunctor { template struct CudaAcosGradFunctor : public BaseActivationFunctor { - using MPType = typename dtype::MPTypeTrait::Type; + using MPType = typename phi::dtype::MPTypeTrait::Type; MPType one = static_cast(1.0f); // dx = -dout / sqrt(1 - x^2) @@ -671,7 +672,7 @@ struct CudaAcosGradFunctor : public BaseActivationFunctor { template struct CudaCoshFunctor : public BaseActivationFunctor { - using MPType = typename dtype::MPTypeTrait::Type; + using MPType = typename phi::dtype::MPTypeTrait::Type; // cosh(x) = cosh(x) __device__ __forceinline__ T operator()(const T arg_x) const { @@ -682,7 +683,7 @@ struct CudaCoshFunctor : public BaseActivationFunctor { template struct CudaCoshGradFunctor : public BaseActivationFunctor { - using MPType = typename dtype::MPTypeTrait::Type; + using MPType = typename phi::dtype::MPTypeTrait::Type; // dx = dout * sinh(x) __device__ __forceinline__ T operator()(const T arg_dout, @@ -697,7 +698,7 @@ struct CudaCoshGradFunctor : public BaseActivationFunctor { template struct CudaSinhFunctor : public BaseActivationFunctor { - using MPType = typename dtype::MPTypeTrait::Type; + using MPType = typename phi::dtype::MPTypeTrait::Type; // sinh(x) = sinh(x) __device__ __forceinline__ T operator()(const T arg_x) const { @@ -708,7 +709,7 @@ struct CudaSinhFunctor : public BaseActivationFunctor { template struct CudaSinhGradFunctor : public BaseActivationFunctor { - using MPType = typename dtype::MPTypeTrait::Type; + using MPType = typename phi::dtype::MPTypeTrait::Type; // dx = dout * cosh(x) __device__ __forceinline__ T operator()(const T arg_dout, @@ -723,7 +724,7 @@ struct CudaSinhGradFunctor : public BaseActivationFunctor { template struct CudaAcoshFunctor : public BaseActivationFunctor { - using MPType = typename dtype::MPTypeTrait::Type; + using MPType = typename phi::dtype::MPTypeTrait::Type; // Acosh(x) = acosh(x) __device__ __forceinline__ T operator()(const T arg_x) const { @@ -734,7 +735,7 @@ struct CudaAcoshFunctor : public BaseActivationFunctor { template struct CudaAcoshGradFunctor : public BaseActivationFunctor { - using MPType = typename dtype::MPTypeTrait::Type; + using MPType = typename phi::dtype::MPTypeTrait::Type; MPType one = static_cast(1.0f); // dx = dout * 1 / sqrt(x^2 - 1) __device__ __forceinline__ T operator()(const T arg_dout, @@ -749,7 +750,7 @@ struct CudaAcoshGradFunctor : public BaseActivationFunctor { template struct CudaAsinhFunctor : public BaseActivationFunctor { - using MPType = typename dtype::MPTypeTrait::Type; + using MPType = typename phi::dtype::MPTypeTrait::Type; // Asinh(x) = asinh(x) __device__ __forceinline__ T operator()(const T arg_x) const { @@ -760,7 +761,7 @@ struct CudaAsinhFunctor : public BaseActivationFunctor { template struct CudaAsinhGradFunctor : public BaseActivationFunctor { - using MPType = typename dtype::MPTypeTrait::Type; + using MPType = typename phi::dtype::MPTypeTrait::Type; MPType one = static_cast(1.0f); // dx = dout * 1/sqrt(x^2 + 1) @@ -776,7 +777,7 @@ struct CudaAsinhGradFunctor : public BaseActivationFunctor { template struct CudaAtanhFunctor : public BaseActivationFunctor { - using MPType = typename dtype::MPTypeTrait::Type; + using MPType = typename phi::dtype::MPTypeTrait::Type; // Atanh(x) = atanh(x) __device__ __forceinline__ T operator()(const T arg_x) const { @@ -787,7 +788,7 @@ struct CudaAtanhFunctor : public BaseActivationFunctor { template struct CudaAtanhGradFunctor : public BaseActivationFunctor { - using MPType = typename dtype::MPTypeTrait::Type; + using MPType = typename phi::dtype::MPTypeTrait::Type; MPType one = static_cast(1.0f); // dx = dout * 1/(1- x^2) __device__ __forceinline__ T operator()(const T arg_dout, @@ -802,7 +803,7 @@ struct CudaAtanhGradFunctor : public BaseActivationFunctor { template struct CudaAtanFunctor : public BaseActivationFunctor { - using MPType = typename dtype::MPTypeTrait::Type; + using MPType = typename phi::dtype::MPTypeTrait::Type; // atan(x) = atan(x) __device__ __forceinline__ T operator()(const T arg_x) const { From 647a6cfec02fd1274c89f6054a5366f3074924f9 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Fri, 4 Mar 2022 14:04:34 +0000 Subject: [PATCH 4/8] fix ci bugs --- .../paddle2cinn/build_cinn_pass_test.cc | 4 ++-- paddle/phi/ops/compat/activation_sig.cc | 23 ++++++++++--------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc b/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc index bf9d1baaf394f..47dffd47b7cbb 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc @@ -675,7 +675,7 @@ TEST(BuildCinnPassTest, NoNeedBufferInput) { USE_PASS(build_cinn_pass); USE_OP(mul); -USE_OP(relu); +USE_OP_ITSELF(relu); USE_OP_ITSELF(elementwise_add); -USE_OP(relu_grad); +USE_OP_ITSELF(relu_grad); USE_OP_ITSELF(elementwise_add_grad); diff --git a/paddle/phi/ops/compat/activation_sig.cc b/paddle/phi/ops/compat/activation_sig.cc index edd6760401ccd..e266a816d2c73 100644 --- a/paddle/phi/ops/compat/activation_sig.cc +++ b/paddle/phi/ops/compat/activation_sig.cc @@ -35,17 +35,18 @@ KernelSignature ReluDoubleGradOpArgumentMapping( return KernelSignature("relu_double_grad", {"Out", "DDX"}, {}, {"DDOut"}); } -DefineActGradDepXOpArgMap(Cos, "cos") DefineActGradDepXOpArgMap(Tan, "tan") - DefineActGradDepXOpArgMap(Acos, "acos") - DefineActGradDepXOpArgMap(Sin, "sin") DefineActGradDepXOpArgMap(Asin, - "asin") - DefineActGradDepXOpArgMap(Atan, "atan") - DefineActGradDepXOpArgMap(Sinh, "sinh") - DefineActGradDepXOpArgMap(Cosh, "cosh") - DefineActGradDepXOpArgMap(Asinh, "asinh") - DefineActGradDepXOpArgMap(Acosh, "acosh") - DefineActGradDepXOpArgMap(Atanh, "atanh") - DefineActGradDepOutOpArgMap(Relu, "relu") +DefineActGradDepXOpArgMap(Cos, "cos") +DefineActGradDepXOpArgMap(Tan, "tan") +DefineActGradDepXOpArgMap(Acos, "acos") +DefineActGradDepXOpArgMap(Sin, "sin") +DefineActGradDepXOpArgMap(Asin, "asin") +DefineActGradDepXOpArgMap(Atan, "atan") +DefineActGradDepXOpArgMap(Sinh, "sinh") +DefineActGradDepXOpArgMap(Cosh, "cosh") +DefineActGradDepXOpArgMap(Asinh, "asinh") +DefineActGradDepXOpArgMap(Acosh, "acosh") +DefineActGradDepXOpArgMap(Atanh, "atanh") +DefineActGradDepOutOpArgMap(Relu, "relu") } // namespace phi PD_REGISTER_BASE_KERNEL_NAME(relu_grad_grad, relu_double_grad); From c6475f921357739e7e59b72a4ac2f8325312ca6c Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Sat, 5 Mar 2022 02:18:04 +0000 Subject: [PATCH 5/8] code format adjust --- paddle/fluid/operators/activation_op.cc | 8 ------- paddle/phi/kernels/activation_grad_kernel.h | 24 ++++++++++----------- 2 files changed, 12 insertions(+), 20 deletions(-) diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 12a629aa3024f..66f1bcc8b6869 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -1603,14 +1603,6 @@ REGISTER_OPERATOR( ops::ActivationOpDoubleGrad2::FwdDeps()>, ops::ActivationDoubleGradOpInplaceInferer); -REGISTER_OP_CPU_KERNEL( - relu_grad_grad, - ops::ActivationDoubleGradKernel>, - ops::ActivationDoubleGradKernel>, - ops::ActivationDoubleGradKernel>); /* ========================================================================== */ /* ======================== leaky relu register ============================ */ diff --git a/paddle/phi/kernels/activation_grad_kernel.h b/paddle/phi/kernels/activation_grad_kernel.h index 74e626e761b14..f34e5710ab729 100644 --- a/paddle/phi/kernels/activation_grad_kernel.h +++ b/paddle/phi/kernels/activation_grad_kernel.h @@ -39,17 +39,17 @@ void ReluDoubleGradKernel(const Context& dev_ctx, const DenseTensor& ddx, DenseTensor* ddout); -DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Cos) -DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Tan) -DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Acos) -DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Sin) -DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Asin) -DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Atan) -DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Sinh) -DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Cosh) -DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Asinh) -DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Acosh) -DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Atanh) -DECLARE_ACTIVATION_GRAD_KERNEL_DepOut(Relu) +DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Cos); +DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Tan); +DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Acos); +DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Sin); +DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Asin); +DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Atan); +DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Sinh); +DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Cosh); +DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Asinh); +DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Acosh); +DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Atanh); +DECLARE_ACTIVATION_GRAD_KERNEL_DepOut(Relu); } // namespace phi From b0eee13c665f9d9d0f2dc32009f01fb9930062b5 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Sat, 5 Mar 2022 02:21:48 +0000 Subject: [PATCH 6/8] code format adjust2 --- .../phi/kernels/cpu/activation_grad_kernel.cc | 24 +++++++++---------- .../phi/kernels/gpu/activation_grad_kernel.cu | 24 +++++++++---------- paddle/phi/ops/compat/activation_sig.cc | 24 +++++++++---------- 3 files changed, 36 insertions(+), 36 deletions(-) diff --git a/paddle/phi/kernels/cpu/activation_grad_kernel.cc b/paddle/phi/kernels/cpu/activation_grad_kernel.cc index b549694392d5b..b6c8c2d9d7889 100644 --- a/paddle/phi/kernels/cpu/activation_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_grad_kernel.cc @@ -42,18 +42,18 @@ namespace phi { dev_ctx, nullptr, &out, &dout, dx, functor); \ } -DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Cos, funcs::CosGradFunctor) -DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Tan,funcs::TanGradFunctor) -DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Acos, funcs::AcosGradFunctor) -DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Sin, funcs::SinGradFunctor) -DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Asin, funcs::AsinGradFunctor) -DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Atan, funcs::AtanGradFunctor) -DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Sinh, funcs::SinhGradFunctor) -DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Cosh, funcs::CoshGradFunctor) -DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Asinh, funcs::AsinhGradFunctor) -DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Acosh, funcs::AcoshGradFunctor) -DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Atanh, funcs::AtanhGradFunctor) -DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepOut(Relu, funcs::ReluGradFunctor) +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Cos, funcs::CosGradFunctor); +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Tan, funcs::TanGradFunctor); +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Acos, funcs::AcosGradFunctor); +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Sin, funcs::SinGradFunctor); +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Asin, funcs::AsinGradFunctor); +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Atan, funcs::AtanGradFunctor); +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Sinh, funcs::SinhGradFunctor); +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Cosh, funcs::CoshGradFunctor); +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Asinh, funcs::AsinhGradFunctor); +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Acosh, funcs::AcoshGradFunctor); +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Atanh, funcs::AtanhGradFunctor); +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepOut(Relu, funcs::ReluGradFunctor); } // namespace phi PD_REGISTER_KERNEL( diff --git a/paddle/phi/kernels/gpu/activation_grad_kernel.cu b/paddle/phi/kernels/gpu/activation_grad_kernel.cu index 5566319887e08..88b63ed543ff7 100644 --- a/paddle/phi/kernels/gpu/activation_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_grad_kernel.cu @@ -94,18 +94,18 @@ void ActivationGradGPUImpl(const Context& dev_ctx, dev_ctx, nullptr, &out, &dout, dx, functor); \ } -DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepOut(Relu, funcs::CudaReluGradFunctor) -DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Cos, funcs::CudaCosGradFunctor) -DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Tan, funcs::CudaTanGradFunctor) -DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Acos, funcs::CudaAcosGradFunctor) -DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Sin, funcs::CudaSinGradFunctor) -DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Asin, funcs::CudaAsinGradFunctor) -DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Atan, funcs::CudaAtanGradFunctor) -DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Sinh, funcs::CudaSinhGradFunctor) -DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Cosh, funcs::CudaCoshGradFunctor) -DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Asinh, funcs::CudaAsinhGradFunctor) -DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Acosh, funcs::CudaAcoshGradFunctor) -DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Atanh, funcs::CudaAtanhGradFunctor) +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepOut(Relu, funcs::CudaReluGradFunctor); +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Cos, funcs::CudaCosGradFunctor); +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Tan, funcs::CudaTanGradFunctor); +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Acos, funcs::CudaAcosGradFunctor); +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Sin, funcs::CudaSinGradFunctor); +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Asin, funcs::CudaAsinGradFunctor); +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Atan, funcs::CudaAtanGradFunctor); +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Sinh, funcs::CudaSinhGradFunctor); +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Cosh, funcs::CudaCoshGradFunctor); +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Asinh, funcs::CudaAsinhGradFunctor); +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Acosh, funcs::CudaAcoshGradFunctor); +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Atanh, funcs::CudaAtanhGradFunctor); } // namespace phi PD_REGISTER_KERNEL(cos_grad, diff --git a/paddle/phi/ops/compat/activation_sig.cc b/paddle/phi/ops/compat/activation_sig.cc index e266a816d2c73..396830ca20765 100644 --- a/paddle/phi/ops/compat/activation_sig.cc +++ b/paddle/phi/ops/compat/activation_sig.cc @@ -35,18 +35,18 @@ KernelSignature ReluDoubleGradOpArgumentMapping( return KernelSignature("relu_double_grad", {"Out", "DDX"}, {}, {"DDOut"}); } -DefineActGradDepXOpArgMap(Cos, "cos") -DefineActGradDepXOpArgMap(Tan, "tan") -DefineActGradDepXOpArgMap(Acos, "acos") -DefineActGradDepXOpArgMap(Sin, "sin") -DefineActGradDepXOpArgMap(Asin, "asin") -DefineActGradDepXOpArgMap(Atan, "atan") -DefineActGradDepXOpArgMap(Sinh, "sinh") -DefineActGradDepXOpArgMap(Cosh, "cosh") -DefineActGradDepXOpArgMap(Asinh, "asinh") -DefineActGradDepXOpArgMap(Acosh, "acosh") -DefineActGradDepXOpArgMap(Atanh, "atanh") -DefineActGradDepOutOpArgMap(Relu, "relu") +DefineActGradDepXOpArgMap(Cos, "cos"); +DefineActGradDepXOpArgMap(Tan, "tan"); +DefineActGradDepXOpArgMap(Acos, "acos"); +DefineActGradDepXOpArgMap(Sin, "sin"); +DefineActGradDepXOpArgMap(Asin, "asin"); +DefineActGradDepXOpArgMap(Atan, "atan"); +DefineActGradDepXOpArgMap(Sinh, "sinh"); +DefineActGradDepXOpArgMap(Cosh, "cosh"); +DefineActGradDepXOpArgMap(Asinh, "asinh"); +DefineActGradDepXOpArgMap(Acosh, "acosh"); +DefineActGradDepXOpArgMap(Atanh, "atanh"); +DefineActGradDepOutOpArgMap(Relu, "relu"); } // namespace phi PD_REGISTER_BASE_KERNEL_NAME(relu_grad_grad, relu_double_grad); From 10a7c2416f61bf48d1af44462f283d34d7f7abc0 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Sat, 5 Mar 2022 07:51:10 +0000 Subject: [PATCH 7/8] activate ci status --- paddle/phi/kernels/cpu/activation_grad_kernel.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/phi/kernels/cpu/activation_grad_kernel.cc b/paddle/phi/kernels/cpu/activation_grad_kernel.cc index b6c8c2d9d7889..c0e249ec0f6ad 100644 --- a/paddle/phi/kernels/cpu/activation_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_grad_kernel.cc @@ -56,6 +56,7 @@ DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Atanh, funcs::AtanhGradFunctor); DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepOut(Relu, funcs::ReluGradFunctor); } // namespace phi + PD_REGISTER_KERNEL( cos_grad, CPU, ALL_LAYOUT, phi::CosGradKernel, float, double) {} PD_REGISTER_KERNEL( From 4abb59be26ee062e66734fa62e01c582f276fc09 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Mon, 7 Mar 2022 06:25:58 +0000 Subject: [PATCH 8/8] modify according to comment --- paddle/fluid/operators/activation_op.h | 8 +- .../phi/kernels/cpu/activation_grad_kernel.cc | 3 +- paddle/phi/kernels/funcs/activation_functor.h | 14 +- .../phi/kernels/gpu/activation_grad_kernel.cu | 5 +- paddle/phi/kernels/gpu/activation_kernel.cu | 5 +- .../phi/kernels/impl/activation_grad_impl.h | 133 ++++++++++++++++++ paddle/phi/kernels/impl/activation_impl.h | 109 -------------- 7 files changed, 150 insertions(+), 127 deletions(-) create mode 100644 paddle/phi/kernels/impl/activation_grad_impl.h diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index 1c214c7a489f6..4b79397b6cdf2 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -493,13 +493,9 @@ using ReluGradFunctor = phi::funcs::ReluGradFunctor; template using ReluGradGradFunctor = phi::funcs::ReluGradGradFunctor; + template -struct ReluCUDAFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out) const { - out.device(d) = x.cwiseMax(static_cast(0)); - } -}; +using ReluCUDAFunctor = phi::funcs::ReluCUDAFunctor; // tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) template diff --git a/paddle/phi/kernels/cpu/activation_grad_kernel.cc b/paddle/phi/kernels/cpu/activation_grad_kernel.cc index c0e249ec0f6ad..fe43ebb816077 100644 --- a/paddle/phi/kernels/cpu/activation_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_grad_kernel.cc @@ -13,10 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/phi/kernels/activation_grad_kernel.h" + #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/impl/activation_impl.h" +#include "paddle/phi/kernels/impl/activation_grad_impl.h" namespace phi { diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index e70c593b31343..1a36e4e132f41 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -463,13 +463,13 @@ struct ReluCPUFunctor : public BaseActivationFunctor { } }; -// template -// struct ReluCUDAFunctor : public BaseActivationFunctor { -// template -// void operator()(Device d, X x, Out out) const { -// out.device(d) = x.cwiseMax(static_cast(0)); -// } -// }; +template +struct ReluCUDAFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.cwiseMax(static_cast(0)); + } +}; template struct ReluGradFunctor : public BaseActivationFunctor { diff --git a/paddle/phi/kernels/gpu/activation_grad_kernel.cu b/paddle/phi/kernels/gpu/activation_grad_kernel.cu index 88b63ed543ff7..c2995c79a7e8c 100644 --- a/paddle/phi/kernels/gpu/activation_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_grad_kernel.cu @@ -12,13 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/phi/kernels/activation_grad_kernel.h" + #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/activation_kernel.h" #include "paddle/phi/kernels/funcs/elementwise_base.h" -#include "paddle/phi/kernels/impl/activation_impl.h" +#include "paddle/phi/kernels/impl/activation_grad_impl.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h" diff --git a/paddle/phi/kernels/gpu/activation_kernel.cu b/paddle/phi/kernels/gpu/activation_kernel.cu index 057340b22f22d..26752b89e7c34 100644 --- a/paddle/phi/kernels/gpu/activation_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_kernel.cu @@ -12,13 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/phi/kernels/activation_kernel.h" + #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/activation_kernel.h" #include "paddle/phi/kernels/funcs/elementwise_base.h" -#include "paddle/phi/kernels/impl/activation_impl.h" +#include "paddle/phi/kernels/impl/activation_grad_impl.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h" diff --git a/paddle/phi/kernels/impl/activation_grad_impl.h b/paddle/phi/kernels/impl/activation_grad_impl.h new file mode 100644 index 0000000000000..80e23d2b8e24b --- /dev/null +++ b/paddle/phi/kernels/impl/activation_grad_impl.h @@ -0,0 +1,133 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/activation_functor.h" + +#include "paddle/fluid/platform/device_context.h" + +namespace phi { + +template +void ActivationGradImpl(const Context& dev_ctx, + const DenseTensor* X, + const DenseTensor* Out, + const DenseTensor* dOut, + DenseTensor* dX, + const Functor& functor) { + if (static_cast(Functor::FwdDeps()) & + static_cast(funcs::ActBwdOpFwdDeps::kDepOut)) { + PADDLE_ENFORCE_NOT_NULL( + Out, errors::NotFound("The input DenseTensor Out can not be nullptr")); + } + PADDLE_ENFORCE_NOT_NULL( + dOut, errors::NotFound("The input DenseTensor dOut can not be nullptr")); + PADDLE_ENFORCE_NOT_NULL( + dX, errors::NotFound("The output DenseTensor dX can not be nullptr")); + if (!Out) { + Out = dOut; // fake out + } + if (static_cast(Functor::FwdDeps()) & + static_cast(funcs::ActBwdOpFwdDeps::kDepX)) { + PADDLE_ENFORCE_NOT_NULL( + X, errors::NotFound("The input DenseTensor X can not be nullptr")); + } else { + VLOG(10) << "Inplace activation of Op Functor: " << typeid(Functor).name(); + X = dX; + } + + dev_ctx.template Alloc(dX); + auto dout = phi::EigenVector::Flatten( + GET_DATA_SAFELY(dOut, "Input", "Out@GRAD", "ActivationGrad")); + auto out = phi::EigenVector::Flatten( + GET_DATA_SAFELY(Out, "Input", "Out", "ActivationGrad")); + auto dx = phi::EigenVector::Flatten( + GET_DATA_SAFELY(dX, "Input", "X@GRAD", "ActivationGrad")); + auto x = phi::EigenVector::Flatten( + GET_DATA_SAFELY(X, "Input", "X", "ActivationGrad")); + auto* place = dev_ctx.eigen_device(); + // use 32bit index to speed up computation + bool use_32bit_index = out.size() < Eigen::NumTraits::highest(); + bool is_gpu_place = paddle::platform::is_gpu_place(dev_ctx.GetPlace()); + if (use_32bit_index && is_gpu_place) { + functor(*place, + To32BitIndex(x), + To32BitIndex(out), + To32BitIndex(dout), + To32BitIndex(dx)); + } else { + functor(*place, x, out, dout, dx); + } +} + +template +void ActivationDoubleGradImpl(const Context& dev_ctx, + const DenseTensor* X, + const DenseTensor* Out, + const DenseTensor* ddX, + DenseTensor* dX, + DenseTensor* dOut, + DenseTensor* ddOut, + const Functor& functor) { + if (static_cast(Functor::FwdDeps()) & + static_cast(funcs::ActBwdOpFwdDeps::kDepX)) { + PADDLE_ENFORCE_NOT_NULL( + X, errors::NotFound("The input DenseTensor X can not be nullptr")); + } else { + VLOG(10) << "Inplace activation of Op Functor: " << typeid(Functor).name(); + X = ddX; + } + if (static_cast(Functor::FwdDeps()) & + static_cast(funcs::ActBwdOpFwdDeps::kDepOut)) { + PADDLE_ENFORCE_NOT_NULL( + Out, errors::NotFound("The input DenseTensor Out can not be nullptr")); + } else { + VLOG(10) << "Inplace activation of Op Functor: " << typeid(Functor).name(); + Out = ddX; + } + + if (ddOut) { + dev_ctx.template Alloc(ddOut); + } + if (dOut) { + dev_ctx.template Alloc(dOut); + } + if (dX) { + dX->Resize(Out->dims()); + dev_ctx.template Alloc(dX); + } + + functor(dev_ctx, X, Out, ddX, ddOut, dOut, dX); +} + +template +void ReluDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& out, + const DenseTensor& ddx, + DenseTensor* ddout) { + funcs::ReluGradGradFunctor relu_double_grad_functor; + ActivationDoubleGradImpl>( + dev_ctx, + nullptr, + &out, + &ddx, + nullptr, + nullptr, + ddout, + relu_double_grad_functor); +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/activation_impl.h b/paddle/phi/kernels/impl/activation_impl.h index 95acf12c19e10..ca3debd394a1e 100644 --- a/paddle/phi/kernels/impl/activation_impl.h +++ b/paddle/phi/kernels/impl/activation_impl.h @@ -47,113 +47,4 @@ void ActivationImpl(const Context& dev_ctx, } } -template -void ActivationGradImpl(const Context& dev_ctx, - const DenseTensor* X, - const DenseTensor* Out, - const DenseTensor* dOut, - DenseTensor* dX, - const Functor& functor) { - if (static_cast(Functor::FwdDeps()) & - static_cast(funcs::ActBwdOpFwdDeps::kDepOut)) { - PADDLE_ENFORCE_NOT_NULL( - Out, errors::NotFound("The input DenseTensor Out can not be nullptr")); - } - PADDLE_ENFORCE_NOT_NULL( - dOut, errors::NotFound("The input DenseTensor dOut can not be nullptr")); - PADDLE_ENFORCE_NOT_NULL( - dX, errors::NotFound("The output DenseTensor dX can not be nullptr")); - if (!Out) { - Out = dOut; // fake out - } - if (static_cast(Functor::FwdDeps()) & - static_cast(funcs::ActBwdOpFwdDeps::kDepX)) { - PADDLE_ENFORCE_NOT_NULL( - X, errors::NotFound("The input DenseTensor X can not be nullptr")); - } else { - VLOG(10) << "Inplace activation of Op Functor: " << typeid(Functor).name(); - X = dX; - } - - dev_ctx.template Alloc(dX); - auto dout = phi::EigenVector::Flatten( - GET_DATA_SAFELY(dOut, "Input", "Out@GRAD", "ActivationGrad")); - auto out = phi::EigenVector::Flatten( - GET_DATA_SAFELY(Out, "Input", "Out", "ActivationGrad")); - auto dx = phi::EigenVector::Flatten( - GET_DATA_SAFELY(dX, "Input", "X@GRAD", "ActivationGrad")); - auto x = phi::EigenVector::Flatten( - GET_DATA_SAFELY(X, "Input", "X", "ActivationGrad")); - auto* place = dev_ctx.eigen_device(); - // use 32bit index to speed up computation - bool use_32bit_index = out.size() < Eigen::NumTraits::highest(); - bool is_gpu_place = paddle::platform::is_gpu_place(dev_ctx.GetPlace()); - if (use_32bit_index && is_gpu_place) { - functor(*place, - To32BitIndex(x), - To32BitIndex(out), - To32BitIndex(dout), - To32BitIndex(dx)); - } else { - functor(*place, x, out, dout, dx); - } -} - -template -void ActivationDoubleGradImpl(const Context& dev_ctx, - const DenseTensor* X, - const DenseTensor* Out, - const DenseTensor* ddX, - DenseTensor* dX, - DenseTensor* dOut, - DenseTensor* ddOut, - const Functor& functor) { - if (static_cast(Functor::FwdDeps()) & - static_cast(funcs::ActBwdOpFwdDeps::kDepX)) { - PADDLE_ENFORCE_NOT_NULL( - X, errors::NotFound("The input DenseTensor X can not be nullptr")); - } else { - VLOG(10) << "Inplace activation of Op Functor: " << typeid(Functor).name(); - X = ddX; - } - if (static_cast(Functor::FwdDeps()) & - static_cast(funcs::ActBwdOpFwdDeps::kDepOut)) { - PADDLE_ENFORCE_NOT_NULL( - Out, errors::NotFound("The input DenseTensor Out can not be nullptr")); - } else { - VLOG(10) << "Inplace activation of Op Functor: " << typeid(Functor).name(); - Out = ddX; - } - - if (ddOut) { - dev_ctx.template Alloc(ddOut); - } - if (dOut) { - dev_ctx.template Alloc(dOut); - } - if (dX) { - dX->Resize(Out->dims()); - dev_ctx.template Alloc(dX); - } - - functor(dev_ctx, X, Out, ddX, ddOut, dOut, dX); -} - -template -void ReluDoubleGradKernel(const Context& dev_ctx, - const DenseTensor& out, - const DenseTensor& ddx, - DenseTensor* ddout) { - funcs::ReluGradGradFunctor relu_double_grad_functor; - ActivationDoubleGradImpl>( - dev_ctx, - nullptr, - &out, - &ddx, - nullptr, - nullptr, - ddout, - relu_double_grad_functor); -} - } // namespace phi