diff --git a/backends/metax_gpu/kernels/cuda_kernels/activation_grad_kernel_register.cu b/backends/metax_gpu/kernels/cuda_kernels/activation_grad_kernel_register.cu index 6c46ef10c0f..d49e74dea73 100644 --- a/backends/metax_gpu/kernels/cuda_kernels/activation_grad_kernel_register.cu +++ b/backends/metax_gpu/kernels/cuda_kernels/activation_grad_kernel_register.cu @@ -15,8 +15,6 @@ limitations under the License. */ #include "glog/logging.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_device_function.h" -#include "paddle/phi/common/bfloat16.h" -#include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/activation_grad_kernel.h" #include "paddle/phi/kernels/full_kernel.h" @@ -119,6 +117,7 @@ void ActivationGradGPUImpl(const Context& dev_ctx, ActivationGradGPUImpl>( \ dev_ctx, &x, nullptr, &dout, dx, functor); \ } + #define DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_DOUBLE_ATTRS_DEPX( \ name, functor_class, attr1, attr2) \ template \ @@ -135,6 +134,7 @@ void ActivationGradGPUImpl(const Context& dev_ctx, ActivationGradGPUImpl>( \ dev_ctx, &x, nullptr, &dout, dx, functor); \ } + #define DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(name, functor_class) \ template \ void name##GradKernel(const Context& dev_ctx, \ @@ -161,6 +161,21 @@ void ActivationGradGPUImpl(const Context& dev_ctx, dev_ctx, nullptr, &out, &dout, dx, functor); \ } +#define DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_DOUBLE_ATTRS_DEPOUT( \ + name, functor_class, attr) \ + template \ + void name##GradKernel(const Context& dev_ctx, \ + const DenseTensor& out, \ + const DenseTensor& dout, \ + double attr, \ + DenseTensor* dx) { \ + funcs::functor_class functor; \ + auto attrs = functor.GetAttrs(); \ + *(attrs[0].second) = attr; \ + ActivationGradGPUImpl>( \ + dev_ctx, nullptr, &out, &dout, dx, functor); \ + } + #define DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPOUT( \ name, functor_class, attr1, attr2) \ template \ @@ -240,9 +255,9 @@ DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Mish, DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Celu, CudaCELUGradFunctor, alpha); -DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPOUT(LogitCUDA, - CudaLogitGradFunctor, - eps); +DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_DOUBLE_ATTRS_DEPOUT(LogitCUDA, + CudaLogitGradFunctor, + eps); DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(HardTanh, CudaHardTanhGradFunctor, @@ -266,6 +281,7 @@ DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(ThresholdedRelu, CudaThresholdedReluGradFunctor, threshold, value); + template void SiluGradKernel(const Context& dev_ctx, const DenseTensor& x, @@ -390,14 +406,14 @@ PD_CUSTOM_KERNEL_REGISTER(relu_grad, phi::ReluGradKernel, float, double, - phi::dtype::float16) {} + phi::float16) {} PD_CUSTOM_KERNEL_REGISTER(relu_double_grad, metax_gpu, ALL_LAYOUT, phi::ReluDoubleGradKernel, float, double, - phi::dtype::float16) {} + phi::float16) {} #else PD_CUSTOM_KERNEL_REGISTER(relu_grad, metax_gpu, @@ -405,16 +421,16 @@ PD_CUSTOM_KERNEL_REGISTER(relu_grad, phi::ReluGradKernel, float, double, - phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::float16, + phi::bfloat16) {} PD_CUSTOM_KERNEL_REGISTER(relu_double_grad, metax_gpu, ALL_LAYOUT, phi::ReluDoubleGradKernel, float, double, - phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::float16, + phi::bfloat16) {} #endif #define PD_REGISTER_ACTIVATION_GRAD_KERNEL(name, func) \ @@ -424,8 +440,8 @@ PD_CUSTOM_KERNEL_REGISTER(relu_double_grad, phi::func, \ float, \ double, \ - phi::dtype::float16, \ - phi::dtype::bfloat16) {} + phi::float16, \ + phi::bfloat16) {} #define PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(name, func) \ PD_CUSTOM_KERNEL_REGISTER(name, \ @@ -434,10 +450,10 @@ PD_CUSTOM_KERNEL_REGISTER(relu_double_grad, phi::func, \ float, \ double, \ - phi::dtype::float16, \ - phi::dtype::bfloat16, \ - phi::dtype::complex, \ - phi::dtype::complex) {} + phi::float16, \ + phi::bfloat16, \ + phi::complex64, \ + phi::complex128) {} PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(sin_grad, SinGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(cos_grad, CosGradKernel) @@ -483,10 +499,10 @@ PD_CUSTOM_KERNEL_REGISTER(exp_grad, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_REGISTER_ACTIVATION_GRAD_KERNEL(softshrink_grad, SoftShrinkGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_shrink_grad, HardShrinkGradKernel) @@ -502,10 +518,10 @@ PD_CUSTOM_KERNEL_REGISTER(expm1_grad, phi::Expm1GradKernel, float, double, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(square_grad, metax_gpu, @@ -515,10 +531,10 @@ PD_CUSTOM_KERNEL_REGISTER(square_grad, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(square_double_grad, metax_gpu, ALL_LAYOUT, @@ -527,10 +543,10 @@ PD_CUSTOM_KERNEL_REGISTER(square_double_grad, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(sin_double_grad, metax_gpu, @@ -540,10 +556,10 @@ PD_CUSTOM_KERNEL_REGISTER(sin_double_grad, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(sin_triple_grad, metax_gpu, @@ -553,10 +569,10 @@ PD_CUSTOM_KERNEL_REGISTER(sin_triple_grad, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(cos_double_grad, metax_gpu, @@ -566,10 +582,10 @@ PD_CUSTOM_KERNEL_REGISTER(cos_double_grad, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(cos_triple_grad, metax_gpu, @@ -579,10 +595,10 @@ PD_CUSTOM_KERNEL_REGISTER(cos_triple_grad, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(softsign_grad, SoftsignGradKernel) @@ -604,10 +620,10 @@ PD_CUSTOM_KERNEL_REGISTER(log_double_grad, phi::LogDoubleGradKernel, float, double, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(hardswish_grad, HardSwishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel) @@ -622,8 +638,8 @@ PD_CUSTOM_KERNEL_REGISTER(rint_grad, int64_t, float, double, - phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::float16, + phi::bfloat16) {} PD_CUSTOM_KERNEL_REGISTER(round_grad, metax_gpu, ALL_LAYOUT, @@ -632,10 +648,10 @@ PD_CUSTOM_KERNEL_REGISTER(round_grad, int64_t, float, double, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(pow_grad, metax_gpu, ALL_LAYOUT, @@ -644,10 +660,10 @@ PD_CUSTOM_KERNEL_REGISTER(pow_grad, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(pow_double_grad, metax_gpu, ALL_LAYOUT, @@ -656,10 +672,10 @@ PD_CUSTOM_KERNEL_REGISTER(pow_double_grad, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(pow_triple_grad, metax_gpu, ALL_LAYOUT, @@ -668,10 +684,10 @@ PD_CUSTOM_KERNEL_REGISTER(pow_triple_grad, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(ceil_grad, metax_gpu, ALL_LAYOUT, @@ -683,8 +699,8 @@ PD_CUSTOM_KERNEL_REGISTER(ceil_grad, int16_t, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::float16, + phi::bfloat16) {} PD_CUSTOM_KERNEL_REGISTER(floor_grad, metax_gpu, ALL_LAYOUT, @@ -696,5 +712,5 @@ PD_CUSTOM_KERNEL_REGISTER(floor_grad, int16_t, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::float16, + phi::bfloat16) {}