diff --git a/backends/metax_gpu/patch/paddle.patch b/backends/metax_gpu/patch/paddle.patch index 4c844e5cc82..6578029129e 100755 --- a/backends/metax_gpu/patch/paddle.patch +++ b/backends/metax_gpu/patch/paddle.patch @@ -440,7 +440,163 @@ index 024a7de73e..66b373d698 100644 } \ } while (0) #elif defined(__HIPCC__) - +diff --git a/paddle/phi/kernels/funcs/blas/blas_impl.cu.h b/paddle/phi/kernels/funcs/blas/blas_impl.cu.h +index ae7b67de6d..fbe9f67737 100644 +--- a/paddle/phi/kernels/funcs/blas/blas_impl.cu.h ++++ b/paddle/phi/kernels/funcs/blas/blas_impl.cu.h +@@ -368,7 +368,7 @@ struct CUBlas { + cudaDataType_t Ctype, + int ldc, + int batchCount, +- cudaDataType_t computeType) { ++ cublasComputeType_t computeType) { + #if CUDA_VERSION >= 8000 + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; + #if CUDA_VERSION >= 9000 +@@ -476,7 +476,7 @@ struct CUBlas { + void *C, + cudaDataType_t Ctype, + int ldc, +- cudaDataType_t computeType) { ++ cublasComputeType_t computeType) { + #if CUDA_VERSION >= 8000 + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; + #if CUDA_VERSION >= 9000 +@@ -532,7 +532,7 @@ struct CUBlas { + void *C, + cudaDataType_t Ctype, + int64_t ldc, +- cudaDataType_t computeType) { ++ cublasComputeType_t computeType) { + #if CUDA_VERSION >= 12030 && defined(__linux__) + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; + bool use_tensor_op_math = dev_ctx->tensor_core_available(); +@@ -759,7 +759,7 @@ struct CUBlas { + void *C, + cudaDataType_t Ctype, + int ldc, +- cudaDataType_t computeType) { ++ cublasComputeType_t computeType) { + #if CUDA_VERSION >= 8000 + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; + #if CUDA_VERSION >= 9000 +@@ -815,7 +815,7 @@ struct CUBlas { + void *C, + cudaDataType_t Ctype, + int64_t ldc, +- cudaDataType_t computeType) { ++ cublasComputeType_t computeType) { + #if CUDA_VERSION >= 12030 && defined(__linux__) + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; + bool use_tensor_op_math = dev_ctx->tensor_core_available(); +@@ -1154,7 +1154,7 @@ struct CUBlas { + void *C, + cudaDataType_t Ctype, + int ldc, +- cudaDataType_t computeType) { ++ cublasComputeType_t computeType) { + #if CUDA_VERSION >= 8000 + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; + #if CUDA_VERSION >= 9000 +@@ -1210,7 +1210,7 @@ struct CUBlas { + void *C, + cudaDataType_t Ctype, + int64_t ldc, +- cudaDataType_t computeType) { ++ cublasComputeType_t computeType) { + #if CUDA_VERSION >= 12030 && defined(__linux__) + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; + bool use_tensor_op_math = dev_ctx->tensor_core_available(); +@@ -1484,7 +1484,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, + C, + CUDA_R_16F, + N, +- CUDA_R_32F); ++ CUBLAS_COMPUTE_32F); + #else + PADDLE_THROW(common::errors::Unimplemented( + "GEMM_EX_64 is not supported on cuda < 12.3")); +@@ -1508,7 +1508,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, + C, + CUDA_R_16F, + static_cast(N), +- CUDA_R_32F); ++ CUBLAS_COMPUTE_32F); + } + #else + // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm +@@ -1694,7 +1694,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, + C, + CUDA_R_16F, + N, +- CUDA_R_32F); ++ CUBLAS_COMPUTE_32F); + #else + PADDLE_THROW(common::errors::Unimplemented( + "GEMM_EX_64 is not supported on cuda < 12.3")); +@@ -1719,7 +1719,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, + C, + CUDA_R_16F, + static_cast(N), +- CUDA_R_32F); ++ CUBLAS_COMPUTE_32F); + #else + // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm + dev_ctx_.CublasCall([&](cublasHandle_t handle) { +@@ -1831,7 +1831,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, + C, + CUDA_R_16BF, + static_cast(N), +- CUDA_R_32F, ++ CUBLAS_COMPUTE_32F, + algo)); + }); + } +@@ -1932,7 +1932,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, + C, + CUDA_R_16BF, + static_cast(N), +- CUDA_R_32F, ++ CUBLAS_COMPUTE_32F, + algo)); + }); + } +@@ -2026,7 +2026,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, + C, + CUDA_C_32F, + static_cast(N), +- CUDA_C_32F); ++ CUBLAS_COMPUTE_32F); + + #else + dev_ctx_.CublasCall([&](cublasHandle_t handle) { +@@ -2111,7 +2111,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, + C, + CUDA_C_64F, + N, +- CUDA_C_64F); ++ CUBLAS_COMPUTE_64F); + #else + PADDLE_THROW(common::errors::Unimplemented( + "GEMM_EX_64 is not supported on cuda < 12.3")); +@@ -2136,7 +2136,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, + C, + CUDA_C_64F, + static_cast(N), +- CUDA_C_64F); ++ CUBLAS_COMPUTE_64F); + #else // CUDA_VERSION >= 8000 + // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm + dev_ctx_.CublasCall([&](cublasHandle_t handle) { +@@ -3129,7 +3129,7 @@ inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, + CUDA_R_16F, + ldc, + batchCount, +- CUDA_R_32F); ++ CUBLAS_COMPUTE_32F); + } + + template <> diff --git a/paddle/phi/kernels/funcs/blas/blaslt_gemm_search.h b/paddle/phi/kernels/funcs/blas/blaslt_gemm_search.h index e63b3d2f6e..95d7e6f204 100644 --- a/paddle/phi/kernels/funcs/blas/blaslt_gemm_search.h