diff --git a/backends/metax_gpu/common/flags_declare.cc b/backends/metax_gpu/common/flags_declare.cc index fb656878033..0b65d635510 100644 --- a/backends/metax_gpu/common/flags_declare.cc +++ b/backends/metax_gpu/common/flags_declare.cc @@ -101,6 +101,18 @@ PHI_DEFINE_EXPORTED_bool( "faster but it may loss precision in most case. If true, the compute " "type will be set to fp16. Default is false."); +/** + * Torch Compatible related FLAG + * Name: FLAGS_torch_compatible_kernel + * Since Version: 3.2.2 + * Value Range: bool, default=false + * Example: + * Note: Whether use torch compatible version kernel. + */ +PHI_DEFINE_EXPORTED_bool(torch_compatible_kernel, + false, + "Whether use torch compatible version kernel."); + PHI_DEFINE_EXPORTED_string( selected_gpus, "", diff --git a/backends/metax_gpu/kernels/cuda_kernels/gammaln_grad_kernel.cu b/backends/metax_gpu/kernels/cuda_kernels/gammaln_grad_kernel.cu new file mode 100644 index 00000000000..850f0d68bac --- /dev/null +++ b/backends/metax_gpu/kernels/cuda_kernels/gammaln_grad_kernel.cu @@ -0,0 +1,28 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/gammaln_grad_kernel.h" +#include "paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h" + +PD_CUSTOM_KERNEL_REGISTER(gammaln_grad, + metax_gpu, + ALL_LAYOUT, + phi::GammalnGradKernel, + float, + double, + phi::float16, + phi::bfloat16) {} diff --git a/backends/metax_gpu/kernels/funcs/softmax.cu b/backends/metax_gpu/kernels/funcs/softmax.cu index 44bfd02a308..a587f9ed016 100644 --- a/backends/metax_gpu/kernels/funcs/softmax.cu +++ b/backends/metax_gpu/kernels/funcs/softmax.cu @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. */ #include +#include "glog/logging.h" #include "kernels/metax_kernel/metax_context.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_dnn.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/softmax.h" #include "paddle/phi/kernels/funcs/softmax_impl.h" - namespace phi { namespace funcs { @@ -38,6 +38,7 @@ void SoftmaxCUDNNFunctor::operator()( ScopedTensorDescriptor yDesc; std::vector cudnn_tensor_dims = common::vectorize(X->dims()); DataLayout layout = DataLayout::kNCHW; + VLOG(0) << "Enter softmax Kernel22."; if (cudnn_tensor_dims.size() == 5) { layout = DataLayout::kNCDHW; } diff --git a/backends/metax_gpu/kernels/gpudnn/softmax_kernel_dnn.cu b/backends/metax_gpu/kernels/gpudnn/softmax_kernel_dnn.cu new file mode 100644 index 00000000000..b51f92c96a4 --- /dev/null +++ b/backends/metax_gpu/kernels/gpudnn/softmax_kernel_dnn.cu @@ -0,0 +1,70 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "kernels/gpudnn/softmax_gpudnn.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/softmax_kernel.h" + +namespace phi { + +template +void SoftmaxGPUDNNKernel(const Context& dev_ctx, + const DenseTensor& x, + int axis, + DenseTensor* out) { + dev_ctx.template Alloc(out); + if (x.numel() == 0) return; + + const int rank = x.dims().size(); + // For 0D Tensor + if (rank == 0) { + phi::funcs::set_constant(dev_ctx, out, static_cast(1.0)); + return; + } + + SoftmaxForwardCUDAKernelDriver(dev_ctx, x, axis, out); +} + +} // namespace phi + +#ifdef PADDLE_WITH_HIP +PD_REGISTER_PLUGIN_KERNEL(softmax, + metax_gpu, + ALL_LAYOUT, + phi::SoftmaxGPUDNNKernel, + float, + phi::float16, + phi::bfloat16) {} +#else +#if CUDNN_VERSION_MIN(8, 1, 0) +PD_REGISTER_PLUGIN_KERNEL(softmax, + metax_gpu, + ALL_LAYOUT, + phi::SoftmaxGPUDNNKernel, + float, + double, + phi::float16, + phi::bfloat16) {} +#else +PD_REGISTER_PLUGIN_KERNEL(softmax, + metax_gpu, + ALL_LAYOUT, + phi::SoftmaxGPUDNNKernel, + float, + double, + phi::float16) {} +#endif +#endif diff --git a/backends/metax_gpu/kernels/metax_kernel/softmax_kernel_register.cu b/backends/metax_gpu/kernels/metax_kernel/softmax_kernel_register.cu index 0344a81dc19..523a2e4d76b 100644 --- a/backends/metax_gpu/kernels/metax_kernel/softmax_kernel_register.cu +++ b/backends/metax_gpu/kernels/metax_kernel/softmax_kernel_register.cu @@ -11,7 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - +#if 0 #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/float16.h" @@ -27,3 +27,5 @@ PD_REGISTER_PLUGIN_KERNEL(softmax, double, phi::dtype::float16, phi::dtype::bfloat16) {} + +#endif diff --git a/backends/metax_gpu/kernels/metax_kernel/svd_kernel_register.cu b/backends/metax_gpu/kernels/metax_kernel/svd_kernel_register.cu index 5f9d6cc20e0..c8ece09bbae 100644 --- a/backends/metax_gpu/kernels/metax_kernel/svd_kernel_register.cu +++ b/backends/metax_gpu/kernels/metax_kernel/svd_kernel_register.cu @@ -15,7 +15,7 @@ #ifndef PADDLE_WITH_HIP // HIP not support cusolver -#include "kernels/impl/values_vectors_functor.h" +#include "kernels/metax_kernel/metax_context.h" #include "paddle/phi/backends/dynload/cusolver.h" #include "paddle/phi/common/memory_utils.h" #include "paddle/phi/core/kernel_registry.h" @@ -60,7 +60,6 @@ void GesvdjBatched(const phi::GPUContext& dev_ctx, int ldu = m; int ldt = n; int lwork = 0; - // auto handle = dev_ctx.cusolver_dn_handle(); auto handle = GetCusolverDnHandle(dev_ctx.stream(), dev_ctx.GetPlace()); PADDLE_ENFORCE_GPU_SUCCESS( phi::dynload::cusolverDnCreateGesvdjInfo(&gesvdj_params)); @@ -142,7 +141,6 @@ void GesvdjBatched(const phi::GPUContext& dev_ctx, int ldu = m; int ldt = n; int lwork = 0; - // auto handle = dev_ctx.cusolver_dn_handle(); auto handle = GetCusolverDnHandle(dev_ctx.stream(), dev_ctx.GetPlace()); PADDLE_ENFORCE_GPU_SUCCESS( phi::dynload::cusolverDnCreateGesvdjInfo(&gesvdj_params)); @@ -205,17 +203,17 @@ void GesvdjBatched(const phi::GPUContext& dev_ctx, } template <> -void GesvdjBatched>(const phi::GPUContext& dev_ctx, - int batchSize, - int m, - int n, - int k, - phi::dtype::complex* A, - phi::dtype::complex* U, - phi::dtype::complex* V, - float* S, - int* info, - int thin_UV) { +void GesvdjBatched(const phi::GPUContext& dev_ctx, + int batchSize, + int m, + int n, + int k, + phi::complex64* A, + phi::complex64* U, + phi::complex64* V, + float* S, + int* info, + int thin_UV) { /* compute singular vectors */ const cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR; /* compute singular vectors */ @@ -224,7 +222,6 @@ void GesvdjBatched>(const phi::GPUContext& dev_ctx, int ldu = m; int ldt = n; int lwork = 0; - // auto handle = dev_ctx.cusolver_dn_handle(); auto handle = GetCusolverDnHandle(dev_ctx.stream(), dev_ctx.GetPlace()); PADDLE_ENFORCE_GPU_SUCCESS( phi::dynload::cusolverDnCreateGesvdjInfo(&gesvdj_params)); @@ -245,10 +242,10 @@ void GesvdjBatched>(const phi::GPUContext& dev_ctx, gesvdj_params)); auto workspace = phi::memory_utils::Alloc( dev_ctx.GetPlace(), - lwork * sizeof(phi::dtype::complex), + lwork * sizeof(phi::complex64), phi::Stream(reinterpret_cast(dev_ctx.stream()))); - phi::dtype::complex* workspace_ptr = - reinterpret_cast*>(workspace->ptr()); + phi::complex64* workspace_ptr = + reinterpret_cast(workspace->ptr()); int stride_A = lda * n; int stride_U = ldu * (thin_UV ? k : m); int stride_V = ldt * (thin_UV ? k : n); @@ -289,17 +286,17 @@ void GesvdjBatched>(const phi::GPUContext& dev_ctx, } template <> -void GesvdjBatched>(const phi::GPUContext& dev_ctx, - int batchSize, - int m, - int n, - int k, - phi::dtype::complex* A, - phi::dtype::complex* U, - phi::dtype::complex* V, - double* S, - int* info, - int thin_UV) { +void GesvdjBatched(const phi::GPUContext& dev_ctx, + int batchSize, + int m, + int n, + int k, + phi::complex128* A, + phi::complex128* U, + phi::complex128* V, + double* S, + int* info, + int thin_UV) { /* compute singular vectors */ const cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR; /* compute singular vectors */ @@ -308,7 +305,6 @@ void GesvdjBatched>(const phi::GPUContext& dev_ctx, int ldu = m; int ldt = n; int lwork = 0; - // auto handle = dev_ctx.cusolver_dn_handle(); auto handle = GetCusolverDnHandle(dev_ctx.stream(), dev_ctx.GetPlace()); PADDLE_ENFORCE_GPU_SUCCESS( phi::dynload::cusolverDnCreateGesvdjInfo(&gesvdj_params)); @@ -329,10 +325,10 @@ void GesvdjBatched>(const phi::GPUContext& dev_ctx, gesvdj_params)); auto workspace = phi::memory_utils::Alloc( dev_ctx.GetPlace(), - lwork * sizeof(phi::dtype::complex), + lwork * sizeof(phi::complex128), phi::Stream(reinterpret_cast(dev_ctx.stream()))); - phi::dtype::complex* workspace_ptr = - reinterpret_cast*>(workspace->ptr()); + phi::complex128* workspace_ptr = + reinterpret_cast(workspace->ptr()); int stride_A = lda * n; int stride_U = ldu * (thin_UV ? k : m); int stride_V = ldt * (thin_UV ? k : n); @@ -432,7 +428,7 @@ PD_REGISTER_PLUGIN_KERNEL(svd, // cuda_only phi::SvdKernel, float, double, - phi::dtype::complex, - phi::dtype::complex) {} + phi::complex64, + phi::complex128) {} #endif // not PADDLE_WITH_HIP diff --git a/backends/metax_gpu/patch/paddle.patch b/backends/metax_gpu/patch/paddle.patch index 6578029129e..fe0d9e104a5 100755 --- a/backends/metax_gpu/patch/paddle.patch +++ b/backends/metax_gpu/patch/paddle.patch @@ -18,6 +18,22 @@ index cfada544d4..a690e97d74 100644 endif() set(EIGEN_INCLUDE_DIR ${SOURCE_DIR}) +diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt +index 99a0116d92..2566e7c41a 100755 +--- a/paddle/fluid/operators/fused/CMakeLists.txt ++++ b/paddle/fluid/operators/fused/CMakeLists.txt +@@ -43,6 +43,11 @@ if(WITH_GPU OR WITH_ROCM) + op_library(fused_multi_transformer_int8_op) + endif() + ++ if 1 ++ op_library(fused_gemm_epilogue_op) ++ endif() ++ ++ + if(CUDA_VERSION GREATER_EQUAL 11.6) + op_library(fused_gemm_epilogue_op) + endif() diff --git a/paddle/fluid/platform/profiler/cupti_data_process.cc b/paddle/fluid/platform/profiler/cupti_data_process.cc index bff0f2bf70..9376b5781f 100644 --- a/paddle/fluid/platform/profiler/cupti_data_process.cc @@ -441,10 +457,38 @@ index 024a7de73e..66b373d698 100644 } while (0) #elif defined(__HIPCC__) diff --git a/paddle/phi/kernels/funcs/blas/blas_impl.cu.h b/paddle/phi/kernels/funcs/blas/blas_impl.cu.h -index ae7b67de6d..fbe9f67737 100644 +index ae7b67de6d..9ac725314f 100644 --- a/paddle/phi/kernels/funcs/blas/blas_impl.cu.h +++ b/paddle/phi/kernels/funcs/blas/blas_impl.cu.h -@@ -368,7 +368,7 @@ struct CUBlas { +@@ -218,11 +218,27 @@ struct CUBlas { + } + }; + ++template ++void print_args(Args... args) { ++ std::cout << "Arguments (" << sizeof...(args) << "): ["; ++ bool first = true; ++ auto printer = [&first](const auto& arg) { ++ if (!first) std::cout << ", "; ++ std::cout << arg; ++ first = false; ++ }; ++ (printer(args), ...); ++ std::cout << "]" << std::endl; ++} ++ + template <> + struct CUBlas { + template + static void GEMM(ARGS... args) { ++ // print_args(args...); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDgemm(args...)); ++ ++ + } + + template +@@ -368,7 +384,7 @@ struct CUBlas { cudaDataType_t Ctype, int ldc, int batchCount, @@ -453,7 +497,7 @@ index ae7b67de6d..fbe9f67737 100644 #if CUDA_VERSION >= 8000 cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; #if CUDA_VERSION >= 9000 -@@ -476,7 +476,7 @@ struct CUBlas { +@@ -476,7 +492,7 @@ struct CUBlas { void *C, cudaDataType_t Ctype, int ldc, @@ -462,7 +506,7 @@ index ae7b67de6d..fbe9f67737 100644 #if CUDA_VERSION >= 8000 cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; #if CUDA_VERSION >= 9000 -@@ -532,7 +532,7 @@ struct CUBlas { +@@ -532,7 +548,7 @@ struct CUBlas { void *C, cudaDataType_t Ctype, int64_t ldc, @@ -471,7 +515,7 @@ index ae7b67de6d..fbe9f67737 100644 #if CUDA_VERSION >= 12030 && defined(__linux__) cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; bool use_tensor_op_math = dev_ctx->tensor_core_available(); -@@ -759,7 +759,7 @@ struct CUBlas { +@@ -759,7 +775,7 @@ struct CUBlas { void *C, cudaDataType_t Ctype, int ldc, @@ -480,7 +524,7 @@ index ae7b67de6d..fbe9f67737 100644 #if CUDA_VERSION >= 8000 cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; #if CUDA_VERSION >= 9000 -@@ -815,7 +815,7 @@ struct CUBlas { +@@ -815,7 +831,7 @@ struct CUBlas { void *C, cudaDataType_t Ctype, int64_t ldc, @@ -489,7 +533,7 @@ index ae7b67de6d..fbe9f67737 100644 #if CUDA_VERSION >= 12030 && defined(__linux__) cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; bool use_tensor_op_math = dev_ctx->tensor_core_available(); -@@ -1154,7 +1154,7 @@ struct CUBlas { +@@ -1154,7 +1170,7 @@ struct CUBlas { void *C, cudaDataType_t Ctype, int ldc, @@ -498,7 +542,7 @@ index ae7b67de6d..fbe9f67737 100644 #if CUDA_VERSION >= 8000 cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; #if CUDA_VERSION >= 9000 -@@ -1210,7 +1210,7 @@ struct CUBlas { +@@ -1210,7 +1226,7 @@ struct CUBlas { void *C, cudaDataType_t Ctype, int64_t ldc, @@ -507,7 +551,7 @@ index ae7b67de6d..fbe9f67737 100644 #if CUDA_VERSION >= 12030 && defined(__linux__) cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; bool use_tensor_op_math = dev_ctx->tensor_core_available(); -@@ -1484,7 +1484,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, +@@ -1484,7 +1500,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, C, CUDA_R_16F, N, @@ -516,7 +560,7 @@ index ae7b67de6d..fbe9f67737 100644 #else PADDLE_THROW(common::errors::Unimplemented( "GEMM_EX_64 is not supported on cuda < 12.3")); -@@ -1508,7 +1508,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, +@@ -1508,7 +1524,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, C, CUDA_R_16F, static_cast(N), @@ -525,7 +569,7 @@ index ae7b67de6d..fbe9f67737 100644 } #else // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm -@@ -1694,7 +1694,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, +@@ -1694,7 +1710,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, C, CUDA_R_16F, N, @@ -534,7 +578,7 @@ index ae7b67de6d..fbe9f67737 100644 #else PADDLE_THROW(common::errors::Unimplemented( "GEMM_EX_64 is not supported on cuda < 12.3")); -@@ -1719,7 +1719,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, +@@ -1719,7 +1735,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, C, CUDA_R_16F, static_cast(N), @@ -543,7 +587,7 @@ index ae7b67de6d..fbe9f67737 100644 #else // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm dev_ctx_.CublasCall([&](cublasHandle_t handle) { -@@ -1831,7 +1831,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, +@@ -1831,7 +1847,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, C, CUDA_R_16BF, static_cast(N), @@ -552,7 +596,7 @@ index ae7b67de6d..fbe9f67737 100644 algo)); }); } -@@ -1932,7 +1932,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, +@@ -1932,7 +1948,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, C, CUDA_R_16BF, static_cast(N), @@ -561,7 +605,7 @@ index ae7b67de6d..fbe9f67737 100644 algo)); }); } -@@ -2026,7 +2026,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, +@@ -2026,7 +2042,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, C, CUDA_C_32F, static_cast(N), @@ -570,7 +614,7 @@ index ae7b67de6d..fbe9f67737 100644 #else dev_ctx_.CublasCall([&](cublasHandle_t handle) { -@@ -2111,7 +2111,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, +@@ -2111,7 +2127,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, C, CUDA_C_64F, N, @@ -579,7 +623,7 @@ index ae7b67de6d..fbe9f67737 100644 #else PADDLE_THROW(common::errors::Unimplemented( "GEMM_EX_64 is not supported on cuda < 12.3")); -@@ -2136,7 +2136,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, +@@ -2136,7 +2152,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, C, CUDA_C_64F, static_cast(N), @@ -588,7 +632,25 @@ index ae7b67de6d..fbe9f67737 100644 #else // CUDA_VERSION >= 8000 // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm dev_ctx_.CublasCall([&](cublasHandle_t handle) { -@@ -3129,7 +3129,7 @@ inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, +@@ -2272,7 +2288,7 @@ inline void Blas::GEMM(bool transA, + C, + CUDA_R_16F, + ldc, +- CUDA_R_32F, ++ CUBLAS_COMPUTE_32F, + algo)); + }); + } +@@ -2334,7 +2350,7 @@ inline void Blas::GEMM(bool transA, + C, + CUDA_R_16BF, + ldc, +- CUDA_R_32F, ++ CUBLAS_COMPUTE_32F, + algo)); + }); + #else +@@ -3129,7 +3145,7 @@ inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, CUDA_R_16F, ldc, batchCount, @@ -597,6 +659,15 @@ index ae7b67de6d..fbe9f67737 100644 } template <> +@@ -3197,7 +3213,7 @@ inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, + CUDA_R_16BF, + ldc, + batchCount, +- CUDA_R_32F, ++ CUBLAS_COMPUTE_32F, + algo)); + }); + #else diff --git a/paddle/phi/kernels/funcs/blas/blaslt_gemm_search.h b/paddle/phi/kernels/funcs/blas/blaslt_gemm_search.h index e63b3d2f6e..95d7e6f204 100644 --- a/paddle/phi/kernels/funcs/blas/blaslt_gemm_search.h @@ -1129,3 +1200,27 @@ index e6b3960f6d..564125f1f6 100644 if ((x <= T{0}) || (a <= T{0})) return (T{1.0}); +diff --git a/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h b/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h +index 410fb3c560..7d173d46f5 100644 +--- a/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h ++++ b/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h +@@ -20,8 +20,8 @@ + namespace phi { + template + HOSTDEVICE T digamma_positive_domain(T x) { +- static T c = T{8.5}; +- static T euler_mascheroni = T{0.57721566490153286060}; ++ const static T c = T{8.5}; ++ const static T euler_mascheroni = T{0.57721566490153286060}; + T r; + T value; + T x2; +@@ -54,7 +54,7 @@ HOSTDEVICE T digamma_positive_domain(T x) { + + template + HOSTDEVICE T digamma(T x) { +- static T pi = T{3.14159265358979323846}; ++ const static T pi = T{3.14159265358979323846}; + + if (x == T{0.0}) { + T inf = std::numeric_limits::infinity();