diff --git a/Paddle b/Paddle index 5dbecdcb0e4..feaccbbbc8b 160000 --- a/Paddle +++ b/Paddle @@ -1 +1 @@ -Subproject commit 5dbecdcb0e4ddd3488927f49082dfb66c794f9e7 +Subproject commit feaccbbbc8b91d62d551f8b5509de4bbbfedc82d diff --git a/backends/metax_gpu/CMakeLists.txt b/backends/metax_gpu/CMakeLists.txt index 6aecdc1f833..9e257e9507d 100755 --- a/backends/metax_gpu/CMakeLists.txt +++ b/backends/metax_gpu/CMakeLists.txt @@ -109,6 +109,10 @@ file( CUDA_SRCS # backends ${PADDLE_SOURCE_DIR}/paddle/phi/backends/gpu/cuda/cuda_info.cc + ${PADDLE_SOURCE_DIR}/paddle/phi/backends/dynload/dynamic_loader.cc + ${PADDLE_SOURCE_DIR}/paddle/phi/backends/dynload/cublas.cc + ${PADDLE_SOURCE_DIR}/paddle/phi/backends/dynload/cublasLt.cc + ${PADDLE_SOURCE_DIR}/paddle/phi/backends/dynload/cudnn.cc ${PADDLE_SOURCE_DIR}/paddle/phi/backends/dynload/cuda_driver.cc ${PADDLE_SOURCE_DIR}/paddle/phi/backends/gpu/cuda/cuda_graph.cc # Core @@ -698,7 +702,6 @@ file( kernels/gpudnn/*.cu kernels/cuda_kernels/*.cc kernels/cuda_kernels/*.cu - kernels/funcs/blas/*.cc kernels/ernie_core/*.cu) set(CUSTOM_DEVICE_SRCS ${CUDA_SRCS} ${CC_SRCS} ${ERNIE_CORE_SRCS}) @@ -746,11 +749,28 @@ target_compile_definitions( PUBLIC PADDLE_WITH_CUDA=1 PADDLE_WITH_CUSTOM_DEVICE=1 mcblasContext=cublasContext + cublasLtContext=mcblasLtContext GPUContext=CustomContext KPSContext=CustomContext STREAM_TYPE=cudaStream_t EVENT_TYPE=cudaEvent_t - EIGEN_USE_GPU=1) + EIGEN_USE_GPU=1 + CUDA_LIB_NAME="libmcruntime.so" + BLAS_LIB_NAME="libmcblas.so" + BLASLT_LIB_NAME="libmcblasLt.so" + DNN_LIB_NAME="libmcdnn.so" + PTI_LIB_NAME="libmcpti.so" + RAND_LIB_NAME="libcurand.so" + JPEG_LIB_NAME="libnvjpeg.so" + SOLVER_LIB_NAME="libmcsolver.so" + SPARSE_LIB_NAME="libmcsparse.so" + RTC_LIB_NAME="libmcruntime.so" + FLASHATTN_LIB_NAME="libmcFlashAttn.so" + FLASHATTNV3_LIB_NAME="libflashattnv3.so" + CCL_LIB_NAME="libmccl.so" + FFT_LIB_NAME="libcufft.so" + SPARSELT_LIB_NAME="libcusparseLt.so" + CUPTI_LIB_PATH="/root/cu-bridge/CUDA_DIR/extras/CUPTI/lib64") # packing wheel package configure_file(${CMAKE_CURRENT_SOURCE_DIR}/setup.py.in diff --git a/backends/metax_gpu/common/flags_declare.cc b/backends/metax_gpu/common/flags_declare.cc index 6b497cf9fdf..fb656878033 100644 --- a/backends/metax_gpu/common/flags_declare.cc +++ b/backends/metax_gpu/common/flags_declare.cc @@ -37,6 +37,27 @@ */ static constexpr int kDefaultConvWorkspaceSizeLimitMB = 512; +/** + * CUDA related FLAG + * Name: FLAGS_cublaslt_exhaustive_search_times + * Since Version: 2.3.0 + * Value Range: int64_t, default=0 + * Example: + * Note: Represents times of exhaustive search to evaluate performance of + * cuBlasLt matmul algorithm (with/without epilogue). Set this flag + * with value > 0 to enable exhaustive search. Default is 0, means + * getting algorithms via heuristic search. There are two search methods + * in cuBlasLt, heuristic search and exhaustive search. Exhaustive search + * attempts all cuBlasLt algorithms to select the fastest, which is very + * time-consuming, and the selected algorithm will be cached for a given + * layer specification Once you change the layer specifications + * (such as M, N and K), it will re-search again. + */ +PHI_DEFINE_EXPORTED_int64( + cublaslt_exhaustive_search_times, + 0, + "The times of exhaustive search for cuBlasLt matmul with/without " + " epilogue algorithms, default is 0, means disabling exhaustive search."); PHI_DEFINE_EXPORTED_bool( cudnn_exhaustive_search, diff --git a/backends/metax_gpu/kernels/cuda_kernels/activation_grad_kernel_register.cu b/backends/metax_gpu/kernels/cuda_kernels/activation_grad_kernel_register.cu index 6c46ef10c0f..f5ee4ec25f8 100644 --- a/backends/metax_gpu/kernels/cuda_kernels/activation_grad_kernel_register.cu +++ b/backends/metax_gpu/kernels/cuda_kernels/activation_grad_kernel_register.cu @@ -15,8 +15,6 @@ limitations under the License. */ #include "glog/logging.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_device_function.h" -#include "paddle/phi/common/bfloat16.h" -#include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/activation_grad_kernel.h" #include "paddle/phi/kernels/full_kernel.h" @@ -103,6 +101,21 @@ void ActivationGradGPUImpl(const Context& dev_ctx, dev_ctx, &x, nullptr, &dout, dx, functor); \ } +#define DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_DOUBLE_ATTRS_DEPX( \ + name, functor_class, attr) \ + template \ + void name##GradKernel(const Context& dev_ctx, \ + const DenseTensor& x, \ + const DenseTensor& dout, \ + double attr, \ + DenseTensor* dx) { \ + funcs::functor_class functor; \ + auto attrs = functor.GetAttrs(); \ + *(attrs[0].second) = attr; \ + ActivationGradGPUImpl>( \ + dev_ctx, &x, nullptr, &dout, dx, functor); \ + } + #define DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX( \ name, functor_class, attr1, attr2) \ template \ @@ -119,6 +132,7 @@ void ActivationGradGPUImpl(const Context& dev_ctx, ActivationGradGPUImpl>( \ dev_ctx, &x, nullptr, &dout, dx, functor); \ } + #define DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_DOUBLE_ATTRS_DEPX( \ name, functor_class, attr1, attr2) \ template \ @@ -135,6 +149,7 @@ void ActivationGradGPUImpl(const Context& dev_ctx, ActivationGradGPUImpl>( \ dev_ctx, &x, nullptr, &dout, dx, functor); \ } + #define DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(name, functor_class) \ template \ void name##GradKernel(const Context& dev_ctx, \ @@ -161,6 +176,21 @@ void ActivationGradGPUImpl(const Context& dev_ctx, dev_ctx, nullptr, &out, &dout, dx, functor); \ } +#define DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_DOUBLE_ATTRS_DEPOUT( \ + name, functor_class, attr) \ + template \ + void name##GradKernel(const Context& dev_ctx, \ + const DenseTensor& out, \ + const DenseTensor& dout, \ + double attr, \ + DenseTensor* dx) { \ + funcs::functor_class functor; \ + auto attrs = functor.GetAttrs(); \ + *(attrs[0].second) = attr; \ + ActivationGradGPUImpl>( \ + dev_ctx, nullptr, &out, &dout, dx, functor); \ + } + #define DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPOUT( \ name, functor_class, attr1, attr2) \ template \ @@ -224,9 +254,9 @@ DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Log10, CudaLog10GradFunctor); DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Log1p, CudaLog1pGradFunctor); DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Swish, CudaSwishGradFunctor); -DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(LeakyRelu, - CudaLeakyReluGradFunctor, - alpha); +DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_DOUBLE_ATTRS_DEPX(LeakyRelu, + CudaLeakyReluGradFunctor, + alpha); DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(SoftShrink, CudaSoftShrinkGradFunctor, lambda); @@ -240,9 +270,9 @@ DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Mish, DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Celu, CudaCELUGradFunctor, alpha); -DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPOUT(LogitCUDA, - CudaLogitGradFunctor, - eps); +DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_DOUBLE_ATTRS_DEPOUT(LogitCUDA, + CudaLogitGradFunctor, + eps); DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(HardTanh, CudaHardTanhGradFunctor, @@ -266,6 +296,7 @@ DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(ThresholdedRelu, CudaThresholdedReluGradFunctor, threshold, value); + template void SiluGradKernel(const Context& dev_ctx, const DenseTensor& x, @@ -390,14 +421,14 @@ PD_CUSTOM_KERNEL_REGISTER(relu_grad, phi::ReluGradKernel, float, double, - phi::dtype::float16) {} + phi::float16) {} PD_CUSTOM_KERNEL_REGISTER(relu_double_grad, metax_gpu, ALL_LAYOUT, phi::ReluDoubleGradKernel, float, double, - phi::dtype::float16) {} + phi::float16) {} #else PD_CUSTOM_KERNEL_REGISTER(relu_grad, metax_gpu, @@ -405,16 +436,16 @@ PD_CUSTOM_KERNEL_REGISTER(relu_grad, phi::ReluGradKernel, float, double, - phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::float16, + phi::bfloat16) {} PD_CUSTOM_KERNEL_REGISTER(relu_double_grad, metax_gpu, ALL_LAYOUT, phi::ReluDoubleGradKernel, float, double, - phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::float16, + phi::bfloat16) {} #endif #define PD_REGISTER_ACTIVATION_GRAD_KERNEL(name, func) \ @@ -424,8 +455,8 @@ PD_CUSTOM_KERNEL_REGISTER(relu_double_grad, phi::func, \ float, \ double, \ - phi::dtype::float16, \ - phi::dtype::bfloat16) {} + phi::float16, \ + phi::bfloat16) {} #define PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(name, func) \ PD_CUSTOM_KERNEL_REGISTER(name, \ @@ -434,10 +465,10 @@ PD_CUSTOM_KERNEL_REGISTER(relu_double_grad, phi::func, \ float, \ double, \ - phi::dtype::float16, \ - phi::dtype::bfloat16, \ - phi::dtype::complex, \ - phi::dtype::complex) {} + phi::float16, \ + phi::bfloat16, \ + phi::complex64, \ + phi::complex128) {} PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(sin_grad, SinGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(cos_grad, CosGradKernel) @@ -483,10 +514,10 @@ PD_CUSTOM_KERNEL_REGISTER(exp_grad, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_REGISTER_ACTIVATION_GRAD_KERNEL(softshrink_grad, SoftShrinkGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_shrink_grad, HardShrinkGradKernel) @@ -502,10 +533,10 @@ PD_CUSTOM_KERNEL_REGISTER(expm1_grad, phi::Expm1GradKernel, float, double, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(square_grad, metax_gpu, @@ -515,10 +546,10 @@ PD_CUSTOM_KERNEL_REGISTER(square_grad, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(square_double_grad, metax_gpu, ALL_LAYOUT, @@ -527,10 +558,10 @@ PD_CUSTOM_KERNEL_REGISTER(square_double_grad, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(sin_double_grad, metax_gpu, @@ -540,10 +571,10 @@ PD_CUSTOM_KERNEL_REGISTER(sin_double_grad, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(sin_triple_grad, metax_gpu, @@ -553,10 +584,10 @@ PD_CUSTOM_KERNEL_REGISTER(sin_triple_grad, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(cos_double_grad, metax_gpu, @@ -566,10 +597,10 @@ PD_CUSTOM_KERNEL_REGISTER(cos_double_grad, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(cos_triple_grad, metax_gpu, @@ -579,10 +610,10 @@ PD_CUSTOM_KERNEL_REGISTER(cos_triple_grad, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(softsign_grad, SoftsignGradKernel) @@ -604,10 +635,10 @@ PD_CUSTOM_KERNEL_REGISTER(log_double_grad, phi::LogDoubleGradKernel, float, double, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(hardswish_grad, HardSwishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel) @@ -622,8 +653,8 @@ PD_CUSTOM_KERNEL_REGISTER(rint_grad, int64_t, float, double, - phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::float16, + phi::bfloat16) {} PD_CUSTOM_KERNEL_REGISTER(round_grad, metax_gpu, ALL_LAYOUT, @@ -632,10 +663,10 @@ PD_CUSTOM_KERNEL_REGISTER(round_grad, int64_t, float, double, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(pow_grad, metax_gpu, ALL_LAYOUT, @@ -644,10 +675,10 @@ PD_CUSTOM_KERNEL_REGISTER(pow_grad, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(pow_double_grad, metax_gpu, ALL_LAYOUT, @@ -656,10 +687,10 @@ PD_CUSTOM_KERNEL_REGISTER(pow_double_grad, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(pow_triple_grad, metax_gpu, ALL_LAYOUT, @@ -668,10 +699,10 @@ PD_CUSTOM_KERNEL_REGISTER(pow_triple_grad, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(ceil_grad, metax_gpu, ALL_LAYOUT, @@ -683,8 +714,8 @@ PD_CUSTOM_KERNEL_REGISTER(ceil_grad, int16_t, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::float16, + phi::bfloat16) {} PD_CUSTOM_KERNEL_REGISTER(floor_grad, metax_gpu, ALL_LAYOUT, @@ -696,5 +727,5 @@ PD_CUSTOM_KERNEL_REGISTER(floor_grad, int16_t, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::float16, + phi::bfloat16) {} diff --git a/backends/metax_gpu/kernels/cuda_kernels/activation_kernel_register.cu b/backends/metax_gpu/kernels/cuda_kernels/activation_kernel_register.cu index 363932cfc28..d91e4afd25e 100644 --- a/backends/metax_gpu/kernels/cuda_kernels/activation_kernel_register.cu +++ b/backends/metax_gpu/kernels/cuda_kernels/activation_kernel_register.cu @@ -14,8 +14,6 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_device_function.h" -#include "paddle/phi/common/bfloat16.h" -#include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/activation_kernel.h" #include "paddle/phi/kernels/full_kernel.h" @@ -75,6 +73,19 @@ void ActivationGPUImpl(const Context& dev_ctx, dev_ctx, x, out, functor); \ } +#define DEFINE_GPU_ACT_KERNEL_WITH_ONE_DOUBLE_ATTRS(name, functor_class, attr) \ + template \ + void name##Kernel(const Context& dev_ctx, \ + const DenseTensor& x, \ + double attr, \ + DenseTensor* out) { \ + funcs::functor_class functor; \ + auto attrs = functor.GetAttrs(); \ + *(attrs[0].second) = attr; \ + ActivationGPUImpl>( \ + dev_ctx, x, out, functor); \ + } + #define DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS( \ name, functor_class, attr1, attr2) \ template \ @@ -90,6 +101,7 @@ void ActivationGPUImpl(const Context& dev_ctx, ActivationGPUImpl>( \ dev_ctx, x, out, functor); \ } + #define DEFINE_GPU_ACT_KERNEL_WITH_TWO_DOUBLE_ATTRS( \ name, functor_class, attr1, attr2) \ template \ @@ -105,6 +117,7 @@ void ActivationGPUImpl(const Context& dev_ctx, ActivationGPUImpl>( \ dev_ctx, x, out, functor); \ } + DEFINE_GPU_ACTIVATION_KERNEL(Cos, CudaCosFunctor) DEFINE_GPU_ACTIVATION_KERNEL(Tan, CudaTanFunctor) DEFINE_GPU_ACTIVATION_KERNEL(Acos, CudaAcosFunctor) @@ -138,8 +151,10 @@ DEFINE_GPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Log1p, CudaLog1pFunctor) DEFINE_GPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Exp, CudaExpFunctor) DEFINE_GPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Expm1, CudaExpm1Functor) -DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, CudaLeakyReluFunctor, alpha) -DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(LogitCUDA, CudaLogitFunctor, eps) +DEFINE_GPU_ACT_KERNEL_WITH_ONE_DOUBLE_ATTRS(LeakyRelu, + CudaLeakyReluFunctor, + alpha) +DEFINE_GPU_ACT_KERNEL_WITH_ONE_DOUBLE_ATTRS(LogitCUDA, CudaLogitFunctor, eps) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink, CudaHardShrinkFunctor, threshold) @@ -286,13 +301,9 @@ void PowKernel(const Context& dev_ctx, } // namespace phi #ifdef PADDLE_WITH_HIP -PD_CUSTOM_KERNEL_REGISTER(relu, - metax_gpu, - ALL_LAYOUT, - phi::ReluKernel, - float, - double, - phi::dtype::float16) {} +PD_CUSTOM_KERNEL_REGISTER( + relu, metax_gpu, ALL_LAYOUT, phi::ReluKernel, float, double, phi::float16) { +} #else PD_CUSTOM_KERNEL_REGISTER(relu, metax_gpu, @@ -300,8 +311,8 @@ PD_CUSTOM_KERNEL_REGISTER(relu, phi::ReluKernel, float, double, - phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::float16, + phi::bfloat16) {} #endif #define PD_REGISTER_ACTIVATION_KERNEL(name, func) \ @@ -311,8 +322,8 @@ PD_CUSTOM_KERNEL_REGISTER(relu, phi::func, \ float, \ double, \ - phi::dtype::float16, \ - phi::dtype::bfloat16) {} + phi::float16, \ + phi::bfloat16) {} #define PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(name, func) \ PD_CUSTOM_KERNEL_REGISTER(name, \ @@ -321,10 +332,10 @@ PD_CUSTOM_KERNEL_REGISTER(relu, phi::func, \ float, \ double, \ - phi::dtype::float16, \ - phi::dtype::bfloat16, \ - phi::dtype::complex, \ - phi::dtype::complex) {} + phi::float16, \ + phi::bfloat16, \ + phi::complex64, \ + phi::complex128) {} PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(sin, SinKernel) PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(cos, CosKernel) @@ -357,10 +368,10 @@ PD_CUSTOM_KERNEL_REGISTER(exp, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(expm1, metax_gpu, ALL_LAYOUT, @@ -369,10 +380,10 @@ PD_CUSTOM_KERNEL_REGISTER(expm1, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(square, metax_gpu, ALL_LAYOUT, @@ -381,10 +392,10 @@ PD_CUSTOM_KERNEL_REGISTER(square, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_REGISTER_ACTIVATION_KERNEL(hard_shrink, HardShrinkKernel) PD_REGISTER_ACTIVATION_KERNEL(softshrink, SoftShrinkKernel) @@ -409,8 +420,8 @@ PD_CUSTOM_KERNEL_REGISTER(rint, int64_t, float, double, - phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::float16, + phi::bfloat16) {} PD_CUSTOM_KERNEL_REGISTER(round, metax_gpu, ALL_LAYOUT, @@ -419,10 +430,10 @@ PD_CUSTOM_KERNEL_REGISTER(round, int64_t, float, double, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(log, metax_gpu, ALL_LAYOUT, @@ -431,10 +442,10 @@ PD_CUSTOM_KERNEL_REGISTER(log, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(log2, metax_gpu, ALL_LAYOUT, @@ -443,10 +454,10 @@ PD_CUSTOM_KERNEL_REGISTER(log2, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(log10, metax_gpu, ALL_LAYOUT, @@ -455,10 +466,10 @@ PD_CUSTOM_KERNEL_REGISTER(log10, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(log1p, metax_gpu, ALL_LAYOUT, @@ -467,10 +478,10 @@ PD_CUSTOM_KERNEL_REGISTER(log1p, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(pow, metax_gpu, ALL_LAYOUT, @@ -479,10 +490,10 @@ PD_CUSTOM_KERNEL_REGISTER(pow, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(ceil, metax_gpu, ALL_LAYOUT, @@ -494,8 +505,8 @@ PD_CUSTOM_KERNEL_REGISTER(ceil, int16_t, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::float16, + phi::bfloat16) {} PD_CUSTOM_KERNEL_REGISTER(floor, metax_gpu, ALL_LAYOUT, @@ -507,5 +518,5 @@ PD_CUSTOM_KERNEL_REGISTER(floor, int16_t, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::float16, + phi::bfloat16) {} diff --git a/backends/metax_gpu/kernels/cuda_kernels/argsort_kernel_register.cu b/backends/metax_gpu/kernels/cuda_kernels/argsort_kernel_register.cu index 8fb331eeedd..20ea33834e6 100644 --- a/backends/metax_gpu/kernels/cuda_kernels/argsort_kernel_register.cu +++ b/backends/metax_gpu/kernels/cuda_kernels/argsort_kernel_register.cu @@ -26,11 +26,11 @@ namespace cub = hipcub; #endif -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/primitive/functor_primitives.h" #include "paddle/phi/kernels/transpose_kernel.h" diff --git a/backends/metax_gpu/kernels/cuda_kernels/batch_fc_kernel_register.cu b/backends/metax_gpu/kernels/cuda_kernels/batch_fc_kernel_register.cu index caccb01f71d..0e82304d31d 100644 --- a/backends/metax_gpu/kernels/cuda_kernels/batch_fc_kernel_register.cu +++ b/backends/metax_gpu/kernels/cuda_kernels/batch_fc_kernel_register.cu @@ -14,10 +14,10 @@ #include -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/eigen/common.h" namespace phi { diff --git a/backends/metax_gpu/kernels/cuda_kernels/fused_fc_elementwise_layernorm_kernel_register.cu b/backends/metax_gpu/kernels/cuda_kernels/fused_fc_elementwise_layernorm_kernel_register.cu new file mode 100644 index 00000000000..f52b0cc4b78 --- /dev/null +++ b/backends/metax_gpu/kernels/cuda_kernels/fused_fc_elementwise_layernorm_kernel_register.cu @@ -0,0 +1,23 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/fusion/gpu/fused_fc_elementwise_layernorm_kernel.cu" // NOLINT + +PD_CUSTOM_KERNEL_REGISTER(fused_fc_elementwise_layernorm, + metax_gpu, + ALL_LAYOUT, + phi::fusion::FusedFCElementwiseLayerNormKernel, + float, + double, + phi::float16) {} diff --git a/backends/metax_gpu/kernels/cuda_kernels/fused_gemm_epilogue_grad_kernel_register.cu b/backends/metax_gpu/kernels/cuda_kernels/fused_gemm_epilogue_grad_kernel_register.cu new file mode 100644 index 00000000000..2e8d33b964c --- /dev/null +++ b/backends/metax_gpu/kernels/cuda_kernels/fused_gemm_epilogue_grad_kernel_register.cu @@ -0,0 +1,26 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h" +#include "paddle/phi/kernels/fusion/gpu/fused_gemm_epilogue_grad_kernel.cu" //NOLINT + +PD_CUSTOM_KERNEL_REGISTER(fused_gemm_epilogue_grad, + metax_gpu, + ALL_LAYOUT, + phi::fusion::FusedGemmEpilogueGradKernel, + float, + double, + phi::float16, + phi::bfloat16) {} diff --git a/backends/metax_gpu/kernels/cuda_kernels/fused_gemm_epilogue_kernel_register.cu b/backends/metax_gpu/kernels/cuda_kernels/fused_gemm_epilogue_kernel_register.cu new file mode 100644 index 00000000000..9be5794c54f --- /dev/null +++ b/backends/metax_gpu/kernels/cuda_kernels/fused_gemm_epilogue_kernel_register.cu @@ -0,0 +1,26 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h" +#include "paddle/phi/kernels/fusion/gpu/fused_gemm_epilogue_kernel.cu" //NOLINT + +PD_CUSTOM_KERNEL_REGISTER(fused_gemm_epilogue, + metax_gpu, + ALL_LAYOUT, + phi::fusion::FusedGemmEpilogueKernel, + float, + double, + phi::float16, + phi::bfloat16) {} diff --git a/backends/metax_gpu/kernels/cuda_kernels/fused_linear_param_grad_add_kernel_register.cu b/backends/metax_gpu/kernels/cuda_kernels/fused_linear_param_grad_add_kernel_register.cu new file mode 100644 index 00000000000..c88f94625b7 --- /dev/null +++ b/backends/metax_gpu/kernels/cuda_kernels/fused_linear_param_grad_add_kernel_register.cu @@ -0,0 +1,24 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/fusion/gpu/fused_linear_param_grad_add_kernel.cu" //NOLINT +PD_CUSTOM_KERNEL_REGISTER(fused_linear_param_grad_add, + metax_gpu, + ALL_LAYOUT, + phi::fusion::FusedLinearParamGradAdd, + float, + double, + phi::float16, + phi::bfloat16) {} diff --git a/backends/metax_gpu/kernels/cuda_kernels/gammaln_grad_kernel.cu b/backends/metax_gpu/kernels/cuda_kernels/gammaln_grad_kernel.cu new file mode 100644 index 00000000000..850f0d68bac --- /dev/null +++ b/backends/metax_gpu/kernels/cuda_kernels/gammaln_grad_kernel.cu @@ -0,0 +1,28 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/gammaln_grad_kernel.h" +#include "paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h" + +PD_CUSTOM_KERNEL_REGISTER(gammaln_grad, + metax_gpu, + ALL_LAYOUT, + phi::GammalnGradKernel, + float, + double, + phi::float16, + phi::bfloat16) {} diff --git a/backends/metax_gpu/kernels/cuda_kernels/matmul_grad_kernel_register.cu b/backends/metax_gpu/kernels/cuda_kernels/matmul_grad_kernel_register.cu index f9eef9908ab..bb3b07d24d0 100644 --- a/backends/metax_gpu/kernels/cuda_kernels/matmul_grad_kernel_register.cu +++ b/backends/metax_gpu/kernels/cuda_kernels/matmul_grad_kernel_register.cu @@ -13,9 +13,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "../impl/matmul_grad_kernel_impl.h" #include "paddle/phi/common/complex.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/matmul_grad_kernel_impl.h" #include "paddle/phi/kernels/matmul_grad_kernel.h" PD_CUSTOM_KERNEL_REGISTER(matmul_grad, diff --git a/backends/metax_gpu/kernels/cuda_kernels/matmul_kernel_register.cu b/backends/metax_gpu/kernels/cuda_kernels/matmul_kernel_register.cu index 57c3a85b1ea..750cf2a9f36 100644 --- a/backends/metax_gpu/kernels/cuda_kernels/matmul_kernel_register.cu +++ b/backends/metax_gpu/kernels/cuda_kernels/matmul_kernel_register.cu @@ -18,7 +18,7 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/complex.h" #include "paddle/phi/core/kernel_registry.h" -#include "kernels/impl/matmul_kernel_impl.h" +#include "paddle/phi/kernels/impl/matmul_kernel_impl.h" #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) diff --git a/backends/metax_gpu/kernels/cuda_kernels/multihead_matmul_kernel.cu b/backends/metax_gpu/kernels/cuda_kernels/multihead_matmul_kernel.cu index 151c929e41c..998854140fc 100644 --- a/backends/metax_gpu/kernels/cuda_kernels/multihead_matmul_kernel.cu +++ b/backends/metax_gpu/kernels/cuda_kernels/multihead_matmul_kernel.cu @@ -15,11 +15,11 @@ #include #include -#include "kernels/funcs/blas/blas.h" #include "paddle/common/errors.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/multihead_matmul_functor.h" namespace phi { diff --git a/backends/metax_gpu/kernels/cuda_kernels/pad_grad_kernel_register.cu b/backends/metax_gpu/kernels/cuda_kernels/pad_grad_kernel_register.cu index 38b89fce698..f87f589a424 100644 --- a/backends/metax_gpu/kernels/cuda_kernels/pad_grad_kernel_register.cu +++ b/backends/metax_gpu/kernels/cuda_kernels/pad_grad_kernel_register.cu @@ -20,6 +20,8 @@ PD_CUSTOM_KERNEL_REGISTER(pad_grad, ALL_LAYOUT, phi::PadGradKernel, float, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex) {} + double, + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} diff --git a/backends/metax_gpu/kernels/dynload/cupti_lib_path.h b/backends/metax_gpu/kernels/dynload/cupti_lib_path.h deleted file mode 100644 index 6082fffd60e..00000000000 --- a/backends/metax_gpu/kernels/dynload/cupti_lib_path.h +++ /dev/null @@ -1,19 +0,0 @@ -// 2024 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights -// Reserved. -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#define CUPTI_LIB_PATH "/root/cu-bridge/CUDA_DIR/extras/CUPTI/lib64" diff --git a/backends/metax_gpu/kernels/dynload/dynamic_loader.cc b/backends/metax_gpu/kernels/dynload/dynamic_loader.cc deleted file mode 100644 index a23b7fa2aff..00000000000 --- a/backends/metax_gpu/kernels/dynload/dynamic_loader.cc +++ /dev/null @@ -1,938 +0,0 @@ -// 2024 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights -// Reserved. -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -// #include "paddle/phi/backends/dynload/dynamic_loader.h" -#include "kernels/dynload/dynamic_loader.h" - -#include - -#include -#include -#include -#include -// #include "paddle/phi/backends/dynload/cupti_lib_path.h" -#include "./dynload/cupti_lib_path.h" -#include "paddle/phi/common/port.h" -#include "paddle/phi/core/enforce.h" - -#if defined(_WIN32) -#include -#endif - -// TODO(wilber): The phi computing library requires a component to manage flags -// (maybe not use gflags). -#include "glog/logging.h" -#include "paddle/common/flags.h" - -COMMON_DECLARE_string(cudnn_dir); -COMMON_DECLARE_string(cuda_dir); -COMMON_DECLARE_string(cublas_dir); -COMMON_DECLARE_string(nccl_dir); -COMMON_DECLARE_string(cupti_dir); -COMMON_DECLARE_string(tensorrt_dir); -COMMON_DECLARE_string(mklml_dir); -COMMON_DECLARE_string(lapack_dir); -COMMON_DECLARE_string(mkl_dir); -COMMON_DECLARE_string(op_dir); -COMMON_DECLARE_string(cusparselt_dir); -COMMON_DECLARE_string(curand_dir); -COMMON_DECLARE_string(cusolver_dir); -COMMON_DECLARE_string(cusparse_dir); -COMMON_DECLARE_string(win_cuda_bin_dir); -#ifdef PADDLE_WITH_HIP - -PHI_DEFINE_string(miopen_dir, - "", - "Specify path for loading libMIOpen.so. For instance, " - "/opt/rocm/miopen/lib. If empty [default], dlopen " - "will search miopen from LD_LIBRARY_PATH"); - -PHI_DEFINE_string(rocm_dir, - "", - "Specify path for loading rocm library, such as librocblas, " - "libmiopen, libhipsparse. For instance, /opt/rocm/lib. " - "If default, dlopen will search rocm from LD_LIBRARY_PATH"); - -PHI_DEFINE_string(rccl_dir, - "", - "Specify path for loading rccl library, such as librccl.so. " - "For instance, /opt/rocm/rccl/lib. If default, " - "dlopen will search rccl from LD_LIBRARY_PATH"); -#endif - -// #ifdef PADDLE_WITH_FLAGCX -// COMMON_DECLARE_string(flagcx_dir); -// #endif - -// PHI_DEFINE_EXPORTED_string( -// flagcx_dir, // NOLINT -// "", -// "Specify path for loading libflagcx.so. For instance, " -// "For instance, /usr/local/flagcx/lib. If default, " -// "dlopen will search flagcx from LD_LIBRARY_PATH"); - -#ifdef PADDLE_WITH_XPU -PD_DEFINE_string(xpti_dir, "", "Specify path for loading libxpti.so."); -#endif - -namespace phi::dynload { - -struct PathNode { - PathNode() = default; - std::string path = ""; -}; - -static constexpr char cupti_lib_path[] = CUPTI_LIB_PATH; // NOLINT - -// NOTE: In order to adapt to the default installation path of cuda -#if defined(_WIN32) && defined(PADDLE_WITH_CUDA) -static constexpr char cuda_lib_path[] = CUDA_TOOLKIT_ROOT_DIR "/bin"; -#else -static constexpr char cuda_lib_path[] = "/usr/local/cuda/lib64"; // NOLINT -#endif - -static PathNode s_py_site_pkg_path; - -#if defined(_WIN32) && defined(PADDLE_WITH_CUDA) -static constexpr char* win_cudnn_lib = "cudnn64_" CUDNN_MAJOR_VERSION ".dll"; -static constexpr char* win_cublas_lib = - "cublas64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR - ".dll;cublas64_" CUDA_VERSION_MAJOR ".dll"; -#if CUDA_VERSION >= 11000 -static constexpr char* win_curand_lib = - "curand64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR - ".dll;curand64_" CUDA_VERSION_MAJOR ".dll;curand64_10.dll"; -static constexpr char* win_nvjpeg_lib = - "nvjpeg64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR - ".dll;nvjpeg64_" CUDA_VERSION_MAJOR ".dll;nvjpeg64_10.dll"; -static constexpr char* win_cusolver_lib = - "cusolver64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR - ".dll;cusolver64_" CUDA_VERSION_MAJOR - ".dll;cusolver64_11.dll;cusolver64_10.dll"; -static constexpr char* win_cusparse_lib = - "cusparse64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR - ".dll;cusparse64_" CUDA_VERSION_MAJOR ".dll;cusparse64_10.dll"; -static constexpr char* win_cufft_lib = - "cufft64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR - ".dll;cufft64_" CUDA_VERSION_MAJOR ".dll;cufft64_11.dll;cufft64_10.dll"; -#else -static constexpr char* win_curand_lib = - "curand64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR - ".dll;curand64_" CUDA_VERSION_MAJOR ".dll"; -static constexpr char* win_nvjpeg_lib = - "nvjpeg64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR - ".dll;nvjpeg64_" CUDA_VERSION_MAJOR ".dll"; -static constexpr char* win_cusolver_lib = - "cusolver64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR - ".dll;cusolver64_" CUDA_VERSION_MAJOR ".dll"; -static constexpr char* win_cusparse_lib = - "cusparse64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR - ".dll;cusparse64_" CUDA_VERSION_MAJOR ".dll"; -static constexpr char* win_cufft_lib = - "cufft64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR - ".dll;cufft64_" CUDA_VERSION_MAJOR ".dll"; -#endif // CUDA_VERSION -#endif - -static inline std::string join(const std::string& part1, - const std::string& part2) { -// directory separator -#if defined(_WIN32) - const char sep = '\\'; -#else - const char sep = '/'; -#endif - if (!part2.empty() && part2.front() == sep) { - return part2; - } - std::string ret; - ret.reserve(part1.size() + part2.size() + 1); - ret = part1; - if (!ret.empty() && ret.back() != sep) { - ret += sep; - } - ret += part2; - return ret; -} - -static inline std::vector split( - const std::string& str, const std::string separator = " ") { - std::vector str_list; - std::string::size_type firstPos = 0; - firstPos = str.find_first_not_of(separator, 0); - std::string::size_type lastPos = 0; - lastPos = str.find_first_of(separator, firstPos); - while (std::string::npos != firstPos && std::string::npos != lastPos) { - str_list.push_back(str.substr(firstPos, lastPos - firstPos)); - firstPos = str.find_first_not_of(separator, lastPos); - lastPos = str.find_first_of(separator, firstPos); - } - if (std::string::npos == lastPos) { - str_list.push_back(str.substr(firstPos, lastPos - firstPos)); - } - return str_list; -} - -void SetPaddleLibPath(const std::string& py_site_pkg_path) { - s_py_site_pkg_path.path = py_site_pkg_path; - VLOG(3) << "Set paddle lib path : " << py_site_pkg_path; -} - -static inline void* GetDsoHandleFromSpecificPath(const std::string& spec_path, - const std::string& dso_name, - int dynload_flags) { - void* dso_handle = nullptr; - if (!spec_path.empty()) { - // search xxx.so from custom path - VLOG(3) << "Try to find library: " << dso_name - << " from specific path: " << spec_path; - std::string dso_path = join(spec_path, dso_name); - dso_handle = dlopen(dso_path.c_str(), dynload_flags); - } - return dso_handle; -} - -static inline std::string FindLibAbsolutePath(const std::string& directory, - const std::string& filename) { - DIR* dir = opendir(directory.c_str()); - struct dirent* ent; - - if (dir != nullptr) { - while ((ent = readdir(dir)) != nullptr) { - if (ent->d_type == DT_REG || ent->d_type == DT_LNK) { - if (filename == std::string(ent->d_name)) { - closedir(dir); - return join(directory, ent->d_name); - } - } else if (ent->d_type == DT_DIR) { - if (strcmp(ent->d_name, ".") != 0 && strcmp(ent->d_name, "..") != 0) { - std::string res = - FindLibAbsolutePath(join(directory, ent->d_name) + "/", filename); - if (!res.empty()) { - closedir(dir); - return res; - } - } - } - } - closedir(dir); - } - return ""; -} - -static inline void* GetDsoHandleFromDefaultPath(const std::string& dso_path, - int dynload_flags) { - // default search from LD_LIBRARY_PATH/DYLD_LIBRARY_PATH - // and /usr/local/lib path - void* dso_handle = dlopen(dso_path.c_str(), dynload_flags); - VLOG(3) << "Try to find library: " << dso_path - << " from default system path."; - -// TODO(chenweihang): This path is used to search which libs? -// DYLD_LIBRARY_PATH is disabled after Mac OS 10.11 to -// bring System Integrity Projection (SIP), if dso_handle -// is null, search from default package path in Mac OS. -#if defined(__APPLE__) || defined(__OSX__) -#if defined(__arm__) || defined(__aarch64__) - if (nullptr == dso_handle) { - dso_handle = - dlopen(FindLibAbsolutePath("/opt/homebrew/Cellar/", dso_path).c_str(), - dynload_flags); - } -#else - if (nullptr == dso_handle) { - dso_handle = - dlopen(FindLibAbsolutePath("/usr/local/cuda/lib/", dso_path).c_str(), - dynload_flags); - } -#endif -#endif - - return dso_handle; -} - -/* - * We define three priorities for dynamic library search: - * - * First: Search for path specified by the user - * Second: Search the stheystem default path - * Third: Search for a special path corresponding to - * a specific library to adapt to changes and easy to expand. - */ - -static inline void* GetDsoHandleFromSearchPath( - const std::string& config_path, - const std::string& dso_name, - bool throw_on_error = true, - const std::vector& extra_paths = std::vector(), - const std::string& warning_msg = std::string()) { -#if !defined(_WIN32) - int dynload_flags = RTLD_LAZY | RTLD_LOCAL; -#else - int dynload_flags = 0; -#endif // !_WIN32 -#if defined(_WIN32) - std::vector cuda_bin_search_path = { - L"cublas", - L"cuda_nvrtc", - L"cuda_runtime", - L"cudnn", - L"cufft", - L"curand", - L"cusolver", - L"cusparse", - L"nvjitlink", - }; - for (auto search_path : cuda_bin_search_path) { - std::wstring_convert> converter; - std::wstring win_path_wstring = - converter.from_bytes(FLAGS_win_cuda_bin_dir); - search_path = win_path_wstring + L"\\" + search_path + L"\\bin"; - AddDllDirectory(search_path.c_str()); - } -#endif - std::vector dso_names = split(dso_name, ";"); - void* dso_handle = nullptr; - for (auto const& dso : dso_names) { - // 1. search in user config path by FLAGS - dso_handle = GetDsoHandleFromSpecificPath(config_path, dso, dynload_flags); - // 2. search in system default path - if (nullptr == dso_handle) { - dso_handle = GetDsoHandleFromDefaultPath(dso, dynload_flags); - } - // 3. search in extra paths - if (nullptr == dso_handle) { - for (auto const& path : extra_paths) { - VLOG(3) << "extra_paths: " << path; - dso_handle = GetDsoHandleFromSpecificPath(path, dso, dynload_flags); - } - } - if (nullptr != dso_handle) break; - } - - // 4. [If Failed for All dso_names] logging warning if exists - if (nullptr == dso_handle && !warning_msg.empty()) { - LOG(WARNING) << warning_msg; - } - - // 5. [If Failed for All dso_names] logging or throw error info - if (nullptr == dso_handle) { - auto error_msg = - "The third-party dynamic library (%s) that Paddle depends on is not " - "configured correctly. (error code is %s)\n" - " Suggestions:\n" - " 1. Check if the third-party dynamic library (e.g. CUDA, CUDNN) " - "is installed correctly and its version is matched with paddlepaddle " - "you installed.\n" - " 2. Configure third-party dynamic library environment variables as " - "follows:\n" - " - Linux: set LD_LIBRARY_PATH by `export LD_LIBRARY_PATH=...`\n" - " - Windows: set PATH by `set PATH=XXX;%%PATH%%`\n" - " - Mac: set DYLD_LIBRARY_PATH by `export DYLD_LIBRARY_PATH=...` " - "[Note: After Mac OS 10.11, using the DYLD_LIBRARY_PATH is " - "impossible unless System Integrity Protection (SIP) is disabled.]"; -#if !defined(_WIN32) - auto errorno = dlerror(); -#else - auto errorno = GetLastError(); -#endif // !_WIN32 - if (throw_on_error) { - // NOTE: Special error report case, no need to change its format - PADDLE_THROW( - common::errors::PreconditionNotMet(error_msg, dso_name, errorno)); - } else { - LOG(WARNING) << paddle::string::Sprintf(error_msg, dso_name, errorno); - } - } - - return dso_handle; -} - -void* GetCublasDsoHandle() { -#if defined(__APPLE__) || defined(__OSX__) - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcublas.dylib"); -#elif defined(_WIN32) && defined(PADDLE_WITH_CUDA) - if (CUDA_VERSION >= 11000 && CUDA_VERSION < 12000) { -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "cublas64_11.dll"); -#else - return GetDsoHandleFromSearchPath( - FLAGS_cuda_dir, win_cublas_lib, true, {cuda_lib_path}); -#endif - } else if (CUDA_VERSION >= 12000 && CUDA_VERSION < 13000) { -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "cublas64_12.dll"); -#else - return GetDsoHandleFromSearchPath( - FLAGS_cuda_dir, win_cublas_lib, true, {cuda_lib_path}); -#endif - } else { - std::string warning_msg( - "Your CUDA_VERSION is less than 11 or greater than 13, paddle " - "temporarily no longer supports"); - return nullptr; - } - -#elif defined(__linux__) && defined(PADDLE_WITH_CUDA) - if (CUDA_VERSION >= 11000 && CUDA_VERSION < 12000) { -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath(FLAGS_cublas_dir, "libcublas.so.11"); -#else - // VLOG(0) << "dynload:libmcblas.so: CUDA_VERSION=11000-12000 start" ; - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libmcblas.so"); - // VLOG(0) << "dynload:libmcblas.so: CUDA_VERSION=11000-12000 end" ; -#endif - } else if (CUDA_VERSION >= 12000 && CUDA_VERSION < 13000) { -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath(FLAGS_cublas_dir, "libcublas.so.12"); -#else - // VLOG(0) << "dynload:libmcblas.so: CUDA_VERSION=12000-13000 start" ; - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libmcblas.so"); - // VLOG(0) << "dynload:libmcblas.so: CUDA_VERSION=12000-13000 end" ; -#endif - } else { - std::string warning_msg( - "Your CUDA_VERSION is less than 11 or greater than 13, paddle " - "temporarily no longer supports"); - return nullptr; - } -#elif defined(PADDLE_WITH_HIP) - return GetDsoHandleFromSearchPath(FLAGS_rocm_dir, "librocblas.so"); -#else - // VLOG(0) << "dynload:libmcblas.so: CUDA_VERSION=else start" ; - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libmcblas.so"); - // VLOG(0) << "dynload:libmcblas.so: CUDA_VERSION=else end" ; -// return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libmcblas.so"); -#endif -} - -void* GetCublasLtDsoHandle() { -// APIs available after CUDA 10.1 -#if defined(__linux__) && defined(PADDLE_WITH_CUDA) - if (CUDA_VERSION >= 11000 && CUDA_VERSION < 12000) { -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath(FLAGS_cublas_dir, "libcublasLt.so.11"); -#else - // return GetDsoHandleFromSearchPath(FLAGS_cublas_dir, "libcublasLt.so"); - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libmcblasLt.so"); -#endif - } else if (CUDA_VERSION >= 12000 && CUDA_VERSION < 13000) { -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath(FLAGS_cublas_dir, "libcublasLt.so.12"); -#else - // return GetDsoHandleFromSearchPath(FLAGS_cublas_dir, "libcublasLt.so"); - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libmcblasLt.so"); -#endif - } else { - std::string warning_msg( - "Your CUDA_VERSION is less than 11 or greater than 13, paddle " - "temporarily no longer supports"); - return nullptr; - } -#elif defined(_WIN32) && defined(PADDLE_WITH_CUDA) - if (CUDA_VERSION >= 11000 && CUDA_VERSION < 12000) { -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "cublasLt64_11.dll"); -#else - return GetDsoHandleFromSearchPath( - FLAGS_cuda_dir, win_cublas_lib, true, {cuda_lib_path}); -#endif - } else if (CUDA_VERSION >= 12000 && CUDA_VERSION < 13000) { -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "cublasLt64_12.dll"); -#else - return GetDsoHandleFromSearchPath( - FLAGS_cuda_dir, win_cublas_lib, true, {cuda_lib_path}); -#endif - } else { - std::string warning_msg( - "Your CUDA_VERSION is less than 11 or greater than 12, paddle " - "temporarily no longer supports"); - return nullptr; - } -#elif !defined(__linux__) && defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10010 - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcublasLt.so"); -#elif defined(PADDLE_WITH_HIP) - return GetDsoHandleFromSearchPath(FLAGS_rocm_dir, "libhipblaslt.so"); -#else - std::string warning_msg( - "Your CUDA_VERSION less 10.1, not support CublasLt. " - "If you want to use CublasLt, please upgrade CUDA and rebuild " - "PaddlePaddle."); - return nullptr; -#endif -} - -void* GetCUDNNDsoHandle() { -#if defined(__APPLE__) || defined(__OSX__) - std::string mac_warn_meg( - "Note: [Recommend] copy cudnn into /usr/local/cuda/ \n " - "For instance, sudo tar -xzf " - "cudnn-7.5-osx-x64-v5.0-ga.tgz -C /usr/local \n sudo " - "chmod a+r /usr/local/cuda/include/cudnn.h " - "/usr/local/cuda/lib/libcudnn*"); - return GetDsoHandleFromSearchPath( - FLAGS_cudnn_dir, "libcudnn.dylib", false, {}, mac_warn_meg); -#elif defined(_WIN32) && defined(PADDLE_WITH_CUDA) - std::string win_warn_meg( - "Note: [Recommend] copy cudnn into CUDA installation directory. \n " - "For instance, download cudnn-10.0-windows10-x64-v7.6.5.32.zip from " - "NVIDIA's official website, \n" - "then, unzip it and copy it into C:\\Program Files\\NVIDIA GPU Computing " - "Toolkit\\CUDA\\v10.0\n" - "You should do this according to your CUDA installation directory and " - "CUDNN version."); - if (CUDA_VERSION >= 11000 && CUDA_VERSION < 12030) { -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath( - FLAGS_cuda_dir, "cudnn64_8.dll", true, {cuda_lib_path}, win_warn_meg); -#else - return GetDsoHandleFromSearchPath( - FLAGS_cuda_dir, win_cudnn_lib, true, {cuda_lib_path}, win_warn_meg); -#endif - } else if (CUDA_VERSION >= 12030) { -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath( - FLAGS_cuda_dir, "cudnn64_9.dll", true, {cuda_lib_path}, win_warn_meg); -#else - return GetDsoHandleFromSearchPath( - FLAGS_cuda_dir, win_cudnn_lib, true, {cuda_lib_path}, win_warn_meg); -#endif - } -#elif defined(PADDLE_WITH_HIP) - return GetDsoHandleFromSearchPath(FLAGS_miopen_dir, "libMIOpen.so", false); -#else -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - if (CUDA_VERSION >= 12030) { - return GetDsoHandleFromSearchPath( - FLAGS_cudnn_dir, "libcudnn.so.9", false, {cuda_lib_path}); - } else { - return GetDsoHandleFromSearchPath( - FLAGS_cudnn_dir, "libcudnn.so.8", false, {cuda_lib_path}); - } -#else - return GetDsoHandleFromSearchPath( - FLAGS_cudnn_dir, "libmcdnn.so", false, {cuda_lib_path}); -#endif -#endif -} - -void* GetCUPTIDsoHandle() { -#if defined(__APPLE__) || defined(__OSX__) - return GetDsoHandleFromSearchPath( - FLAGS_cupti_dir, "libcupti.dylib", false, {cupti_lib_path}); -#elif defined(__linux__) && defined(PADDLE_WITH_CUDA) - if (CUDA_VERSION >= 11000 && CUDA_VERSION < 12000) { -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath( - FLAGS_cupti_dir, "libcupti.so.11.8", false, {cupti_lib_path}); -#else - return GetDsoHandleFromSearchPath( - FLAGS_cupti_dir, "libmcpti.so", false, {cupti_lib_path}); -#endif - - } else if (CUDA_VERSION >= 12000 && CUDA_VERSION < 13000) { -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath( - FLAGS_cupti_dir, "libcupti.so.12", false, {cupti_lib_path}); -#else - return GetDsoHandleFromSearchPath( - FLAGS_cupti_dir, "libmcpti.so", false, {cupti_lib_path}); -#endif - } else { - std::string warning_msg( - "Your CUDA_VERSION is less than 11 or greater than 13, paddle " - "temporarily no longer supports"); - return nullptr; - } -#else - return GetDsoHandleFromSearchPath( - FLAGS_cupti_dir, "libmcpti.so", false, {cupti_lib_path}); -#endif -} - -void* GetCurandDsoHandle() { -#if defined(__APPLE__) || defined(__OSX__) - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcurand.dylib"); -#elif defined(_WIN32) && defined(PADDLE_WITH_CUDA) -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath( - FLAGS_cuda_dir, "curand64_10.dll", true, {cuda_lib_path}); -#else - return GetDsoHandleFromSearchPath( - FLAGS_cuda_dir, win_curand_lib, true, {cuda_lib_path}); -#endif -#elif defined(PADDLE_WITH_HIP) - return GetDsoHandleFromSearchPath(FLAGS_rocm_dir, "libhiprand.so"); -#else -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath(FLAGS_curand_dir, "libcurand.so.10"); -#else - return GetDsoHandleFromSearchPath(FLAGS_curand_dir, "libcurand.so"); -#endif - -#endif -} - -#ifdef PADDLE_WITH_HIP -void* GetROCFFTDsoHandle() { -#if defined(__APPLE__) || defined(__OSX__) - return GetDsoHandleFromSearchPath(FLAGS_rocm_dir, "librocfft.dylib"); -#else - return GetDsoHandleFromSearchPath(FLAGS_rocm_dir, "libhipfft.so"); -#endif -} -#endif - -void* GetNvjpegDsoHandle() { -#if defined(__APPLE__) || defined(__OSX__) - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libnvjpeg.dylib"); -#elif defined(_WIN32) && defined(PADDLE_WITH_CUDA) - return GetDsoHandleFromSearchPath( - FLAGS_cuda_dir, win_nvjpeg_lib, true, {cuda_lib_path}); -#else - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libnvjpeg.so"); -#endif -} - -void* GetCusolverDsoHandle() { -#if defined(__APPLE__) || defined(__OSX__) - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcusolver.dylib"); -#elif defined(_WIN32) && defined(PADDLE_WITH_CUDA) -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath( - FLAGS_cuda_dir, "cusolver64_11.dll", true, {cuda_lib_path}); -#else - return GetDsoHandleFromSearchPath( - FLAGS_cuda_dir, win_cusolver_lib, true, {cuda_lib_path}); -#endif -#elif defined(PADDLE_WITH_HIP) - return GetDsoHandleFromSearchPath(FLAGS_rocm_dir, "librocsolver.so"); -#else -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcusolver.so.11"); -#else - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libmcsolver.so"); -#endif -#endif -} - -void* GetCusparseDsoHandle() { -#if defined(__APPLE__) || defined(__OSX__) - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcusparse.dylib"); -#elif defined(_WIN32) && defined(PADDLE_WITH_CUDA) - if (CUDA_VERSION >= 11000 && CUDA_VERSION < 12000) { -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "cusparse64_11.dll"); -#else - return GetDsoHandleFromSearchPath( - FLAGS_cuda_dir, win_cusparse_lib, true, {cuda_lib_path}); -#endif - } else if (CUDA_VERSION >= 12000 && CUDA_VERSION < 13000) { -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "cusparse64_12.dll"); -#else - return GetDsoHandleFromSearchPath( - FLAGS_cuda_dir, win_cusparse_lib, true, {cuda_lib_path}); -#endif - } else { - std::string warning_msg( - "Your CUDA_VERSION is less than 11 or greater than 13, paddle " - "temporarily no longer supports"); - return nullptr; - } -#elif defined(__linux__) && defined(PADDLE_WITH_CUDA) - if (CUDA_VERSION >= 11000 && CUDA_VERSION < 12000) { -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath(FLAGS_cusparse_dir, "libcusparse.so.11"); -#else - return GetDsoHandleFromSearchPath(FLAGS_cusparse_dir, "libmcsparse.so"); -#endif - } else if (CUDA_VERSION >= 12000 && CUDA_VERSION < 13000) { -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath(FLAGS_cusparse_dir, "libcusparse.so.12"); -#else - return GetDsoHandleFromSearchPath(FLAGS_cusparse_dir, "libmcsparse.so"); -#endif - } else { - std::string warning_msg( - "Your CUDA_VERSION is less than 11 or greater than 12, paddle " - "temporarily no longer."); - return nullptr; - } -#elif defined(PADDLE_WITH_HIP) - return GetDsoHandleFromSearchPath(FLAGS_rocm_dir, "librocsparse.so"); -#else - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libmcsparse.so"); -#endif -} - -void* GetNVRTCDsoHandle() { -#if defined(__APPLE__) || defined(__OSX__) - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libnvrtc.dylib", false); -#elif defined(PADDLE_WITH_HIP) - return GetDsoHandleFromSearchPath(FLAGS_rocm_dir, "libamdhip64.so", false); -#else - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libmcruntime.so", false); -#endif -} - -void* GetCUDADsoHandle() { -#if defined(__APPLE__) || defined(__OSX__) - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcuda.dylib", false); -#elif defined(PADDLE_WITH_HIP) - return GetDsoHandleFromSearchPath(FLAGS_rocm_dir, "libamdhip64.so", false); -#elif defined(_WIN32) - char system32_dir[MAX_PATH]; - GetSystemDirectory(system32_dir, MAX_PATH); - return GetDsoHandleFromSearchPath(system32_dir, "nvcuda.dll"); -#else - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libmcruntime.so", false); -#endif -} - -void* GetWarpCTCDsoHandle() { - std::string warpctc_dir = ""; - if (!s_py_site_pkg_path.path.empty()) { - warpctc_dir = s_py_site_pkg_path.path; - } -#if defined(__APPLE__) || defined(__OSX__) - return GetDsoHandleFromSearchPath(warpctc_dir, "libwarpctc.dylib"); -#elif defined(_WIN32) - return GetDsoHandleFromSearchPath(warpctc_dir, "warpctc.dll"); -#else - return GetDsoHandleFromSearchPath(warpctc_dir, "libwarpctc.so"); -#endif -} - -void* GetWarpRNNTDsoHandle() { - std::string warprnnt_dir = ""; - if (!s_py_site_pkg_path.path.empty()) { - warprnnt_dir = s_py_site_pkg_path.path; - } -#if defined(__APPLE__) || defined(__OSX__) - return GetDsoHandleFromSearchPath(warprnnt_dir, "libwarprnnt.dylib"); -#elif defined(_WIN32) - return GetDsoHandleFromSearchPath(warprnnt_dir, "warprnnt.dll"); -#else - return GetDsoHandleFromSearchPath(warprnnt_dir, "libwarprnnt.so"); -#endif -} - -void* GetFlashAttnDsoHandle() { - std::string flashattn_dir = ""; - if (!s_py_site_pkg_path.path.empty()) { - flashattn_dir = s_py_site_pkg_path.path; - } -#if defined(__APPLE__) || defined(__OSX__) - return GetDsoHandleFromSearchPath(flashattn_dir, "libflashattn.dylib"); -#elif defined(_WIN32) - return GetDsoHandleFromSearchPath(flashattn_dir, "flashattn.dll"); -#else - return GetDsoHandleFromSearchPath(flashattn_dir, "libmcFlashAttn.so"); -#endif -} - -void* GetFlashAttnV3DsoHandle() { - std::string flashattn_dir = ""; - if (!s_py_site_pkg_path.path.empty()) { - flashattn_dir = s_py_site_pkg_path.path; - } -#if defined(__APPLE__) || defined(__OSX__) - return GetDsoHandleFromSearchPath(flashattn_dir, "libflashattnv3.dylib"); -#elif defined(_WIN32) - return GetDsoHandleFromSearchPath(flashattn_dir, "flashattnv3.dll"); -#else - return GetDsoHandleFromSearchPath(flashattn_dir, "libflashattnv3.so"); -#endif -} - -void* GetAfsApiDsoHandle() { - std::string afsapi_dir = ""; - if (!s_py_site_pkg_path.path.empty()) { - afsapi_dir = s_py_site_pkg_path.path; - } -#if defined(__APPLE__) || defined(__OSX__) || defined(_WIN32) - return NULL; -#else - return GetDsoHandleFromSearchPath(afsapi_dir, "libafs-api-so.so"); -#endif -} - -void* GetNCCLDsoHandle() { -#ifdef PADDLE_WITH_HIP - std::string warning_msg( - "You may need to install 'rccl' from ROCM official website: " - "https://rocmdocs.amd.com/en/latest/Installation_Guide/" - "Installation-Guide.html before install PaddlePaddle."); -#else - std::string warning_msg( - "You may need to install 'nccl2' from NVIDIA official website: " - "https://developer.nvidia.com/nccl/nccl-download " - "before install PaddlePaddle."); -#endif - -#if defined(__APPLE__) || defined(__OSX__) - return GetDsoHandleFromSearchPath( - FLAGS_nccl_dir, "libnccl.dylib", true, {}, warning_msg); -#elif defined(PADDLE_WITH_HIP) && defined(PADDLE_WITH_RCCL) - return GetDsoHandleFromSearchPath( - FLAGS_rccl_dir, "librccl.so", true, {}, warning_msg); -#else -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath( - FLAGS_nccl_dir, "libnccl.so;libnccl.so.2", true, {}, warning_msg); -#else - return GetDsoHandleFromSearchPath( - FLAGS_nccl_dir, "libmccl.so", true, {}, warning_msg); -#endif - -#endif -} - -// void* GetFLAGCXDsoHandle() { -// #ifdef PADDLE_WITH_FLAGCX -// return GetDsoHandleFromSearchPath(FLAGS_flagcx_dir, "libflagcx.so"); -// #else -// return nullptr; -// #endif -// } - -void* GetTensorRtDsoHandle() { -#if defined(__APPLE__) || defined(__OSX__) - return GetDsoHandleFromSearchPath(FLAGS_tensorrt_dir, "libnvinfer.dylib"); -#elif defined(_WIN32) - return GetDsoHandleFromSearchPath(FLAGS_mklml_dir, "nvinfer.dll"); -#else - return GetDsoHandleFromSearchPath(FLAGS_tensorrt_dir, "libnvinfer.so"); -#endif -} - -void* GetMKLMLDsoHandle() { -#if defined(__APPLE__) || defined(__OSX__) - return GetDsoHandleFromSearchPath(FLAGS_mklml_dir, "libmklml_intel.dylib"); -#elif defined(_WIN32) - return GetDsoHandleFromSearchPath(FLAGS_mklml_dir, "mklml.dll"); -#else - return GetDsoHandleFromSearchPath(FLAGS_mklml_dir, "libmklml_intel.so"); -#endif -} - -void* GetLAPACKDsoHandle() { -#if defined(__APPLE__) || defined(__OSX__) -#if defined(__arm__) || defined(__aarch64__) - return GetDsoHandleFromSearchPath(FLAGS_lapack_dir, "liblapack.dylib"); -#else - return GetDsoHandleFromSearchPath(FLAGS_lapack_dir, "liblapack.3.dylib"); -#endif -#elif defined(_WIN32) - return GetDsoHandleFromSearchPath(FLAGS_lapack_dir, "liblapack.dll"); -#else - return GetDsoHandleFromSearchPath(FLAGS_lapack_dir, "liblapack.so.3"); -#endif -} - -void* GetOpDsoHandle(const std::string& dso_name) { - return GetDsoHandleFromSearchPath(FLAGS_op_dir, dso_name); -} - -void* GetNvtxDsoHandle() { -#if defined(__APPLE__) || defined(__OSX__) - PADDLE_THROW(common::errors::Unimplemented("Nvtx do not support Apple.")); -#elif defined(_WIN32) - PADDLE_THROW(common::errors::Unimplemented("Nvtx do not support Windows.")); -#elif !defined(PADDLE_WITH_CUDA) - PADDLE_THROW( - common::errors::Unimplemented("Nvtx do not support without CUDA.")); -#else - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libnvToolsExt.so"); -#endif -} - -void* GetCUFFTDsoHandle() { -#if defined(__APPLE__) || defined(__OSX__) - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcufft.dylib"); -#elif defined(__linux__) && defined(PADDLE_WITH_CUDA) - if (CUDA_VERSION >= 11000 && CUDA_VERSION < 12000) { -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcufft.so.10"); -#else - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcufft.so"); -#endif - } else if (CUDA_VERSION >= 12000 && CUDA_VERSION < 13000) { - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcufft.so.11"); - } else { - std::string warning_msg( - "Your CUDA_VERSION is less than 11 or greater than 13, paddle " - "temporarily no longer."); - return nullptr; - } -#elif defined(_WIN32) && defined(PADDLE_WITH_CUDA) - if (CUDA_VERSION >= 11000 && CUDA_VERSION < 12000) { -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "cufft64_10.dll"); -#else - return GetDsoHandleFromSearchPath( - FLAGS_cuda_dir, win_cufft_lib, true, {cuda_lib_path}); -#endif - } else if (CUDA_VERSION >= 12000 && CUDA_VERSION < 13000) { -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "cufft64_11.dll"); -#else - return GetDsoHandleFromSearchPath( - FLAGS_cuda_dir, win_cufft_lib, true, {cuda_lib_path}); -#endif - } else { - std::string warning_msg( - "Your CUDA_VERSION is less than 11 or greater than 13, paddle " - "temporarily no longer supports"); - return nullptr; - } -#else - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcufft.so"); -#endif -} - -void* GetMKLRTDsoHandle() { -#if defined(__APPLE__) || defined(__OSX__) - return GetDsoHandleFromSearchPath(FLAGS_mkl_dir, "libmkl_rt.dylib"); -#elif defined(_WIN32) - return GetDsoHandleFromSearchPath(FLAGS_mkl_dir, "mkl_rt.dll"); -#else - return GetDsoHandleFromSearchPath(FLAGS_mkl_dir, "libmkl_rt.so"); -#endif -} - -void* GetCusparseLtDsoHandle() { -// APIs available after CUDA 11.2 -#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11020 && 0 - return GetDsoHandleFromSearchPath(FLAGS_cusparselt_dir, "libcusparseLt.so"); -#else - std::string warning_msg( - "Your CUDA_VERSION less 11.2, not support cusparseLt. " - "If you want to use cusparseLt, please upgrade CUDA and rebuild " - "PaddlePaddle."); - return nullptr; -#endif -} - -void* GetXPTIDsoHandle() { -#ifdef PADDLE_WITH_XPTI - return GetDsoHandleFromSearchPath(FLAGS_xpti_dir, "libxpti.so"); -#else - return nullptr; -#endif -} -} // namespace phi::dynload diff --git a/backends/metax_gpu/kernels/dynload/dynamic_loader.h b/backends/metax_gpu/kernels/dynload/dynamic_loader.h deleted file mode 100644 index a5d3d0ff76c..00000000000 --- a/backends/metax_gpu/kernels/dynload/dynamic_loader.h +++ /dev/null @@ -1,61 +0,0 @@ -// 2024 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights -// Reserved. -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include - -#include "paddle/utils/test_macros.h" -namespace phi { -namespace dynload { - -#ifndef _WIN32 -#define DECLARE_TYPE(__name, ...) decltype(__name(__VA_ARGS__)) -#else -#define DECLARE_TYPE(__name, ...) decltype(auto) -#endif - -void* GetCublasDsoHandle(); -void* GetCublasLtDsoHandle(); -TEST_API void* GetCUDNNDsoHandle(); -void* GetCUPTIDsoHandle(); -void* GetCurandDsoHandle(); -void* GetNvjpegDsoHandle(); -void* GetCusolverDsoHandle(); -void* GetCusparseDsoHandle(); -void* GetNVRTCDsoHandle(); -void* GetCUDADsoHandle(); -void* GetWarpCTCDsoHandle(); -void* GetWarpRNNTDsoHandle(); -void* GetFlashAttnDsoHandle(); -void* GetFlashAttnV3DsoHandle(); -void* GetNCCLDsoHandle(); -// void* GetFLAGCXDsoHandle(); -void* GetTensorRtDsoHandle(); -void* GetMKLMLDsoHandle(); -void* GetLAPACKDsoHandle(); -void* GetOpDsoHandle(const std::string& dso_name); -void* GetNvtxDsoHandle(); -void* GetCUFFTDsoHandle(); -void* GetMKLRTDsoHandle(); -void* GetROCFFTDsoHandle(); -void* GetCusparseLtDsoHandle(); -void* GetXPTIDsoHandle(); -void* GetAfsApiDsoHandle(); - -void SetPaddleLibPath(const std::string&); - -} // namespace dynload -} // namespace phi diff --git a/backends/metax_gpu/kernels/funcs/affine_grid_utils.h b/backends/metax_gpu/kernels/funcs/affine_grid_utils.h index c137d9ad468..b973d75a9be 100644 --- a/backends/metax_gpu/kernels/funcs/affine_grid_utils.h +++ b/backends/metax_gpu/kernels/funcs/affine_grid_utils.h @@ -14,8 +14,8 @@ #pragma once -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/backends/metax_gpu/kernels/funcs/blas/blas.cc b/backends/metax_gpu/kernels/funcs/blas/blas.cc deleted file mode 100644 index 098a0400552..00000000000 --- a/backends/metax_gpu/kernels/funcs/blas/blas.cc +++ /dev/null @@ -1,59 +0,0 @@ -// 2024 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights -// Reserved. -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// clang-format off -#include "funcs/blas/blas.h" // NOLINT -#include "paddle/phi/core/enforce.h" -// clang-format on -namespace phi { -namespace funcs { -MatDescriptor CreateMatrixDescriptor(const DDim &tensor_dim, - int num_flatten_cols, - bool trans) { - PADDLE_ENFORCE_GT( - tensor_dim.size(), - 1, - phi::errors::InvalidArgument("The tensor dim size should be greater " - "than 1, but reveived dim size is %d", - tensor_dim.size())); - MatDescriptor retv; - if (num_flatten_cols > 1) { - auto flatten_dim = common::flatten_to_2d(tensor_dim, num_flatten_cols); - retv.height_ = flatten_dim[0]; - retv.width_ = flatten_dim[1]; - } else { - if (tensor_dim.size() == 2) { - retv.height_ = tensor_dim[0]; - retv.width_ = tensor_dim[1]; - } else { - auto dim_vec = common::vectorize(tensor_dim); - retv.batch_size_ = 1; - for (size_t i = 0; i < dim_vec.size() - 2; ++i) { - retv.batch_size_ *= dim_vec[i]; - } - retv.height_ = dim_vec[dim_vec.size() - 2]; - retv.width_ = dim_vec[dim_vec.size() - 1]; - retv.stride_ = retv.height_ * retv.width_; - } - } - if (trans) { - std::swap(retv.width_, retv.height_); - } - retv.trans_ = trans; - return retv; -} -} // namespace funcs -} // namespace phi diff --git a/backends/metax_gpu/kernels/funcs/blas/blas.h b/backends/metax_gpu/kernels/funcs/blas/blas.h deleted file mode 100644 index 75ea8c921e2..00000000000 --- a/backends/metax_gpu/kernels/funcs/blas/blas.h +++ /dev/null @@ -1,631 +0,0 @@ -// 2024 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights -// Reserved. -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/phi/core/dense_tensor.h" - -#ifdef PADDLE_WITH_MKLML -#include "paddle/phi/backends/dynload/mklml.h" -#endif - -#ifdef PADDLE_WITH_LIBXSMM -#include -#endif - -#if defined(PADDLE_USE_OPENBLAS) || defined(PADDLE_USE_REFERENCE_CBLAS) -#include -#endif -// #include "paddle/phi/core/enforce_metax.h" -namespace phi { -namespace funcs { - -/** - * Matrix Descriptor of a memory buffer. - * - * It is used for Blas::MatMul. MatMul operator can be batched. - * if Mat A is [BatchSize, H, W], Mat B is [BatchSize, H, W]. It will be a - * `batch_size` times of GEMM. The batched GEMM could be faster base on the - * implementation of the blas library. The batch size could be zero. If any - * matrix of `matmul` has a batch size, there will be a batched GEMM, too. e.g., - * Mat A is [BatchSize, H1, W2], and Mat B [H2, W2], The result matrix wil be - * [BatchSize, H1, W2] - * - * The boolean flag, `trans`, describe the memory is the transpose of matrix or - * not. If the trans is true, the last two dims of matrix are transposed. The - * memory layout of the matrix is [Width, Height] or [BatchSize, Width, Height]. - * - * The MatDescriptor is not only the dimension or shape of a matrix, it also - * contains the layout, stride of matrix. It is clearer to have a structure than - * reuse `DDim`. - */ -struct MatDescriptor { - int64_t height_; - int64_t width_; - int64_t stride_{0}; - int64_t batch_size_{0}; - bool trans_; -}; - -/** - * Create Matrix Descriptor from a tensor dim, num_flatten_cols, and transpose - * flag - * - * @param tensor_dim: The dimension of the tensor. The rank of this dimension - * must larger than 1. - * - * @param num_flatten_cols: Reshape a tensor to a matrix. The matrix's first - * dimension(column length) will be the product of tensor's first `num_col_dims` - * dimensions. If num_flatten_cols is zero, the first N-2 dimension will be the - * batch_size of descriptor. - * - * @param trans: True if the matrix is transposed. - */ -extern MatDescriptor CreateMatrixDescriptor(const DDim& tensor_dim, - int num_flatten_cols, - bool trans); - -template -class Blas { - public: - explicit Blas(const DeviceContext& context) : dev_ctx_(context) {} - - template - void GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int64_t M, - int64_t N, - int64_t K, - T alpha, - const T* A, - const T* B, - T beta, - T* C) const; - - template - void GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int64_t M, - int64_t N, - int64_t K, - U alpha, - const T* A, - const T* B, - U beta, - T* C) const; - - template - void GEMM(bool transA, - bool transB, - int M, - int N, - int K, - T alpha, - const T* A, - int lda, - const T* B, - int ldb, - T beta, - T* C, - int ldc) const; - - template - void GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int M, - int N, - int K, - T alpha, - const T* A, - int lda, - const T* B, - int ldb, - T beta, - T* C, - int ldc) const; - -#ifdef PADDLE_WITH_MKLML // @{ Group MKLML: class Blas - template - T* GEMM_ALLOC(const CBLAS_IDENTIFIER id, - const int M, - const int N, - const int K) const; - - template - void GEMM_PACK(const CBLAS_IDENTIFIER id, - const CBLAS_TRANSPOSE trans, - int M, - int N, - int K, - const T alpha, - const T* src, - const int ld, - T* dst) const; - - template - void GEMM_COMPUTE(int transA, - int transB, - int M, - int N, - int K, - const T* A, - const int lda, - const T* B, - const int ldb, - T beta, - T* C, - const int ldc) const; - - template - void GEMM_FREE(T* data) const; - - template - void CSRMM(const char* transa, - const int* m, - const int* n, - const int* k, - const T* alpha, - const char* matdescra, - const T* val, - const int* indx, - const int* pntrb, - const int* pntre, - const T* b, - const int* ldb, - const T* beta, - T* c, - const int* ldc) const; - -#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) - template - void MatMulWithHead(const phi::DenseTensor& mat_a, - const MatDescriptor& dim_a, - const phi::DenseTensor& mat_b, - const MatDescriptor& dim_b, - T alpha, - int head_number, - phi::DenseTensor* mat_out, - T beta, - bool mat_y_split_vertical) const; -#endif -#endif // @} End Group MKLML: class Blas - - template - void MatMul(const int M, - const int N, - const int K, - const T* A, - const T* B, - T* C) const; - - template - void MatMul(const phi::DenseTensor& mat_a, - bool trans_a, - const phi::DenseTensor& mat_b, - bool trans_b, - T alpha, - phi::DenseTensor* mat_out, - T beta) const; - - template - void MatMul(const phi::DenseTensor& mat_a, - bool trans_a, - const phi::DenseTensor& mat_b, - bool trans_b, - phi::DenseTensor* mat_out) const { - MatMul(mat_a, - trans_a, - mat_b, - trans_b, - static_cast(1.0), - mat_out, - static_cast(0.0)); - } - - template - void MatMul(const phi::DenseTensor& mat_a, - const phi::DenseTensor& mat_b, - phi::DenseTensor* mat_out) const { - this->template MatMul(mat_a, false, mat_b, false, mat_out); - } - - template - void AXPY(int n, T alpha, const T* x, T* y) const; - - template - void VADD(int n, const T* x, const T* y, T* z) const; - - template - void VSUB(int n, const T* x, const T* y, T* z) const; - - template - void VMUL(int n, const T* x, const T* y, T* z) const; - - template - void VDIV(int n, const T* x, const T* y, T* z) const; - - template - void VCOPY(int n, const T* x, T* y) const; - - template - void VEXP(int n, const T* x, T* y) const; - - template - void VSQUARE(int n, const T* x, T* y) const; - - template - void VPOW(int n, const T* x, T alpha, T* y) const; - - template - void GEMV(bool trans_a, - int M, - int N, - T alpha, - const T* A, - const T* B, - T beta, - T* C) const; - - template - T DOT(int n, const T* x, const T* y) const; - - template - void CUDOT( - int n, const T* x, int incx, const T* y, int incy, T* result) const; - template - void SCAL(int n, const T a, T* x) const; - - template - T ASUM(int n, T* x, int inc) const; - - template - void BatchedGEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int64_t M, - int64_t N, - int64_t K, - T alpha, - const T* A, - const T* B, - T beta, - T* C, - int64_t batchCount, - int64_t strideA, - int64_t strideB) const; - - template - void BatchedGEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int64_t M, - int64_t N, - int64_t K, - U alpha, - const T* A, - const T* B, - U beta, - T* C, - int64_t batchCount, - int64_t strideA, - int64_t strideB) const; - - template - void BatchedGEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int M, - int N, - int K, - T alpha, - const T** A, - const T** B, - T beta, - T** C, - int batchCount) const; - -#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) && \ - !defined(PADDLE_WITH_HIP) - template - void BatchedGEMMWithHead(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int W1, - int H1, - int W2, - int H2, - T alpha, - const T* A, - const T* B, - T beta, - T* C, - int batchCount, - int64_t strideA, - int64_t strideB, - int64_t head_number, - bool split_b_vertical) const; -#endif - - template - void MatMul(const phi::DenseTensor& mat_a, - const MatDescriptor& dim_a, - const phi::DenseTensor& mat_b, - const MatDescriptor& dim_b, - T alpha, - phi::DenseTensor* mat_out, - T beta) const; - - template - void MatMul(const T* mat_a, - const MatDescriptor& dim_a, - const T* mat_b, - const MatDescriptor& dim_b, - T alpha, - T* mat_out, - T beta) const; - - template - void VINV(int n, const T* a, T* y) const; - - template - void VMERF(int n, const T* a, T* y, int64_t mode) const; - - template - void TRSM(CBLAS_SIDE side, - CBLAS_UPLO uplo, - CBLAS_TRANSPOSE transA, - CBLAS_DIAG diag, - int M, - int N, - T alpha, - const T* A, - int lda, - T* B, - int ldb) const; - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - template - void BatchedGETRF(int n, T** a, int* ipiv, int* info, int batch_size) const; - - template - void BatchedGETRI(int n, - const T** a, - const int* ipiv, - T** a_inv, - int* info, - int batch_size) const; - - template - void BatchedMatInv( - int n, const T** a, T** a_inv, int* info, int batch_size) const; - - // cuBlas solve - template - void BatchedGETRS(CBLAS_TRANSPOSE trans, - int n, - int nrhs, - const T** a, - int lda, - int* ipiv, - T** b, - int ldb, - int* info, - int batch_size) const; - - // cuBlas triangular_solve - template - void BatchedTRSM(CBLAS_SIDE side, - CBLAS_UPLO uplo, - CBLAS_TRANSPOSE transA, - CBLAS_DIAG diag, - int M, - int N, - T alpha, - const T** a, - int lda, - T** b, - int ldb, - int batch_size) const; -#endif - - private: - const DeviceContext& dev_ctx_; -}; - -template -class BlasT : private Blas { - public: - using Blas::Blas; - - template - void GEMM(ARGS... args) const { - Base()->template GEMM(args...); - } - -#ifdef PADDLE_WITH_MKLML // @{ Group MKLML: class BlasT - template - T* GEMM_ALLOC(ARGS... args) const { - return Base()->template GEMM_ALLOC(args...); - } - - template - void GEMM_PACK(ARGS... args) const { - Base()->template GEMM_PACK(args...); - } - - template - void GEMM_COMPUTE(ARGS... args) const { - Base()->template GEMM_COMPUTE(args...); - } - - template - void GEMM_FREE(ARGS... args) const { - Base()->template GEMM_FREE(args...); - } - - template - void CSRMM(ARGS... args) const { - Base()->template CSRMM(args...); - } - -#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) - template - void MatMulWithHead(ARGS... args) const { - Base()->template MatMulWithHead(args...); - } -#endif -#endif // @} End Group MKLML: class BlasT - - template - void MatMul(ARGS... args) const { - Base()->template MatMul(args...); - } - - template - void AXPY(ARGS... args) const { - Base()->template AXPY(args...); - } - - template - void VADD(ARGS... args) const { - Base()->template VADD(args...); - } - - template - void VSUB(ARGS... args) const { - Base()->template VSUB(args...); - } - - template - void VMUL(ARGS... args) const { - Base()->template VMUL(args...); - } - - template - void VDIV(ARGS... args) const { - Base()->template VDIV(args...); - } - - template - void VCOPY(ARGS... args) const { - Base()->template VCOPY(args...); - } - - template - void VEXP(ARGS... args) const { - Base()->template VEXP(args...); - } - - template - void VSQUARE(ARGS... args) const { - Base()->template VSQUARE(args...); - } - - template - void VPOW(ARGS... args) const { - Base()->template VPOW(args...); - } - - template - void GEMV(ARGS... args) const { - Base()->template GEMV(args...); - } - - template - T DOT(ARGS... args) const { - return Base()->template DOT(args...); - } - template - void CUDOT(ARGS... args) const { - Base()->template CUDOT(args...); - } - template - void SCAL(ARGS... args) const { - Base()->template SCAL(args...); - } - - template - T ASUM(ARGS... args) const { - return Base()->template ASUM(args...); - } - - template - void BatchedGEMM(ARGS... args) const { - Base()->template BatchedGEMM(args...); - } - - template - void VINV(ARGS... args) const { - Base()->template VINV(args...); - } - - template - void VMERF(ARGS... args) const { - Base()->template VMERF(args...); - } - - template - void TRSM(ARGS... args) const { - Base()->template TRSM(args...); - } - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - template - void BatchedGETRF(ARGS... args) const { - Base()->template BatchedGETRF(args...); - } - - template - void BatchedGETRI(ARGS... args) const { - Base()->template BatchedGETRI(args...); - } - - template - void BatchedMatInv(ARGS... args) const { - Base()->template BatchedMatInv(args...); - } - - // solve - template - void BatchedGETRS(ARGS... args) const { - Base()->template BatchedGETRS(args...); - } - - // triangular_solve - template - void BatchedTRSM(ARGS... args) const { - Base()->template BatchedTRSM(args...); - } -#endif - - private: - const Blas* Base() const { - return static_cast*>(this); - } -}; - -template -inline BlasT GetBlas(const DeviceContext& dev_ctx) { - return BlasT(dev_ctx); -} - -} // namespace funcs -} // namespace phi -// clang-format off -#include "./blas_impl.h" -#ifdef PADDLE_WITH_CUDA -#include "./blas_impl.cu.h" -#endif -#ifdef PADDLE_WITH_HIP -#include "paddle/phi/kernels/funcs/blas/blas_impl.hip.h" -#endif -// clang-format on diff --git a/backends/metax_gpu/kernels/funcs/blas/blas_impl.cu.h b/backends/metax_gpu/kernels/funcs/blas/blas_impl.cu.h deleted file mode 100644 index ae4baa52613..00000000000 --- a/backends/metax_gpu/kernels/funcs/blas/blas_impl.cu.h +++ /dev/null @@ -1,3027 +0,0 @@ -// 2024 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights -// Reserved. -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#if defined(__NVCC__) -#include -#endif -#include "./cublas.h" -#include "glog/logging.h" -#include "paddle/common/flags.h" -#include "paddle/phi/backends/gpu/gpu_context.h" -// #include "paddle/phi/core/flags.h" -#include "paddle/phi/core/enforce.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -#define INT_MAX_VALUE 2147483647 - -PHI_DECLARE_bool(enable_cublas_tensor_op_math); -PHI_DECLARE_bool(gemm_use_half_precision_compute_type); - -namespace phi { -namespace funcs { -template -struct CUBlas; - -template <> -struct CUBlas { - template - static void GEMM(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSgemm(args...)); - } - - template - static void AXPY(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSaxpy(args...)); - } - - template - static void SCAL(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSscal(args...)); - } - - template - static void VCOPY(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasScopy(args...)); - } - - template - static void GEMV(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSgemv(args...)); - } - - template - static void GEMM_BATCH(ARGS... args) { -#if CUDA_VERSION >= 8000 - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSgemmBatched(args...)); -#else - PADDLE_THROW(phi::errors::Unimplemented( - "SgemmBatched is not supported on cuda <= 7.5")); -#endif - } - - template - static void GEMM_STRIDED_BATCH(ARGS... args) { -#if CUDA_VERSION >= 8000 - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::cublasSgemmStridedBatched(args...)); -#else - PADDLE_THROW(phi::errors::Unimplemented( - "SgemmStridedBatched is not supported on cuda <= 7.5")); -#endif - } - - // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply. - // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode - template - static void GEMM_EX(phi::GPUContext *dev_ctx, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float *alpha, - const void *A, - cudaDataType_t Atype, - int lda, - const void *B, - cudaDataType_t Btype, - int ldb, - const float *beta, - void *C, - cudaDataType_t Ctype, - int ldc) { -// Because the gcc 4.8 doesn't expand template parameter pack that -// appears in a lambda-expression, I can not use template parameter pack -// here. -#if CUDA_VERSION >= 8000 - VLOG(5) << "use_tensor_op_math: " - << (dev_ctx->tensor_core_available() ? "True" : "False"); - dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSgemmEx(handle, - transa, - transb, - m, - n, - k, - alpha, - A, - Atype, - lda, - B, - Btype, - ldb, - beta, - C, - Ctype, - ldc)); - }); -#else - PADDLE_THROW(phi::errors::Unimplemented( - "cublasSgemmEx is not supported on cuda <= 7.5")); -#endif - } - - template - static void TRSM(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasStrsm(args...)); - } - - template - static void GETRF_BATCH(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSgetrfBatched(args...)); - } - - template - static void GETRI_BATCH(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSgetriBatched(args...)); - } - - template - static void MATINV_BATCH(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSmatinvBatched(args...)); - } - - template - static void GETRS_BATCH(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSgetrsBatched(args...)); - } - - template - static void TRSM_BATCH(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasStrsmBatched(args...)); - } -}; - -template <> -struct CUBlas { - template - static void GEMM(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDgemm(args...)); - } - - template - static void AXPY(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDaxpy(args...)); - } - - template - static void SCAL(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDscal(args...)); - } - - template - static void VCOPY(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDcopy(args...)); - } - - template - static void GEMV(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDgemv(args...)); - } - - template - static void GEMM_BATCH(ARGS... args) { -#if CUDA_VERSION >= 8000 - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDgemmBatched(args...)); -#else - PADDLE_THROW(phi::errors::Unimplemented( - "DgemmBatched is not supported on cuda <= 7.5")); -#endif - } - - template - static void GEMM_STRIDED_BATCH(ARGS... args) { -#if CUDA_VERSION >= 8000 - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::cublasDgemmStridedBatched(args...)); -#else - PADDLE_THROW(phi::errors::Unimplemented( - "DgemmStridedBatched is not supported on cuda <= 7.5")); -#endif - } - - template - static void GEMM_EX(ARGS... args UNUSED) { - PADDLE_THROW( - phi::errors::Unimplemented("Currently there are not cublasDgemmEx.")); - } - - template - static void TRSM(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDtrsm(args...)); - } - - template - static void GETRF_BATCH(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDgetrfBatched(args...)); - } - - template - static void GETRI_BATCH(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDgetriBatched(args...)); - } - - template - static void MATINV_BATCH(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDmatinvBatched(args...)); - } - - template - static void GETRS_BATCH(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDgetrsBatched(args...)); - } - - template - static void TRSM_BATCH(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDtrsmBatched(args...)); - } -}; - -template <> -struct CUBlas { - using float16 = phi::dtype::float16; - - static void GEMM(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float16 *alpha, - const float16 *A, - int lda, - const float16 *B, - int ldb, - const float16 *beta, - float16 *C, - int ldc) { - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::cublasHgemm(handle, - transa, - transb, - m, - n, - k, - reinterpret_cast(alpha), - reinterpret_cast(A), - lda, - reinterpret_cast(B), - ldb, - reinterpret_cast(beta), - reinterpret_cast(C), - ldc)); - } - -#if defined(__NVCC__) - static void GEMM_BATCH(phi::GPUContext *dev_ctx, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float *alpha, - const float16 **A, - cudaDataType_t Atype, - int lda, - const float16 **B, - cudaDataType_t Btype, - int ldb, - const float *beta, - float16 **C, - cudaDataType_t Ctype, - int ldc, - int batchCount, - cublasComputeType_t computeType) { -#if CUDA_VERSION >= 8000 - cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; -#if CUDA_VERSION >= 9000 - bool use_tensor_op_math = dev_ctx->tensor_core_available(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " - << (use_tensor_op_math ? "True" : "False"); -#endif // CUDA_VERSION >= 9000 - thrust::device_vector A_ptr(A, A + batchCount); - thrust::device_vector B_ptr(B, B + batchCount); - thrust::device_vector C_ptr(C, C + batchCount); - dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::cublasGemmBatchedEx(handle, - transa, - transb, - m, - n, - k, - alpha, - A_ptr.data().get(), - Atype, - lda, - B_ptr.data().get(), - Btype, - ldb, - beta, - C_ptr.data().get(), - Ctype, - ldc, - batchCount, - computeType, - algo)); - }); -#else - PADDLE_THROW(phi::errors::Unimplemented( - "cublasGemmBatchedEx is not supported on cuda <= 7.5")); -#endif - } -#endif - - static void GEMM_STRIDED_BATCH(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float16 *alpha, - const float16 *A, - int lda, - long long int strideA, // NOLINT - const float16 *B, // NOLINT - int ldb, - long long int strideB, // NOLINT - const float16 *beta, - float16 *C, - int ldc, - long long int strideC, // NOLINT - int batchCount) { -#if CUDA_VERSION >= 8000 - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasHgemmStridedBatched( - handle, - transa, - transb, - m, - n, - k, - reinterpret_cast(alpha), - reinterpret_cast(A), - lda, - strideA, - reinterpret_cast(B), - ldb, - strideB, - reinterpret_cast(beta), - reinterpret_cast(C), - ldc, - strideC, - batchCount)); -#else - PADDLE_THROW(phi::errors::Unimplemented( - "HgemmStridedBatched is not supported on cuda <= 7.5")); -#endif - } - - // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply. - // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode - template - static void GEMM_EX(phi::GPUContext *dev_ctx, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const void *alpha, - const void *A, - cudaDataType_t Atype, - int lda, - const void *B, - cudaDataType_t Btype, - int ldb, - const void *beta, - void *C, - cudaDataType_t Ctype, - int ldc, - cublasComputeType_t computeType) { -#if CUDA_VERSION >= 8000 - cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; -#if CUDA_VERSION >= 9000 - bool use_tensor_op_math = dev_ctx->tensor_core_available(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " - << (use_tensor_op_math ? "True" : "False"); -#endif // CUDA_VERSION >= 9000 - - dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle, - transa, - transb, - m, - n, - k, - alpha, - A, - Atype, - lda, - B, - Btype, - ldb, - beta, - C, - Ctype, - ldc, - computeType, - algo)); - }); -#else - PADDLE_THROW(phi::errors::Unimplemented( - "cublasGemmEx is not supported on cuda <= 7.5")); -#endif - } -}; - -template <> -struct CUBlas> { - static void GEMV(cublasHandle_t handle, - cublasOperation_t transa, - int m, - int n, - const phi::dtype::complex *alpha, - const phi::dtype::complex *A, - int lda, - const phi::dtype::complex *B, - int ldb, - const phi::dtype::complex *beta, - phi::dtype::complex *C, - int ldc) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCgemv( - handle, - transa, - m, - n, - reinterpret_cast(alpha), - reinterpret_cast(A), - lda, - reinterpret_cast(B), - ldb, - reinterpret_cast(beta), - reinterpret_cast(C), - ldc)); - } - - static void AXPY(cublasHandle_t handle, - int n, - const phi::dtype::complex *alpha, - const phi::dtype::complex *X, - const int incX, - phi::dtype::complex *Y, - const int incY) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCaxpy( - handle, - n, - reinterpret_cast(alpha), - reinterpret_cast(X), - incX, - reinterpret_cast(Y), - incY)); - } - - static void GEMM_STRIDED_BATCH(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const phi::dtype::complex *alpha, - const phi::dtype::complex *A, - int lda, - long long int strideA, // NOLINT - const phi::dtype::complex *B, // NOLINT - int ldb, - long long int strideB, // NOLINT - const phi::dtype::complex *beta, - phi::dtype::complex *C, - int ldc, - long long int strideC, // NOLINT - int batchCount) { -#if CUDA_VERSION >= 8000 - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCgemmStridedBatched( - handle, - transa, - transb, - m, - n, - k, - reinterpret_cast(alpha), - reinterpret_cast(A), - lda, - strideA, - reinterpret_cast(B), - ldb, - strideB, - reinterpret_cast(beta), - reinterpret_cast(C), - ldc, - strideC, - batchCount)); -#else - PADDLE_THROW(phi::errors::Unimplemented( - "CgemmStridedBatched is not supported on cuda <= 7.5")); -#endif - } - - static void GEMM(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const phi::dtype::complex *alpha, - const phi::dtype::complex *A, - int lda, - const phi::dtype::complex *B, - int ldb, - const phi::dtype::complex *beta, - phi::dtype::complex *C, - int ldc) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCgemm( - handle, - transa, - transb, - m, - n, - k, - reinterpret_cast(alpha), - reinterpret_cast(A), - lda, - reinterpret_cast(B), - ldb, - reinterpret_cast(beta), - reinterpret_cast(C), - ldc)); - } - - static void TRSM(cublasHandle_t handle, - cublasSideMode_t side, - cublasFillMode_t uplo, - cublasOperation_t transa, - cublasDiagType_t diag, - int m, - int n, - const phi::dtype::complex *alpha, - const phi::dtype::complex *A, - int lda, - phi::dtype::complex *B, - int ldb) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCtrsm( - handle, - side, - uplo, - transa, - diag, - m, - n, - reinterpret_cast(alpha), - reinterpret_cast(A), - lda, - reinterpret_cast(B), - ldb)); - } - - // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply. - // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode - template - static void GEMM_EX(phi::GPUContext *dev_ctx, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const void *alpha, - const void *A, - cudaDataType_t Atype, - int lda, - const void *B, - cudaDataType_t Btype, - int ldb, - const void *beta, - void *C, - cudaDataType_t Ctype, - int ldc, - cublasComputeType_t computeType) { -#if CUDA_VERSION >= 8000 - cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; -#if CUDA_VERSION >= 9000 - bool use_tensor_op_math = dev_ctx->tensor_core_available(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " - << (use_tensor_op_math ? "True" : "False"); -#endif // CUDA_VERSION >= 9000 - - dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle, - transa, - transb, - m, - n, - k, - alpha, - A, - Atype, - lda, - B, - Btype, - ldb, - beta, - C, - Ctype, - ldc, - computeType, - algo)); - }); -#else - PADDLE_THROW(phi::errors::Unimplemented( - "cublasGemmEx is not supported on cuda <= 7.5")); -#endif - } - - static void TRSM_BATCH(cublasHandle_t handle, - cublasSideMode_t side, - cublasFillMode_t uplo, - cublasOperation_t transa, - cublasDiagType_t diag, - int m, - int n, - const phi::dtype::complex *alpha, - const phi::dtype::complex **A, - int lda, - phi::dtype::complex **B, - int ldb, - int batch_size) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCtrsmBatched( - handle, - side, - uplo, - transa, - diag, - m, - n, - reinterpret_cast(alpha), - reinterpret_cast(A), - lda, - reinterpret_cast(B), - ldb, - batch_size)); - } - // ****************************************************************新增模版定义********************* - - static void GETRF_BATCH(cublasHandle_t handle, - int n, - phi::dtype::complex **A, - int lda, - int *ipiv, - int *info, - int batch_size) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCgetrfBatched( - handle, - n, - reinterpret_cast(A), - lda, - ipiv, - info, - batch_size)); - } - - static void GETRI_BATCH(cublasHandle_t handle, - int n, - const phi::dtype::complex **A, - int lda, - const int *ipiv, - phi::dtype::complex **Ainv, - int ldc, - int *info, - int batch_size) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCgetriBatched( - handle, - n, - reinterpret_cast(A), - lda, - ipiv, - reinterpret_cast(Ainv), - ldc, - info, - batch_size)); - } - - static void MATINV_BATCH(cublasHandle_t handle, - int n, - const phi::dtype::complex **A, - int lda, - phi::dtype::complex **Ainv, - int lda_inv, - int *info, - int batch_size) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCmatinvBatched( - handle, - n, - reinterpret_cast(A), - lda, - reinterpret_cast(Ainv), - lda_inv, - info, - batch_size)); - } - // ****************************************************************新增模版定义********************* -}; - -template <> -struct CUBlas> { - static void GEMV(cublasHandle_t handle, - cublasOperation_t transa, - int m, - int n, - const phi::dtype::complex *alpha, - const phi::dtype::complex *A, - int lda, - const phi::dtype::complex *B, - int ldb, - const phi::dtype::complex *beta, - phi::dtype::complex *C, - int ldc) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZgemv( - handle, - transa, - m, - n, - reinterpret_cast(alpha), - reinterpret_cast(A), - lda, - reinterpret_cast(B), - ldb, - reinterpret_cast(beta), - reinterpret_cast(C), - ldc)); - } - - static void AXPY(cublasHandle_t handle, - int n, - const phi::dtype::complex *alpha, - const phi::dtype::complex *X, - const int incX, - phi::dtype::complex *Y, - const int incY) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZaxpy( - handle, - n, - reinterpret_cast(alpha), - reinterpret_cast(X), - incX, - reinterpret_cast(Y), - incY)); - } - - static void GEMM_STRIDED_BATCH( - cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const phi::dtype::complex *alpha, - const phi::dtype::complex *A, - int lda, - long long int strideA, // NOLINT - const phi::dtype::complex *B, // NOLINT - int ldb, - long long int strideB, // NOLINT - const phi::dtype::complex *beta, - phi::dtype::complex *C, - int ldc, - long long int strideC, // NOLINT - int batchCount) { -#if CUDA_VERSION >= 8000 - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZgemmStridedBatched( - handle, - transa, - transb, - m, - n, - k, - reinterpret_cast(alpha), - reinterpret_cast(A), - lda, - strideA, - reinterpret_cast(B), - ldb, - strideB, - reinterpret_cast(beta), - reinterpret_cast(C), - ldc, - strideC, - batchCount)); -#else - PADDLE_THROW(phi::errors::Unimplemented( - "CgemmStridedBatched is not supported on cuda <= 7.5")); -#endif - } - - static void GEMM(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const phi::dtype::complex *alpha, - const phi::dtype::complex *A, - int lda, - const phi::dtype::complex *B, - int ldb, - const phi::dtype::complex *beta, - phi::dtype::complex *C, - int ldc) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZgemm( - handle, - transa, - transb, - m, - n, - k, - reinterpret_cast(alpha), - reinterpret_cast(A), - lda, - reinterpret_cast(B), - ldb, - reinterpret_cast(beta), - reinterpret_cast(C), - ldc)); - } - - static void TRSM(cublasHandle_t handle, - cublasSideMode_t side, - cublasFillMode_t uplo, - cublasOperation_t transa, - cublasDiagType_t diag, - int m, - int n, - const phi::dtype::complex *alpha, - const phi::dtype::complex *A, - int lda, - phi::dtype::complex *B, - int ldb) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZtrsm( - handle, - side, - uplo, - transa, - diag, - m, - n, - reinterpret_cast(alpha), - reinterpret_cast(A), - lda, - reinterpret_cast(B), - ldb)); - } - - static void TRSM_BATCH(cublasHandle_t handle, - cublasSideMode_t side, - cublasFillMode_t uplo, - cublasOperation_t transa, - cublasDiagType_t diag, - int m, - int n, - const phi::dtype::complex *alpha, - const phi::dtype::complex **A, - int lda, - phi::dtype::complex **B, - int ldb, - int batch_size) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZtrsmBatched( - handle, - side, - uplo, - transa, - diag, - m, - n, - reinterpret_cast(alpha), - reinterpret_cast(A), - lda, - reinterpret_cast(B), - ldb, - batch_size)); - } - - // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply. - // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode - template - static void GEMM_EX(phi::GPUContext *dev_ctx, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const void *alpha, - const void *A, - cudaDataType_t Atype, - int lda, - const void *B, - cudaDataType_t Btype, - int ldb, - const void *beta, - void *C, - cudaDataType_t Ctype, - int ldc, - cublasComputeType_t computeType) { -#if CUDA_VERSION >= 8000 - cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; -#if CUDA_VERSION >= 9000 - bool use_tensor_op_math = dev_ctx->tensor_core_available(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " - << (use_tensor_op_math ? "True" : "False"); -#endif // CUDA_VERSION >= 9000 - - dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle, - transa, - transb, - m, - n, - k, - alpha, - A, - Atype, - lda, - B, - Btype, - ldb, - beta, - C, - Ctype, - ldc, - computeType, - algo)); - }); -#else - PADDLE_THROW(phi::errors::Unimplemented( - "cublasGemmEx is not supported on cuda <= 7.5")); -#endif - } - // &*******************************************新增模版定义************************* - static void GETRF_BATCH(cublasHandle_t handle, - int n, - phi::dtype::complex **A, - int lda, - int *ipiv, - int *info, - int batch_size) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZgetrfBatched( - handle, - n, - reinterpret_cast(A), - lda, - ipiv, - info, - batch_size)); - } - - static void GETRI_BATCH(cublasHandle_t handle, - int n, - const phi::dtype::complex **A, - int lda, - const int *ipiv, - phi::dtype::complex **Ainv, - int ldc, - int *info, - int batch_size) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZgetriBatched( - handle, - n, - reinterpret_cast(A), - lda, - ipiv, - reinterpret_cast(Ainv), - ldc, - info, - batch_size)); - } - - static void MATINV_BATCH(cublasHandle_t handle, - int n, - const phi::dtype::complex **A, - int lda, - phi::dtype::complex **Ainv, - int lda_inv, - int *info, - int batch_size) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZmatinvBatched( - handle, - n, - reinterpret_cast(A), - lda, - reinterpret_cast(Ainv), - lda_inv, - info, - batch_size)); - } - // &*******************************************新增模版定义************************* -}; - -inline void CheckGEMMNSize(int64_t N) { - constexpr int64_t kMaxN = 1073741823; - if (N > kMaxN) { - PADDLE_THROW(common::errors::Unimplemented( - "cublas GEMM does not support N > %ld. Got N = %ld. ", kMaxN, N)); - } -} - -template <> -template -void Blas::GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int64_t M, - int64_t N, - int64_t K, - T alpha, - const T *A, - const T *B, - T beta, - T *C) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int64_t lda = (transA == CblasNoTrans) ? K : M; - int64_t ldb = (transB == CblasNoTrans) ? N : K; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - -#if CUDA_VERSION >= 8000 - if (FLAGS_enable_cublas_tensor_op_math && std::is_same::value) { - auto &cuda_ctx = const_cast(dev_ctx_); - if (M > INT_MAX_VALUE || N > INT_MAX_VALUE || K > INT_MAX_VALUE) { -#if CUDA_VERSION >= 12030 && defined(__linux__) - PADDLE_THROW(common::errors::Unimplemented( - "CUBlas::GEMM_EX_64 is not complete")); -#else - PADDLE_THROW(common::errors::Unimplemented( - "GEMM_EX_64 is not supported on cuda < 12.3")); -#endif - } else { - CheckGEMMNSize(N); - CUBlas::GEMM_EX(&cuda_ctx, - cuTransB, - cuTransA, - N, - M, - K, - &alpha, - B, - CUDA_R_32F, - ldb, - A, - CUDA_R_32F, - lda, - &beta, - C, - CUDA_R_32F, - N); - } - } else { -#endif // CUDA_VERSION >= 8000 - if (M > INT_MAX_VALUE || N > INT_MAX_VALUE || K > INT_MAX_VALUE) { - PADDLE_THROW(common::errors::Unimplemented( - "GEMM_EX_64 is not supported on cuda < 12.3")); - } else { - dev_ctx_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GEMM(handle, - cuTransB, - cuTransA, - N, - M, - K, - &alpha, - B, - ldb, - A, - lda, - &beta, - C, - N); - }); - } - -#if CUDA_VERSION >= 8000 - } -#endif // CUDA_VERSION >= 8000 -} - -template <> -template <> -inline void Blas::GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int64_t M, - int64_t N, - int64_t K, - phi::dtype::float16 alpha, - const phi::dtype::float16 *A, - const phi::dtype::float16 *B, - phi::dtype::float16 beta, - phi::dtype::float16 *C) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int64_t lda = (transA == CblasNoTrans) ? K : M; - int64_t ldb = (transB == CblasNoTrans) ? N : K; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - - // TODO(kexinzhao): add processing code for compute capability < 53 case - PADDLE_ENFORCE_GE( - dev_ctx_.GetComputeCapability(), - 53, - phi::errors::InvalidArgument( - "cublas fp16 gemm requires GPU compute capability >= 53," - "but received %d", - dev_ctx_.GetComputeCapability())); - - float h_alpha = static_cast(alpha); - float h_beta = static_cast(beta); - -#if CUDA_VERSION >= 8000 - // cublasHgemm does true FP16 computation which is slow for non-Volta - // GPUs. So use cublasGemmEx instead which does pesudo FP16 computation: - // input/output in fp16, computation in fp32, which can also be accelerated - // using tensor cores in volta GPUs. - auto &cuda_ctx = const_cast(dev_ctx_); - CUBlas::GEMM_EX(&cuda_ctx, - cuTransB, - cuTransA, - N, - M, - K, - &h_alpha, - B, - CUDA_R_16F, - ldb, - A, - CUDA_R_16F, - lda, - &h_beta, - C, - CUDA_R_16F, - N, - CUBLAS_COMPUTE_32F); -#else - // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm - - dev_ctx_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GEMM(handle, - cuTransB, - cuTransA, - N, - M, - K, - &h_alpha, - h_B, - ldb, - h_A, - lda, - &h_beta, - h_C, - N); - }); -#endif // CUDA_VERSION >= 8000 -} - -template <> -template -void Blas::GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int64_t M, - int64_t N, - int64_t K, - U alpha, - const T *A, - const T *B, - U beta, - T *C) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int64_t lda = (transA == CblasNoTrans) ? K : M; - int64_t ldb = (transB == CblasNoTrans) ? N : K; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - - T t_alpha = static_cast(alpha); - T t_beta = static_cast(beta); - -#if CUDA_VERSION >= 8000 - if (FLAGS_enable_cublas_tensor_op_math && std::is_same::value) { - auto &cuda_ctx = const_cast(dev_ctx_); - if (M > INT_MAX_VALUE || N > INT_MAX_VALUE || K > INT_MAX_VALUE) { -#if CUDA_VERSION >= 12030 && defined(__linux__) - PADDLE_THROW(common::errors::Unimplemented("GEMM_EX_64 is not complete")); -#else - PADDLE_THROW(common::errors::Unimplemented( - "GEMM_EX_64 is not supported on cuda < 12.3")); -#endif - } else { - CheckGEMMNSize(N); - CUBlas::GEMM_EX(&cuda_ctx, - cuTransB, - cuTransA, - static_cast(N), - static_cast(M), - static_cast(K), - &t_alpha, - B, - CUDA_R_32F, - static_cast(ldb), - A, - CUDA_R_32F, - static_cast(lda), - &t_beta, - C, - CUDA_R_32F, - static_cast(N)); - } - } else { -#endif // CUDA_VERSION >= 8000 - if (M > INT_MAX_VALUE || N > INT_MAX_VALUE || K > INT_MAX_VALUE) { - PADDLE_THROW(common::errors::Unimplemented( - "GEMM_EX_64 is not supported on cuda < 12.3")); - } else { - dev_ctx_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GEMM(handle, - cuTransB, - cuTransA, - static_cast(N), - static_cast(M), - static_cast(K), - &t_alpha, - B, - static_cast(ldb), - A, - static_cast(lda), - &t_beta, - C, - static_cast(N)); - }); - } - -#if CUDA_VERSION >= 8000 - } -#endif // CUDA_VERSION >= 8000 -} - -template <> -template <> -inline void Blas::GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int64_t M, - int64_t N, - int64_t K, - float alpha, - const phi::dtype::float16 *A, - const phi::dtype::float16 *B, - float beta, - phi::dtype::float16 *C) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int64_t lda = (transA == CblasNoTrans) ? K : M; - int64_t ldb = (transB == CblasNoTrans) ? N : K; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - - // TODO(kexinzhao): add processing code for compute capability < 53 case - // PADDLE_ENFORCE_GE( - // dev_ctx_.GetComputeCapability(), - // 53, - // common::errors::InvalidArgument( - // "cublas fp16 gemm requires GPU compute capability >= 53," - // "but received %d", - // dev_ctx_.GetComputeCapability())); - - float h_alpha = alpha; - float h_beta = beta; - -#if CUDA_VERSION >= 8000 - auto &cuda_ctx = const_cast(dev_ctx_); -#endif - // cublasHgemm does true FP16 computation which is slow for non-Volta - // GPUs. So use cublasGemmEx instead which does pseudo FP16 computation: - // input/output in fp16, computation in fp32, which can also be accelerated - // using tensor cores in volta GPUs. - if (M > INT_MAX_VALUE || N > INT_MAX_VALUE || K > INT_MAX_VALUE) { -#if CUDA_VERSION >= 12030 && defined(__linux__) - PADDLE_THROW(common::errors::Unimplemented("GEMM_EX_64 is not complete")); -#else - PADDLE_THROW(common::errors::Unimplemented( - "GEMM_EX_64 is not supported on cuda < 12.3")); -#endif // CUDA_VERSION >= 12030 - } else { -#if CUDA_VERSION >= 8000 - CheckGEMMNSize(N); - CUBlas::GEMM_EX(&cuda_ctx, - cuTransB, - cuTransA, - static_cast(N), - static_cast(M), - static_cast(K), - &h_alpha, - B, - CUDA_R_16F, - static_cast(ldb), - A, - CUDA_R_16F, - static_cast(lda), - &h_beta, - C, - CUDA_R_16F, - static_cast(N), - CUBLAS_COMPUTE_32F); -#else - // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm - dev_ctx_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GEMM(handle, - cuTransB, - cuTransA, - static_cast(N), - static_cast(M), - static_cast(K), - &h_alpha, - h_B, - static_cast(ldb), - h_A, - static_cast(lda), - &h_beta, - h_C, - static_cast(N)); - }); -#endif // CUDA_VERSION >= 8000 - } -} - -template <> -template <> -inline void Blas::GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int64_t M, - int64_t N, - int64_t K, - phi::dtype::bfloat16 alpha, - const phi::dtype::bfloat16 *A, - const phi::dtype::bfloat16 *B, - phi::dtype::bfloat16 beta, - phi::dtype::bfloat16 *C) const { -#if CUDA_VERSION >= 11000 - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int64_t lda = (transA == CblasNoTrans) ? K : M; - int64_t ldb = (transB == CblasNoTrans) ? N : K; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - - PADDLE_ENFORCE_GE( - dev_ctx_.GetComputeCapability(), - 80, - phi::errors::InvalidArgument( - "cublas bf16 gemm requires GPU compute capability >= 80," - "but received %d", - dev_ctx_.GetComputeCapability())); - - float h_alpha = static_cast(alpha); - float h_beta = static_cast(beta); - - cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; - bool use_tensor_op_math = dev_ctx_.tensor_core_available(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False"); - - if (M > INT_MAX_VALUE || N > INT_MAX_VALUE || K > INT_MAX_VALUE) { -#if CUDA_VERSION >= 12030 && defined(__linux__) - PADDLE_THROW( - common::errors::Unimplemented("cublasGemmEx_64 is not complete")); -#else - PADDLE_THROW(common::errors::Unimplemented( - "cublasGemmEx_64 is not supported on cuda < 12.3")); -#endif // CUDA_VERSION >= 12030 - } else { - CheckGEMMNSize(N); - dev_ctx_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle, - cuTransB, - cuTransA, - N, - M, - K, - &h_alpha, - B, - CUDA_R_16BF, - ldb, - A, - CUDA_R_16BF, - lda, - &h_beta, - C, - CUDA_R_16BF, - N, - CUBLAS_COMPUTE_32F, - algo)); - }); - } -#else - // raise error - PADDLE_THROW(phi::errors::Unimplemented( - "cublasGemmEx with bfloat16 is not supported on cuda <= 11")); - -#endif // CUDA_VERSION >= 11000 -} - -template <> -template <> -inline void Blas::GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int64_t M, - int64_t N, - int64_t K, - phi::dtype::complex alpha, - const phi::dtype::complex *A, - const phi::dtype::complex *B, - phi::dtype::complex beta, - phi::dtype::complex *C) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int64_t lda = (transA == CblasNoTrans) ? K : M; - int64_t ldb = (transB == CblasNoTrans) ? N : K; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - - // TODO(kexinzhao): add processing code for compute capability < 53 case - PADDLE_ENFORCE_GE( - dev_ctx_.GetComputeCapability(), - 53, - phi::errors::InvalidArgument( - "cublas complex64 gemm requires GPU compute capability >= 53," - "but received %d", - dev_ctx_.GetComputeCapability())); - - thrust::complex c_alpha = - thrust::complex(alpha.real, alpha.imag); - thrust::complex c_beta = thrust::complex(beta.real, beta.imag); - -#if CUDA_VERSION >= 8000 - auto &cuda_ctx = const_cast(dev_ctx_); -#endif - - if (M > INT_MAX_VALUE || N > INT_MAX_VALUE || K > INT_MAX_VALUE) { -#if CUDA_VERSION >= 12030 && defined(__linux__) - PADDLE_THROW(common::errors::Unimplemented("GEMM_EX_64 is not complete")); -#else - PADDLE_THROW(common::errors::Unimplemented( - "GEMM_EX_64 is not supported on cuda < 12.3")); -#endif // CUDA_VERSION >= 12030 - } else { -#if CUDA_VERSION >= 8000 - CheckGEMMNSize(N); - CUBlas>::GEMM_EX(&cuda_ctx, - cuTransB, - cuTransA, - static_cast(N), - static_cast(M), - static_cast(K), - &c_alpha, - B, - CUDA_C_32F, - static_cast(ldb), - A, - CUDA_C_32F, - static_cast(lda), - &c_beta, - C, - CUDA_C_32F, - static_cast(N), - CUBLAS_COMPUTE_32F); -#else - // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm - - dev_ctx_.CublasCall([&](cublasHandle_t handle) { - CUBlas>::GEMM(handle, - cuTransB, - cuTransA, - static_cast(N), - static_cast(M), - static_cast(K), - &c_alpha, - h_B, - static_cast(ldb), - h_A, - static_cast(lda), - &c_beta, - h_C, - static_cast(N)); - }); -#endif // CUDA_VERSION >= 8000 - } -} - -template <> -template <> -inline void Blas::GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int64_t M, - int64_t N, - int64_t K, - phi::dtype::complex alpha, - const phi::dtype::complex *A, - const phi::dtype::complex *B, - phi::dtype::complex beta, - phi::dtype::complex *C) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int64_t lda = (transA == CblasNoTrans) ? K : M; - int64_t ldb = (transB == CblasNoTrans) ? N : K; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - - // TODO(kexinzhao): add processing code for compute capability < 53 case - PADDLE_ENFORCE_GE( - dev_ctx_.GetComputeCapability(), - 53, - phi::errors::InvalidArgument( - "cublas complex128 gemm requires GPU compute capability >= 53," - "but received %d", - dev_ctx_.GetComputeCapability())); - - thrust::complex c_alpha = - thrust::complex(alpha.real, alpha.imag); - thrust::complex c_beta = - thrust::complex(beta.real, beta.imag); - -#if CUDA_VERSION >= 8000 - auto &cuda_ctx = const_cast(dev_ctx_); -#endif - - if (M > INT_MAX_VALUE || N > INT_MAX_VALUE || K > INT_MAX_VALUE) { -#if CUDA_VERSION >= 12030 && defined(__linux__) - PADDLE_THROW(common::errors::Unimplemented("GEMM_EX_64 is not complete")); -#else - PADDLE_THROW(common::errors::Unimplemented( - "GEMM_EX_64 is not supported on cuda < 12.3")); -#endif // CUDA_VERSION >= 12030 - } else { -#if CUDA_VERSION >= 8000 - CheckGEMMNSize(N); - CUBlas>::GEMM_EX(&cuda_ctx, - cuTransB, - cuTransA, - static_cast(N), - static_cast(M), - static_cast(K), - &c_alpha, - B, - CUDA_C_64F, - static_cast(ldb), - A, - CUDA_C_64F, - static_cast(lda), - &c_beta, - C, - CUDA_C_64F, - static_cast(N), - CUBLAS_COMPUTE_64F); -#else - // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm - - dev_ctx_.CublasCall([&](cublasHandle_t handle) { - CUBlas>::GEMM(handle, - cuTransB, - cuTransA, - static_cast(N), - static_cast(M), - static_cast(K), - &c_alpha, - h_B, - static_cast(ldb), - h_A, - static_cast(lda), - &c_beta, - h_C, - static_cast(N)); - }); -#endif // CUDA_VERSION >= 8000 - } -} - -template <> -template <> -inline void Blas::GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int64_t M, - int64_t N, - int64_t K, - float alpha, - const phi::dtype::bfloat16 *A, - const phi::dtype::bfloat16 *B, - float beta, - phi::dtype::bfloat16 *C) const { -#if CUDA_VERSION >= 11000 - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int64_t lda = (transA == CblasNoTrans) ? K : M; - int64_t ldb = (transB == CblasNoTrans) ? N : K; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - - // PADDLE_ENFORCE_GE( - // dev_ctx_.GetComputeCapability(), - // 80, - // common::errors::InvalidArgument( - // "cublas bf16 gemm requires GPU compute capability >= 80," - // "but received %d", - // dev_ctx_.GetComputeCapability())); - - float h_alpha = alpha; - float h_beta = beta; - - cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; - bool use_tensor_op_math = dev_ctx_.tensor_core_available(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False"); - if (M > INT_MAX_VALUE || N > INT_MAX_VALUE || K > INT_MAX_VALUE) { -#if CUDA_VERSION >= 12030 && defined(__linux__) - PADDLE_THROW( - common::errors::Unimplemented("cublasGemmEx_64 is not complete")); -#else - PADDLE_THROW(common::errors::Unimplemented( - "cublasGemmEx_64 is not supported on cuda < 12.3")); -#endif // CUDA_VERSION >= 12030 - } else { - CheckGEMMNSize(N); - dev_ctx_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::cublasGemmEx(handle, - cuTransB, - cuTransA, - static_cast(N), - static_cast(M), - static_cast(K), - &h_alpha, - B, - CUDA_R_16BF, - static_cast(ldb), - A, - CUDA_R_16BF, - static_cast(lda), - &h_beta, - C, - CUDA_R_16BF, - static_cast(N), - CUDA_R_32F, - algo)); - }); - } -#else - // raise error - PADDLE_THROW(common::errors::Unimplemented( - "cublasGemmEx with bfloat16 is not supported on cuda <= 11")); - -#endif // CUDA_VERSION >= 11000 -} - -template <> -template -void Blas::GEMM(bool transA, - bool transB, - int M, - int N, - int K, - T alpha, - const T *A, - int lda, - const T *B, - int ldb, - T beta, - T *C, - int ldc) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N; - cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N; - -#if CUDA_VERSION >= 8000 - if (FLAGS_enable_cublas_tensor_op_math && std::is_same::value) { - auto &cuda_ctx = const_cast(dev_ctx_); - CUBlas::GEMM_EX(&cuda_ctx, - cuTransB, - cuTransA, - N, - M, - K, - &alpha, - B, - CUDA_R_32F, - ldb, - A, - CUDA_R_32F, - lda, - &beta, - C, - CUDA_R_32F, - ldc); - } else { -#endif // CUDA_VERSION >= 8000 - - dev_ctx_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GEMM(handle, - cuTransB, - cuTransA, - N, - M, - K, - &alpha, - B, - ldb, - A, - lda, - &beta, - C, - ldc); - }); - -#if CUDA_VERSION >= 8000 - } -#endif // CUDA_VERSION >= 8000 -} - -template <> -template <> -inline void Blas::GEMM(bool transA, - bool transB, - int M, - int N, - int K, - phi::dtype::float16 alpha, - const phi::dtype::float16 *A, - int lda, - const phi::dtype::float16 *B, - int ldb, - phi::dtype::float16 beta, - phi::dtype::float16 *C, - int ldc) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N; - cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N; - - dev_ctx_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GEMM(handle, - cuTransB, - cuTransA, - N, - M, - K, - &alpha, - B, - ldb, - A, - lda, - &beta, - C, - ldc); - }); -} - -template <> -template <> -inline void Blas::GEMM(bool transA, - bool transB, - int M, - int N, - int K, - phi::dtype::bfloat16 alpha, - const phi::dtype::bfloat16 *A, - int lda, - const phi::dtype::bfloat16 *B, - int ldb, - phi::dtype::bfloat16 beta, - phi::dtype::bfloat16 *C, - int ldc) const { -#if CUDA_VERSION >= 11000 - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N; - cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N; - - // PADDLE_ENFORCE_GE( - // dev_ctx_.GetComputeCapability(), - // 80, - // phi::errors::InvalidArgument( - // "cublas bf16 gemm requires GPU compute capability >= 80," - // "but received %d", - // dev_ctx_.GetComputeCapability())); - - float h_alpha = static_cast(alpha); - float h_beta = static_cast(beta); - - cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; - bool use_tensor_op_math = dev_ctx_.tensor_core_available(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False"); - - dev_ctx_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle, - cuTransB, - cuTransA, - N, - M, - K, - &h_alpha, - B, - CUDA_R_16BF, - ldb, - A, - CUDA_R_16BF, - lda, - &h_beta, - C, - CUDA_R_16BF, - ldc, - CUBLAS_COMPUTE_32F, - algo)); - }); -#else - // raise error - PADDLE_THROW(phi::errors::Unimplemented( - "cublasGemmEx with bfloat16 is not supported on cuda <= 11")); - -#endif // CUDA_VERSION >= 11000 -} - -template <> -template -void Blas::AXPY(int n, T alpha, const T *x, T *y) const { - dev_ctx_.CublasCall([&](cublasHandle_t handle) { - CUBlas::AXPY(handle, n, &alpha, x, 1, y, 1); - }); -} - -template <> -template -void Blas::SCAL(int n, const T alpha, T *x) const { - dev_ctx_.CublasCall( - [&](cublasHandle_t handle) { CUBlas::SCAL(handle, n, &alpha, x, 1); }); -} - -template <> -template -void Blas::VCOPY(int n, const T *x, T *y) const { - dev_ctx_.CublasCall( - [&](cublasHandle_t handle) { CUBlas::VCOPY(handle, n, x, 1, y, 1); }); -} - -template <> -template -void Blas::GEMV(bool trans_a, - int M, - int N, - T alpha, - const T *A, - const T *B, - T beta, - T *C) const { - cublasOperation_t cuTransA = !trans_a ? CUBLAS_OP_T : CUBLAS_OP_N; - - dev_ctx_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GEMV(handle, cuTransA, N, M, &alpha, A, N, B, 1, &beta, C, 1); - }); -} - -template <> -template <> -inline void Blas::GEMV(bool trans_a, - int M, - int N, - phi::dtype::float16 alpha, - const phi::dtype::float16 *A, - const phi::dtype::float16 *B, - phi::dtype::float16 beta, - phi::dtype::float16 *C) const { - // Because cublas doesn't support half gemv, we use cublasHgemm to achieve it. - if (trans_a) { - this->template GEMM( - CblasNoTrans, CblasNoTrans, 1, N, M, alpha, B, A, beta, C); - } else { - this->template GEMM( - CblasNoTrans, CblasNoTrans, M, 1, N, alpha, A, B, beta, C); - } -} - -template <> -template <> -inline void Blas::GEMV(bool trans_a, - int M, - int N, - phi::dtype::bfloat16 alpha, - const phi::dtype::bfloat16 *A, - const phi::dtype::bfloat16 *B, - phi::dtype::bfloat16 beta, - phi::dtype::bfloat16 *C) const { - // Because cublas doesn't support bfloat gemv, we use cublasHgemm to achieve - // it. - if (trans_a) { - this->template GEMM( - CblasNoTrans, CblasNoTrans, 1, N, M, alpha, B, A, beta, C); - } else { - this->template GEMM( - CblasNoTrans, CblasNoTrans, M, 1, N, alpha, A, B, beta, C); - } -} - -template <> -template -void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int64_t M, - int64_t N, - int64_t K, - T alpha, - const T *A, - const T *B, - T beta, - T *C, - int64_t batchCount, - int64_t strideA, - int64_t strideB) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int64_t lda = (transA == CblasNoTrans) ? K : M; - int64_t ldb = (transB == CblasNoTrans) ? N : K; - int64_t ldc = N; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - const int64_t strideC = M * N; - -#if CUDA_VERSION >= 9010 - if ((FLAGS_enable_cublas_tensor_op_math && (std::is_same::value)) || - std::is_same::value) { - cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; - bool use_tensor_op_math = dev_ctx_.tensor_core_available(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " - << (use_tensor_op_math ? "True" : "False"); - VLOG(4) << "use_half_precision_compute_type: " - << FLAGS_gemm_use_half_precision_compute_type; - - auto fp = std::is_same::value ? CUDA_R_32F : CUDA_R_16F; -#if CUDA_VERSION >= 11000 - auto compute_type = CUBLAS_COMPUTE_32F; -#else - auto compute_type = CUDA_R_32F; -#endif - - float h_alpha = static_cast(alpha); - float h_beta = static_cast(beta); - void *a = static_cast(&h_alpha); - void *b = static_cast(&h_beta); - // set ComputeType as CUDA_R_32F for fp16, for better accuracy - if (FLAGS_gemm_use_half_precision_compute_type == true && - std::is_same::value) { - a = static_cast(&alpha); - b = static_cast(&beta); -#if CUDA_VERSION >= 11000 - compute_type = CUBLAS_COMPUTE_16F; -#else - compute_type = CUDA_R_16F; -#endif - } - - if (M > INT_MAX_VALUE || N > INT_MAX_VALUE || K > INT_MAX_VALUE) { -#if CUDA_VERSION >= 12030 && defined(__linux__) - PADDLE_THROW(common::errors::Unimplemented( - "cublasGemmStridedBatchedEx_64 is not complete")); -#else - PADDLE_THROW(common::errors::Unimplemented( - "cublasGemmStridedBatchedEx_64 is not supported on cuda < 12.3")); -#endif // CUDA_VERSION >= 12030 - } else { - dev_ctx_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::cublasGemmStridedBatchedEx(handle, - cuTransB, - cuTransA, - N, - M, - K, - a, - B, - fp, - ldb, - strideB, - A, - fp, - lda, - strideA, - b, - C, - fp, - ldc, - strideC, - batchCount, - compute_type, - algo)); - }); - } - } else { -#endif // CUDA_VERSION >= 9010 - - dev_ctx_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GEMM_STRIDED_BATCH(handle, - cuTransB, - cuTransA, - static_cast(N), - static_cast(M), - static_cast(K), - &alpha, - B, - static_cast(ldb), - strideB, - A, - static_cast(lda), - strideA, - &beta, - C, - ldc, - strideC, - static_cast(batchCount)); - }); - -#if CUDA_VERSION >= 9010 - } -#endif // CUDA_VERSION >= 9010 -} - -template <> -template -void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int64_t M, - int64_t N, - int64_t K, - U alpha, - const T *A, - const T *B, - U beta, - T *C, - int64_t batchCount, - int64_t strideA, - int64_t strideB) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int64_t lda = (transA == CblasNoTrans) ? K : M; - int64_t ldb = (transB == CblasNoTrans) ? N : K; - int64_t ldc = N; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - const int64_t strideC = M * N; -#if CUDA_VERSION >= 9010 - if ((FLAGS_enable_cublas_tensor_op_math && (std::is_same::value)) || - std::is_same::value) { - cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; - bool use_tensor_op_math = dev_ctx_.tensor_core_available(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " - << (use_tensor_op_math ? "True" : "False"); - VLOG(4) << "use_half_precision_compute_type: " - << FLAGS_gemm_use_half_precision_compute_type; - - auto fp = std::is_same::value ? CUDA_R_32F : CUDA_R_16F; -#if CUDA_VERSION >= 11000 - auto compute_type = CUBLAS_COMPUTE_32F; -#else - auto compute_type = CUDA_R_32F; -#endif - - float h_alpha = static_cast(alpha); - float h_beta = static_cast(beta); - void *a = static_cast(&h_alpha); - void *b = static_cast(&h_beta); - // set ComputeType as CUDA_R_32F for fp16, for better accuracy - if (FLAGS_gemm_use_half_precision_compute_type == true && - std::is_same::value) { - a = static_cast(&alpha); - b = static_cast(&beta); -#if CUDA_VERSION >= 11000 - compute_type = CUBLAS_COMPUTE_16F; -#else - compute_type = CUDA_R_16F; -#endif - } - - if (M > INT_MAX_VALUE || N > INT_MAX_VALUE || K > INT_MAX_VALUE || - batchCount > INT_MAX_VALUE) { -#if CUDA_VERSION >= 12030 && defined(__linux__) - PADDLE_THROW(common::errors::Unimplemented( - "cublasGemmStridedBatchedEx_64 is not complete")); -#else - PADDLE_THROW(common::errors::Unimplemented( - "cublasGemmStridedBatchedEx_64 is not supported on cuda < 12.3")); -#endif // CUDA_VERSION >= 12030 - } else { - dev_ctx_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmStridedBatchedEx( - handle, - cuTransB, - cuTransA, - static_cast(N), - static_cast(M), - static_cast(K), - a, - B, - fp, - static_cast(ldb), - strideB, - A, - fp, - static_cast(lda), - strideA, - b, - C, - fp, - static_cast(ldc), - strideC, - static_cast(batchCount), - compute_type, - algo)); - }); - } - } else { -#endif // CUDA_VERSION >= 9010 - T h_alpha = static_cast(alpha); - T h_beta = static_cast(beta); - dev_ctx_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GEMM_STRIDED_BATCH(handle, - cuTransB, - cuTransA, - static_cast(N), - static_cast(M), - static_cast(K), - &h_alpha, - B, - static_cast(ldb), - strideB, - A, - static_cast(lda), - strideA, - &h_beta, - C, - static_cast(ldc), - strideC, - static_cast(batchCount)); - }); - -#if CUDA_VERSION >= 9010 - } -#endif // CUDA_VERSION >= 9010 -} - -template <> -template <> -inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int64_t M, - int64_t N, - int64_t K, - phi::dtype::bfloat16 alpha, - const phi::dtype::bfloat16 *A, - const phi::dtype::bfloat16 *B, - phi::dtype::bfloat16 beta, - phi::dtype::bfloat16 *C, - int64_t batchCount, - int64_t strideA, - int64_t strideB) const { -#if CUDA_VERSION >= 11000 - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int64_t lda = (transA == CblasNoTrans) ? K : M; - int64_t ldb = (transB == CblasNoTrans) ? N : K; - int64_t ldc = N; - - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - const int64_t strideC = M * N; - - float h_alpha = static_cast(alpha); - float h_beta = static_cast(beta); - - cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; - bool use_tensor_op_math = dev_ctx_.tensor_core_available(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False"); - if (M > INT_MAX_VALUE || N > INT_MAX_VALUE || K > INT_MAX_VALUE || - batchCount > INT_MAX_VALUE) { -#if CUDA_VERSION >= 12030 && defined(__linux__) - PADDLE_THROW(common::errors::Unimplemented( - "cublasGemmStridedBatchedEx_64 is not complete")); -#else - PADDLE_THROW(common::errors::Unimplemented( - "cublasGemmStridedBatchedEx_64 is not supported on cuda < 12.3")); -#endif // CUDA_VERSION >= 12030 - } else { - dev_ctx_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::cublasGemmStridedBatchedEx(handle, - cuTransB, - cuTransA, - static_cast(N), - static_cast(M), - static_cast(K), - &h_alpha, - B, - CUDA_R_16BF, - static_cast(ldb), - strideB, - A, - CUDA_R_16BF, - static_cast(lda), - strideA, - &h_beta, - C, - CUDA_R_16BF, - static_cast(ldc), - strideC, - static_cast(batchCount), - CUBLAS_COMPUTE_32F, - algo)); - }); - } -#else - // raise error - PADDLE_THROW(common::errors::Unimplemented( - "cublasGemmStridedBatchedEx with bfloat16 is not supported on cuda <= " - "11")); -#endif // CUDA_VERSION >= 11000 -} - -template <> -template <> -inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int64_t M, - int64_t N, - int64_t K, - float alpha, - const phi::dtype::bfloat16 *A, - const phi::dtype::bfloat16 *B, - float beta, - phi::dtype::bfloat16 *C, - int64_t batchCount, - int64_t strideA, - int64_t strideB) const { -#if CUDA_VERSION >= 11000 - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - int ldc = N; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - const int64_t strideC = M * N; - - float h_alpha = alpha; - float h_beta = beta; - - cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; - bool use_tensor_op_math = dev_ctx_.tensor_core_available(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False"); - if (M > INT_MAX_VALUE || N > INT_MAX_VALUE || K > INT_MAX_VALUE || - batchCount > INT_MAX_VALUE) { -#if CUDA_VERSION >= 12030 && defined(__linux__) - PADDLE_THROW(common::errors::Unimplemented( - "cublasGemmStridedBatchedEx_64 is not complete")); -#else - PADDLE_THROW(common::errors::Unimplemented( - "cublasGemmStridedBatchedEx_64 is not supported on cuda < 12.3")); -#endif // CUDA_VERSION >= 12030 - } else { - dev_ctx_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::cublasGemmStridedBatchedEx(handle, - cuTransB, - cuTransA, - static_cast(N), - static_cast(M), - static_cast(K), - &h_alpha, - B, - CUDA_R_16BF, - static_cast(ldb), - strideB, - A, - CUDA_R_16BF, - static_cast(lda), - strideA, - &h_beta, - C, - CUDA_R_16BF, - static_cast(ldc), - strideC, - static_cast(batchCount), - CUBLAS_COMPUTE_32F, - algo)); - }); - } -#else - // raise error - PADDLE_THROW(common::errors::Unimplemented( - "cublasGemmStridedBatchedEx with bfloat16 is not supported on cuda <= " - "11")); -#endif // CUDA_VERSION >= 11000 -} - -// /*** -// * Uknow bug, parameters dislocation when calling BatchedGEMM. -// * Reference: paddle github PR #45530 and #55612 -// */ -// template <> -// template <> -// inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, -// CBLAS_TRANSPOSE transB, -// int M, -// int N, -// int K, -// float16 alpha, -// const float16 *A, -// const float16 *B, -// float16 beta, -// float16 *C, -// int batchCount, -// int64_t strideA, -// int64_t strideB) const { -// // Note that cublas follows fortran order, so the order is different from -// // the cblas convention. -// int lda = (transA == CblasNoTrans) ? K : M; -// int ldb = (transB == CblasNoTrans) ? N : K; -// int ldc = N; -// cublasOperation_t cuTransA = -// (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; -// cublasOperation_t cuTransB = -// (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; -// const int64_t strideC = M * N; - -// #if CUDA_VERSION >= 9010 -// if ((FLAGS_enable_cublas_tensor_op_math && -// (std::is_same::value)) || -// std::is_same::value) { -// cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; -// bool use_tensor_op_math = dev_ctx_.tensor_core_available(); -// if (use_tensor_op_math) { -// algo = CUBLAS_GEMM_DFALT_TENSOR_OP; -// } -// VLOG(5) << "use_tensor_op_math: " -// << (use_tensor_op_math ? "True" : "False"); -// VLOG(4) << "use_half_precision_compute_type: " -// << FLAGS_gemm_use_half_precision_compute_type; - -// auto fp = std::is_same::value ? CUDA_R_32F : CUDA_R_16F; -// #if CUDA_VERSION >= 11000 -// auto compute_type = CUBLAS_COMPUTE_32F; -// #else -// auto compute_type = CUDA_R_32F; -// #endif - -// float h_alpha = static_cast(alpha); -// float h_beta = static_cast(beta); -// void *a = static_cast(&h_alpha); -// void *b = static_cast(&h_beta); -// // set ComputeType as CUDA_R_32F for fp16, for better accuracy -// if (FLAGS_gemm_use_half_precision_compute_type == true && -// std::is_same::value) { -// a = static_cast(&alpha); -// b = static_cast(&beta); -// #if CUDA_VERSION >= 11000 -// compute_type = CUBLAS_COMPUTE_16F; -// #else -// compute_type = CUDA_R_16F; -// #endif -// } - -// dev_ctx_.TensorCoreCublasCallIfAvailable( -// [&](cublasHandle_t handle) { -// PADDLE_ENFORCE_GPU_SUCCESS( -// phi::dynload::cublasGemmStridedBatchedEx(handle, -// cuTransB, -// cuTransA, -// N, -// M, -// K, -// a, -// B, -// fp, -// ldb, -// strideB, -// A, -// fp, -// lda, -// strideA, -// b, -// C, -// fp, -// ldc, -// strideC, -// batchCount, -// compute_type, -// algo)); -// }); -// } else { -// #endif // CUDA_VERSION >= 9010 - -// dev_ctx_.CublasCall( -// [&](cublasHandle_t handle) { -// CUBlas::GEMM_STRIDED_BATCH(handle, -// cuTransB, -// cuTransA, -// N, -// M, -// K, -// &alpha, -// B, -// ldb, -// strideB, -// A, -// lda, -// strideA, -// &beta, -// C, -// ldc, -// strideC, -// batchCount); -// }, -// dev_ctx_.stream()); - -// #if CUDA_VERSION >= 9010 -// } -// #endif // CUDA_VERSION >= 9010 -// } - -// /*** -// * Uknow bug, parameters dislocation when calling BatchedGEMM. -// * Reference: paddle github PR #45530 and #55612 -// */ -// template <> -// template <> -// inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, -// CBLAS_TRANSPOSE transB, -// int M, -// int N, -// int K, -// double alpha, -// const double *A, -// const double *B, -// double beta, -// double *C, -// int batchCount, -// int64_t strideA, -// int64_t strideB) const { -// // Note that cublas follows fortran order, so the order is different from -// // the cblas convention. -// int lda = (transA == CblasNoTrans) ? K : M; -// int ldb = (transB == CblasNoTrans) ? N : K; -// int ldc = N; -// cublasOperation_t cuTransA = -// (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; -// cublasOperation_t cuTransB = -// (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; -// const int64_t strideC = M * N; -// dev_ctx_.CublasCall( -// [&](cublasHandle_t handle) { -// PADDLE_ENFORCE_GPU_SUCCESS( -// phi::dynload::cublasDgemmStridedBatched(handle, -// cuTransB, -// cuTransA, -// N, -// M, -// K, -// &alpha, -// B, -// ldb, -// strideB, -// A, -// lda, -// strideA, -// &beta, -// C, -// ldc, -// strideC, -// batchCount)); -// }, -// dev_ctx_.stream()); -// } - -// template <> -// template <> -// inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, -// CBLAS_TRANSPOSE transB, -// int M, -// int N, -// int K, -// phi::dtype::bfloat16 alpha, -// const phi::dtype::bfloat16 *A, -// const phi::dtype::bfloat16 *B, -// phi::dtype::bfloat16 beta, -// phi::dtype::bfloat16 *C, -// int batchCount, -// int64_t strideA, -// int64_t strideB) const { -// #if CUDA_VERSION >= 11000 -// // Note that cublas follows fortran order, so the order is different from -// // the cblas convention. -// int lda = (transA == CblasNoTrans) ? K : M; -// int ldb = (transB == CblasNoTrans) ? N : K; -// int ldc = N; -// cublasOperation_t cuTransA = -// (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; -// cublasOperation_t cuTransB = -// (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; -// const int64_t strideC = M * N; - -// float h_alpha = static_cast(alpha); -// float h_beta = static_cast(beta); - -// cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; -// bool use_tensor_op_math = dev_ctx->tensor_core_available(); -// if (use_tensor_op_math) { -// algo = CUBLAS_GEMM_DFALT_TENSOR_OP; -// } -// VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : -// "False"); - -// dev_ctx_.TensorCoreCublasCallIfAvailable( -// [&](cublasHandle_t handle) { -// PADDLE_ENFORCE_GPU_SUCCESS( -// phi::dynload::cublasGemmStridedBatchedEx(handle, -// cuTransB, -// cuTransA, -// N, -// M, -// K, -// &h_alpha, -// B, -// CUDA_R_16BF, -// ldb, -// strideB, -// A, -// CUDA_R_16BF, -// lda, -// strideA, -// &h_beta, -// C, -// CUDA_R_16BF, -// ldc, -// strideC, -// batchCount, -// CUBLAS_COMPUTE_32F, -// algo)); -// }); -// #else -// // raise error -// PADDLE_THROW(phi::errors::Unimplemented( -// "cublasGemmStridedBatchedEx with bfloat16 is not supported on cuda <= " -// "11")); -// #endif // CUDA_VERSION >= 11000 -// } - -template <> -template -void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int M, - int N, - int K, - T alpha, - const T **A, - const T **B, - T beta, - T **C, - int batchCount) const { - for (int k = 0; k < batchCount; ++k) { - this->template GEMM( - transA, transB, M, N, K, alpha, A[k], B[k], beta, C[k]); - } -} - -#if defined(__NVCC__) -template <> -template <> -inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int M, - int N, - int K, - double alpha, - const double **A, - const double **B, - double beta, - double **C, - int batchCount) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - int ldc = N; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - thrust::device_vector A_ptr(A, A + batchCount); - thrust::device_vector B_ptr(B, B + batchCount); - thrust::device_vector C_ptr(C, C + batchCount); - - dev_ctx_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GEMM_BATCH(handle, - cuTransB, - cuTransA, - N, - M, - K, - &alpha, - B_ptr.data().get(), - ldb, - A_ptr.data().get(), - lda, - &beta, - C_ptr.data().get(), - ldc, - batchCount); - }); -} - -template <> -template <> -inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int M, - int N, - int K, - float alpha, - const float **A, - const float **B, - float beta, - float **C, - int batchCount) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - int ldc = N; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - thrust::device_vector A_ptr(A, A + batchCount); - thrust::device_vector B_ptr(B, B + batchCount); - thrust::device_vector C_ptr(C, C + batchCount); - - dev_ctx_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GEMM_BATCH(handle, - cuTransB, - cuTransA, - N, - M, - K, - &alpha, - B_ptr.data().get(), - ldb, - A_ptr.data().get(), - lda, - &beta, - C_ptr.data().get(), - ldc, - batchCount); - }); -} - -template <> -template <> -inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int M, - int N, - int K, - phi::dtype::float16 alpha, - const phi::dtype::float16 **A, - const phi::dtype::float16 **B, - phi::dtype::float16 beta, - phi::dtype::float16 **C, - int batchCount) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - int ldc = N; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - - PADDLE_ENFORCE_GE( - dev_ctx_.GetComputeCapability(), - 53, - phi::errors::InvalidArgument( - "cublas fp16 gemm requires GPU compute capability >= 53," - "but received %d", - dev_ctx_.GetComputeCapability())); - float f_alpha = static_cast(alpha); - float f_beta = static_cast(beta); - auto &cuda_ctx = const_cast(dev_ctx_); - CUBlas::GEMM_BATCH(&cuda_ctx, - cuTransB, - cuTransA, - N, - M, - K, - &f_alpha, - B, - CUDA_R_16F, - ldb, - A, - CUDA_R_16F, - lda, - &f_beta, - C, - CUDA_R_16F, - ldc, - batchCount, - CUBLAS_COMPUTE_32F); -} - -template <> -template <> -inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int M, - int N, - int K, - phi::dtype::bfloat16 alpha, - const phi::dtype::bfloat16 **A, - const phi::dtype::bfloat16 **B, - phi::dtype::bfloat16 beta, - phi::dtype::bfloat16 **C, - int batchCount) const { -#if CUDA_VERSION >= 11000 - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - int ldc = N; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - - // PADDLE_ENFORCE_GE( - // dev_ctx_.GetComputeCapability(), - // 80, - // phi::errors::InvalidArgument( - // "cublas bf16 gemm requires GPU compute capability >= 80," - // "but received %d", - // dev_ctx_.GetComputeCapability())); - - float f_alpha = static_cast(alpha); - float f_beta = static_cast(beta); - - cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; - bool use_tensor_op_math = dev_ctx_.tensor_core_available(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False"); - - thrust::device_vector A_ptr(A, A + batchCount); - thrust::device_vector B_ptr(B, B + batchCount); - thrust::device_vector C_ptr(C, C + batchCount); - dev_ctx_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::cublasGemmBatchedEx(handle, - cuTransB, - cuTransA, - N, - M, - K, - &f_alpha, - B_ptr.data().get(), - CUDA_R_16BF, - ldb, - A_ptr.data().get(), - CUDA_R_16BF, - lda, - &f_beta, - C_ptr.data().get(), - CUDA_R_16BF, - ldc, - batchCount, - CUBLAS_COMPUTE_32F, - algo)); - }); -#else - // raise error - PADDLE_THROW(phi::errors::Unimplemented( - "cublasGemmBatchedEx with bfloat16 is not supported on cuda <= 11")); - -#endif // CUDA_VERSION >= 11000 -} -#endif - -template <> -template -void Blas::TRSM(CBLAS_SIDE side, - CBLAS_UPLO uplo, - CBLAS_TRANSPOSE transA, - CBLAS_DIAG diag, - int M, - int N, - T alpha, - const T *A, - int lda, - T *B, - int ldb) const { - // solve row major `op ( A ) X = α B` by taking it as `X' op ( A' ) = α B'` - // where ' stands for transpose - cublasSideMode_t cuSide = - (side == CblasLeft) ? CUBLAS_SIDE_RIGHT : CUBLAS_SIDE_LEFT; - cublasFillMode_t cuUplo = - (uplo == CblasLower) ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER; - // use CUBLAS_OP_C (conjugate transpose) for complex - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasDiagType_t cuDiag = - (diag == CblasUnit) ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT; - - dev_ctx_.CublasCall([&](cublasHandle_t handle) { - CUBlas::TRSM( - handle, cuSide, cuUplo, cuTransA, cuDiag, N, M, &alpha, A, lda, B, ldb); - }); -} - -template <> -template -void Blas::BatchedGETRF( - int n, T **a, int *ipiv, int *info, int batch_size) const { - dev_ctx_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GETRF_BATCH(handle, n, a, n, ipiv, info, batch_size); - }); -} - -template <> -template -void Blas::BatchedGETRI(int n, - const T **a, - const int *ipiv, - T **a_inv, - int *info, - int batch_size) const { - PADDLE_ENFORCE_NE( - a_inv, - a, - phi::errors::InvalidArgument( - "cuBLAS fuction 'cublasgetrfBatched' cannot be executed " - "in-place. The memory space of output matrix (address: %p) cannot " - "overlap memory space of input matrix (address: %p).", - a_inv, - a)); - dev_ctx_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GETRI_BATCH(handle, n, a, n, ipiv, a_inv, n, info, batch_size); - }); -} - -template <> -template -void Blas::BatchedMatInv( - int n, const T **a, T **a_inv, int *info, int batch_size) const { - dev_ctx_.CublasCall([&](cublasHandle_t handle) { - CUBlas::MATINV_BATCH(handle, n, a, n, a_inv, n, info, batch_size); - }); -} - -template <> -template -void Blas::BatchedGETRS(CBLAS_TRANSPOSE trans, - int n, - int nrhs, - const T **a, - int lda, - int *ipiv, - T **b, - int ldb, - int *info, - int batch_size) const { - // use CUBLAS_OP_C (conjugate transpose) for complex - cublasOperation_t cuTrans = - (trans == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - dev_ctx_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GETRS_BATCH( - handle, cuTrans, n, nrhs, a, lda, ipiv, b, ldb, info, batch_size); - }); -} - -template <> -template -void Blas::BatchedTRSM(CBLAS_SIDE side, - CBLAS_UPLO uplo, - CBLAS_TRANSPOSE transA, - CBLAS_DIAG diag, - int M, - int N, - T alpha, - const T **A, - int lda, - T **B, - int ldb, - int batch_size) const { - // solve row major `op ( A ) X = α B` by taking it as `X' op ( A' ) = α B'` - // where ' stands for transpose - cublasSideMode_t cuSide = - (side == CblasLeft) ? CUBLAS_SIDE_RIGHT : CUBLAS_SIDE_LEFT; - cublasFillMode_t cuUplo = - (uplo == CblasLower) ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER; - // use CUBLAS_OP_C (conjugate transpose) for complex - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasDiagType_t cuDiag = - (diag == CblasUnit) ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT; - - dev_ctx_.CublasCall([&](cublasHandle_t handle) { - CUBlas::TRSM_BATCH(handle, - cuSide, - cuUplo, - cuTransA, - cuDiag, - N, - M, - &alpha, - A, - lda, - B, - ldb, - batch_size); - }); -} - -} // namespace funcs -} // namespace phi diff --git a/backends/metax_gpu/kernels/funcs/blas/blas_impl.h b/backends/metax_gpu/kernels/funcs/blas/blas_impl.h deleted file mode 100644 index cb59d73bef8..00000000000 --- a/backends/metax_gpu/kernels/funcs/blas/blas_impl.h +++ /dev/null @@ -1,2003 +0,0 @@ -// 2024 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights -// Reserved. -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once -#include -#include -#include -#include - -#include "paddle/phi/backends/cpu/cpu_context.h" -#include "paddle/phi/common/bfloat16.h" -#include "paddle/phi/common/complex.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -#define INT_MAX_VALUE 2147483647 - -namespace phi { -namespace funcs { - -namespace detail { -template -static void axpy( - int n, const T alpha, const T *x, const int incx, T *y, const int incy) { - // Y = Y + alpha * X - while (n-- > 0) { - *y += alpha * *x; - y = y + incy; - x = x + incx; - } -} -} // namespace detail - -template -struct CBlas; - -template <> -struct CBlas { - template - static void VCOPY(ARGS... args) { - PADDLE_THROW(phi::errors::Unimplemented( - "Blas VCOPY do not supported on CPU, please check your code")); - } -}; - -template <> -struct CBlas { - template - static void VCOPY(ARGS... args) { - PADDLE_THROW(phi::errors::Unimplemented( - "Blas VCOPY do not supported on CPU, please check your code")); - } -}; - -template <> -struct CBlas { - template - static void AXPY(ARGS... args) { - detail::axpy(args...); - } - - template - static void VCOPY(ARGS... args UNUSED) { - PADDLE_THROW(phi::errors::Unimplemented( - "Blas VCOPY do not supported on CPU with bfloat16," - " please check your code")); - } - - template - static void VADD(int n, - const phi::dtype::bfloat16 *x, - const phi::dtype::bfloat16 *y, - phi::dtype::bfloat16 *z) { - for (int i = 0; i < n; ++i) { - z[i] = x[i] + y[i]; - } - } - - template - static void VMUL(int n, - const phi::dtype::bfloat16 *x, - const phi::dtype::bfloat16 *y, - phi::dtype::bfloat16 *z) { - for (int i = 0; i < n; ++i) { - z[i] = x[i] * y[i]; - } - } - - template - static void VSUB(int n, - const phi::dtype::bfloat16 *x, - const phi::dtype::bfloat16 *y, - phi::dtype::bfloat16 *z) { - for (int i = 0; i < n; ++i) { - z[i] = x[i] - y[i]; - } - } -}; - -#ifdef PADDLE_WITH_MKLML -template <> -struct CBlas { - template - static void GEMM(ARGS... args) { - phi::dynload::cblas_sgemm(args...); - } - - template - static float *GEMM_ALLOC(ARGS... args) { - return phi::dynload::cblas_sgemm_alloc(args...); - } - - template - static void GEMM_PACK(ARGS... args) { - phi::dynload::cblas_sgemm_pack(args...); - } - - template - static void GEMM_COMPUTE(ARGS... args) { - phi::dynload::cblas_sgemm_compute(args...); - } - - template - static void GEMM_FREE(ARGS... args) { - phi::dynload::cblas_sgemm_free(args...); - } - -#ifdef PADDLE_WITH_LIBXSMM - template - static void SMM_GEMM(ARGS... args) { - libxsmm_sgemm(args...); - } -#endif - - template - static void AXPY(ARGS... args) { - phi::dynload::cblas_saxpy(args...); - } - - template - static void VCOPY(ARGS... args) { - phi::dynload::cblas_scopy(args...); - } - - template - static void GEMV(ARGS... args) { - phi::dynload::cblas_sgemv(args...); - } - - template - static float DOT(ARGS... args) { - return phi::dynload::cblas_sdot(args...); - } - - template - static void SCAL(ARGS... args) { - phi::dynload::cblas_sscal(args...); - } - - template - static float ASUM(ARGS... args) { - return phi::dynload::cblas_sasum(args...); - } - - template - static void GEMM_BATCH(ARGS... args) { - phi::dynload::cblas_sgemm_batch(args...); - } - - template - static void VADD(ARGS... args) { - phi::dynload::vsAdd(args...); - } - - template - static void VSUB(ARGS... args) { - phi::dynload::vsSub(args...); - } - - template - static void VMUL(ARGS... args) { - phi::dynload::vsMul(args...); - } - - template - static void VDIV(ARGS... args) { - phi::dynload::vsDiv(args...); - } - - template - static void VEXP(ARGS... args) { - phi::dynload::vsExp(args...); - } - - template - static void VSQUARE(ARGS... args) { - phi::dynload::vsSqr(args...); - } - - template - static void VPOW(ARGS... args) { - phi::dynload::vsPowx(args...); - } - - template - static void VINV(ARGS... args) { - phi::dynload::vsInv(args...); - } - - template - static void VMERF(ARGS... args) { - phi::dynload::vmsErf(args...); - } -#if !defined(_WIN32) - template - static void CSRMM(ARGS... args) { - phi::dynload::mkl_scsrmm(args...); - } -#endif - - template - static void TRSM(ARGS... args) { - phi::dynload::cblas_strsm(args...); - } -}; - -template <> -struct CBlas { - template - static void GEMM(ARGS... args) { - phi::dynload::cblas_dgemm(args...); - } - - template - static double *GEMM_ALLOC(ARGS... args) { - return phi::dynload::cblas_dgemm_alloc(args...); - } - - template - static void GEMM_PACK(ARGS... args) { - phi::dynload::cblas_dgemm_pack(args...); - } - - template - static void GEMM_COMPUTE(ARGS... args) { - phi::dynload::cblas_dgemm_compute(args...); - } - - template - static void GEMM_FREE(ARGS... args) { - phi::dynload::cblas_dgemm_free(args...); - } - -#ifdef PADDLE_WITH_LIBXSMM - template - static void SMM_GEMM(ARGS... args) { - libxsmm_dgemm(args...); - } -#endif - - template - static void AXPY(ARGS... args) { - phi::dynload::cblas_daxpy(args...); - } - - template - static void VCOPY(ARGS... args) { - phi::dynload::cblas_dcopy(args...); - } - - template - static void GEMV(ARGS... args) { - phi::dynload::cblas_dgemv(args...); - } - - template - static double DOT(ARGS... args) { - return phi::dynload::cblas_ddot(args...); - } - - template - static void SCAL(ARGS... args) { - phi::dynload::cblas_dscal(args...); - } - - template - static double ASUM(ARGS... args) { - return phi::dynload::cblas_dasum(args...); - } - - template - static void GEMM_BATCH(ARGS... args) { - phi::dynload::cblas_dgemm_batch(args...); - } - - template - static void VADD(ARGS... args) { - phi::dynload::vdAdd(args...); - } - - template - static void VSUB(ARGS... args) { - phi::dynload::vdSub(args...); - } - - template - static void VMUL(ARGS... args) { - phi::dynload::vdMul(args...); - } - - template - static void VDIV(ARGS... args) { - phi::dynload::vdDiv(args...); - } - - template - static void VEXP(ARGS... args) { - phi::dynload::vdExp(args...); - } - - template - static void VSQUARE(ARGS... args) { - phi::dynload::vdSqr(args...); - } - - template - static void VPOW(ARGS... args) { - phi::dynload::vdPowx(args...); - } - - template - static void VINV(ARGS... args) { - phi::dynload::vdInv(args...); - } - - template - static void VMERF(ARGS... args) { - phi::dynload::vmdErf(args...); - } -#if !defined(_WIN32) - template - static void CSRMM(ARGS... args) { - phi::dynload::mkl_dcsrmm(args...); - } -#endif - - template - static void TRSM(ARGS... args) { - phi::dynload::cblas_dtrsm(args...); - } -}; - -template <> -struct CBlas> { - template - static void AXPY(int n, - const phi::dtype::complex alpha, - const phi::dtype::complex *X, - const int incX, - phi::dtype::complex *Y, - const int incY) { - phi::dynload::cblas_caxpy(n, &alpha, X, incX, Y, incY); - } - - template - static void VCOPY(ARGS... args) { - phi::dynload::cblas_ccopy(args...); - } - - // the libmklml_intel.so paddle used has no vcAdd, vcSub, - // vcMul, vcDiv apis before rebuild from source - // so replace with the raw operator methods - /* - template - static void VADD(ARGS... args) { - phi::dynload::vcAdd(args...); - } - - template - static void VSUB(ARGS... args) { - phi::dynload::vcSub(args...); - } - - template - static void VMUL(ARGS... args) { - phi::dynload::vcMul(args...); - } - - template - static void VDIV(ARGS... args) { - phi::dynload::vcDiv(args...); - } - */ - - template - static void VADD(int n, - const phi::dtype::complex *a, - const phi::dtype::complex *b, - phi::dtype::complex *y) { - for (int i = 0; i < n; ++i) { - y[i] = a[i] + b[i]; - } - } - - template - static void VSUB(int n, - const phi::dtype::complex *a, - const phi::dtype::complex *b, - phi::dtype::complex *y) { - for (int i = 0; i < n; ++i) { - y[i] = a[i] - b[i]; - } - } - - template - static void VMUL(int n, - const phi::dtype::complex *a, - const phi::dtype::complex *b, - phi::dtype::complex *y) { - for (int i = 0; i < n; ++i) { - y[i] = a[i] * b[i]; - } - } - template - static void VDIV(int n, - const phi::dtype::complex *a, - const phi::dtype::complex *b, - phi::dtype::complex *y) { - for (int i = 0; i < n; ++i) { - y[i] = a[i] / b[i]; - } - } - - template - static void GEMV(CBLAS_LAYOUT layout, - CBLAS_TRANSPOSE trans, - int M, - int N, - phi::dtype::complex alpha, - const phi::dtype::complex *A, - int lda, - const phi::dtype::complex *X, - int incx, - phi::dtype::complex beta, - phi::dtype::complex *Y, - int incy) { - const void *a_ = (const void *)(A); - const void *x_ = (const void *)(X); - void *y_ = static_cast(Y); - phi::dynload::cblas_cgemv( - layout, trans, M, N, &alpha, a_, lda, x_, incx, &beta, y_, incy); - } - - template - static void GEMM(CBLAS_LAYOUT layout, - CBLAS_TRANSPOSE trans_a, - CBLAS_TRANSPOSE trans_b, - int M, - int N, - int K, - phi::dtype::complex alpha, - const phi::dtype::complex *A, - int lda, - const phi::dtype::complex *B, - int ldb, - phi::dtype::complex beta, - phi::dtype::complex *C, - int ldc) { - const void *a_ = (const void *)(A); - const void *b_ = (const void *)(B); - void *c_ = static_cast(C); - phi::dynload::cblas_cgemm(layout, - trans_a, - trans_b, - M, - N, - K, - &alpha, - a_, - lda, - b_, - ldb, - &beta, - c_, - ldc); - } - - static void TRSM(CBLAS_LAYOUT layout, - CBLAS_SIDE side, - CBLAS_UPLO uplo, - CBLAS_TRANSPOSE trans_a, - CBLAS_DIAG diag, - int M, - int N, - phi::dtype::complex alpha, - const phi::dtype::complex *A, - int lda, - phi::dtype::complex *B, - int ldb) { - const void *a_ = (const void *)(A); - void *b_ = static_cast(B); - phi::dynload::cblas_ctrsm( - layout, side, uplo, trans_a, diag, M, N, &alpha, a_, lda, b_, ldb); - } - - template - static void GEMM_BATCH(CBLAS_LAYOUT layout, - CBLAS_TRANSPOSE *trans_a, - CBLAS_TRANSPOSE *trans_b, - int *M, - int *N, - int *K, - phi::dtype::complex *alpha, - const phi::dtype::complex **A, - const int *lda, - const phi::dtype::complex **B, - const int *ldb, - phi::dtype::complex *beta, - phi::dtype::complex **C, - const int *ldc, - int group_count, - int *group_size) { - const void **A_void = (const void **)(&(*A)); - const void **B_void = (const void **)(&(*B)); - void **C_void = reinterpret_cast(C); - - phi::dynload::cblas_cgemm_batch(layout, - trans_a, - trans_b, - M, - N, - K, - alpha, - A_void, - lda, - B_void, - ldb, - beta, - C_void, - ldc, - group_count, - group_size); - } - - template - static void GEMM_EX(ARGS... args) { - phi::dynload::cblas_cgemm_batch(args...); - } -}; - -template <> -struct CBlas> { - template - static void AXPY(int n, - const phi::dtype::complex alpha, - const phi::dtype::complex *X, - const int incX, - phi::dtype::complex *Y, - const int incY) { - phi::dynload::cblas_zaxpy(n, &alpha, X, incX, Y, incY); - } - - template - static void VCOPY(ARGS... args) { - phi::dynload::cblas_zcopy(args...); - } - - // the libmklml_intel.so paddle used has no vzAdd, vzSub, - // vzMul, vzDiv apis before rebuild from source - // so replace with the raw operator methods - /* - template - static void VADD(ARGS... args) { - phi::dynload::vzAdd(args...); - } - - template - static void VSUB(ARGS... args) { - phi::dynload::vzSub(args...); - } - - template - static void VMUL(ARGS... args) { - phi::dynload::vzMul(args...); - } - - template - static void VDIV(ARGS... args) { - phi::dynload::vzDiv(args...); - } - */ - - template - static void VADD(int n, - const phi::dtype::complex *a, - const phi::dtype::complex *b, - phi::dtype::complex *y) { - for (int i = 0; i < n; ++i) { - y[i] = a[i] + b[i]; - } - } - - template - static void VSUB(int n, - const phi::dtype::complex *a, - const phi::dtype::complex *b, - phi::dtype::complex *y) { - for (int i = 0; i < n; ++i) { - y[i] = a[i] - b[i]; - } - } - - template - static void VMUL(int n, - const phi::dtype::complex *a, - const phi::dtype::complex *b, - phi::dtype::complex *y) { - for (int i = 0; i < n; ++i) { - y[i] = a[i] * b[i]; - } - } - template - static void VDIV(int n, - const phi::dtype::complex *a, - const phi::dtype::complex *b, - phi::dtype::complex *y) { - for (int i = 0; i < n; ++i) { - y[i] = a[i] / b[i]; - } - } - - template - static void GEMV(CBLAS_LAYOUT layout, - CBLAS_TRANSPOSE trans, - int M, - int N, - phi::dtype::complex alpha, - const phi::dtype::complex *A, - int lda, - const phi::dtype::complex *X, - int incx, - phi::dtype::complex beta, - phi::dtype::complex *Y, - int incy) { - const void *a_ = (const void *)(A); - const void *x_ = (const void *)(X); - void *y_ = static_cast(Y); - phi::dynload::cblas_zgemv( - layout, trans, M, N, &alpha, a_, lda, x_, incx, &beta, y_, incy); - } - - template - static void GEMM(CBLAS_LAYOUT layout, - CBLAS_TRANSPOSE trans_a, - CBLAS_TRANSPOSE trans_b, - int M, - int N, - int K, - phi::dtype::complex alpha, - const phi::dtype::complex *A, - int lda, - const phi::dtype::complex *B, - int ldb, - phi::dtype::complex beta, - phi::dtype::complex *C, - int ldc) { - const void *a_ = (const void *)(A); - const void *b_ = (const void *)(B); - void *c_ = static_cast(C); - phi::dynload::cblas_zgemm(layout, - trans_a, - trans_b, - M, - N, - K, - &alpha, - a_, - lda, - b_, - ldb, - &beta, - c_, - ldc); - } - - static void TRSM(CBLAS_LAYOUT layout, - CBLAS_SIDE side, - CBLAS_UPLO uplo, - CBLAS_TRANSPOSE trans_a, - CBLAS_DIAG diag, - int M, - int N, - phi::dtype::complex alpha, - const phi::dtype::complex *A, - int lda, - phi::dtype::complex *B, - int ldb) { - const void *a_ = (const void *)(A); - void *b_ = static_cast(B); - phi::dynload::cblas_ztrsm( - layout, side, uplo, trans_a, diag, M, N, &alpha, a_, lda, b_, ldb); - } - - template - static void GEMM_BATCH(CBLAS_LAYOUT layout, - CBLAS_TRANSPOSE *trans_a, - CBLAS_TRANSPOSE *trans_b, - int *M, - int *N, - int *K, - phi::dtype::complex *alpha, - const phi::dtype::complex **A, - const int *lda, - const phi::dtype::complex **B, - const int *ldb, - phi::dtype::complex *beta, - phi::dtype::complex **C, - const int *ldc, - int group_count, - int *group_size) { - const void **A_void = (const void **)(&(*A)); - const void **B_void = (const void **)(&(*B)); - void **C_void = reinterpret_cast(C); - - phi::dynload::cblas_zgemm_batch(layout, - trans_a, - trans_b, - M, - N, - K, - alpha, - A_void, - lda, - B_void, - ldb, - beta, - C_void, - ldc, - group_count, - group_size); - } - - template - static void GEMM_EX(ARGS... args) { - phi::dynload::cblas_zgemm_batch(args...); - } -}; - -#else - -template <> -struct CBlas { - template - static void GEMM(ARGS... args) { - cblas_sgemm(args...); - } - - template - static void AXPY(ARGS... args) { - cblas_saxpy(args...); - } - - template - static void VCOPY(ARGS... args) { - cblas_scopy(args...); - } - - template - static void GEMV(ARGS... args) { - cblas_sgemv(args...); - } - - template - static void TRSM(ARGS... args) { - cblas_strsm(args...); - } -}; - -template <> -struct CBlas { - template - static void GEMM(ARGS... args) { - cblas_dgemm(args...); - } - - template - static void AXPY(ARGS... args) { - cblas_daxpy(args...); - } - - template - static void VCOPY(ARGS... args) { - cblas_dcopy(args...); - } - - template - static void GEMV(ARGS... args) { - cblas_dgemv(args...); - } - - template - static void TRSM(ARGS... args) { - cblas_dtrsm(args...); - } -}; - -template <> -struct CBlas> { - template - static void VCOPY(ARGS... args) { - cblas_ccopy(args...); - } - - template - static void AXPY(int n, - const phi::dtype::complex alpha, - const phi::dtype::complex *X, - const int incX, - phi::dtype::complex *Y, - const int incY) { - cblas_caxpy(n, &alpha, X, incX, Y, incY); - } - - template - static void GEMV(const CBLAS_LAYOUT layout, - const CBLAS_TRANSPOSE TransA, - const int M, - const int N, - const phi::dtype::complex alpha, - const phi::dtype::complex *A, - const int lda, - const phi::dtype::complex *X, - const int incX, - const phi::dtype::complex beta, - phi::dtype::complex *Y, - const int incY) { - cblas_cgemv(layout, TransA, M, N, &alpha, A, lda, X, incX, &beta, Y, incY); - } - - template - static void GEMM(const CBLAS_LAYOUT layout, - const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, - const int M, - const int N, - const int K, - const phi::dtype::complex alpha, - const phi::dtype::complex *A, - const int lda, - const phi::dtype::complex *B, - const int ldb, - const phi::dtype::complex beta, - phi::dtype::complex *C, - const int ldc) { - cblas_cgemm( - layout, TransA, TransB, M, N, K, &alpha, A, lda, B, ldb, &beta, C, ldc); - } - - static void TRSM(const CBLAS_LAYOUT layout, - const CBLAS_SIDE side, - const CBLAS_UPLO uplo, - const CBLAS_TRANSPOSE transA, - const CBLAS_DIAG diag, - const int M, - const int N, - const phi::dtype::complex alpha, - const phi::dtype::complex *A, - const int lda, - phi::dtype::complex *B, - const int ldb) { - cblas_ctrsm(layout, side, uplo, transA, diag, M, N, &alpha, A, lda, B, ldb); - } -}; - -template <> -struct CBlas> { - template - static void VCOPY(ARGS... args) { - cblas_zcopy(args...); - } - - template - static void AXPY(int n, - const phi::dtype::complex alpha, - const phi::dtype::complex *X, - const int incX, - phi::dtype::complex *Y, - const int incY) { - cblas_zaxpy(n, &alpha, X, incX, Y, incY); - } - - template - static void GEMV(const CBLAS_LAYOUT layout, - const CBLAS_TRANSPOSE TransA, - const int M, - const int N, - const phi::dtype::complex alpha, - const phi::dtype::complex *A, - const int lda, - const phi::dtype::complex *X, - const int incX, - const phi::dtype::complex beta, - phi::dtype::complex *Y, - const int incY) { - cblas_zgemv(layout, TransA, M, N, &alpha, A, lda, X, incX, &beta, Y, incY); - } - - template - static void GEMM(const CBLAS_LAYOUT layout, - const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, - const int M, - const int N, - const int K, - const phi::dtype::complex alpha, - const phi::dtype::complex *A, - const int lda, - const phi::dtype::complex *B, - const int ldb, - const phi::dtype::complex beta, - phi::dtype::complex *C, - const int ldc) { - cblas_zgemm( - layout, TransA, TransB, M, N, K, &alpha, A, lda, B, ldb, &beta, C, ldc); - } - - static void TRSM(const CBLAS_LAYOUT layout, - const CBLAS_SIDE side, - const CBLAS_UPLO uplo, - const CBLAS_TRANSPOSE transA, - const CBLAS_DIAG diag, - const int M, - const int N, - const phi::dtype::complex alpha, - const phi::dtype::complex *A, - const int lda, - phi::dtype::complex *B, - const int ldb) { - cblas_ztrsm(layout, side, uplo, transA, diag, M, N, &alpha, A, lda, B, ldb); - } -}; - -#endif - -template <> -struct CBlas { - static void GEMM(...) { - PADDLE_THROW(phi::errors::Unimplemented( - "float16 GEMM not supported on CPU, please check your code")); - } - - static void SMM_GEMM(...) { - PADDLE_THROW(phi::errors::Unimplemented( - "float16 SMM_GEMM not supported on CPU, please check your code")); - } - static void VMUL(...) { - PADDLE_THROW(phi::errors::Unimplemented( - "float16 VMUL not supported on CPU, please check your code")); - } - static void VEXP(...) { - PADDLE_THROW(phi::errors::Unimplemented( - "float16 VEXP not supported on CPU, please check your code")); - } - static void VSQUARE(...) { - PADDLE_THROW(phi::errors::Unimplemented( - "float16 VSQUARE not supported on CPU, please check your code")); - } - static void VPOW(...) { - PADDLE_THROW(phi::errors::Unimplemented( - "float16 VPOW not supported on CPU, please check your code")); - } - static void DOT(...) { - PADDLE_THROW(phi::errors::Unimplemented( - "float16 DOT not supported on CPU, please check your code")); - }; - static void SCAL(...) { - PADDLE_THROW(phi::errors::Unimplemented( - "float16 SCAL not supported on CPU, please check your code")); - }; - static void ASUM(...) { - PADDLE_THROW(phi::errors::Unimplemented( - "float16 ASUM not supported on CPU, please check your code")); - }; -#ifdef PADDLE_WITH_MKLML - static void GEMM_BATCH(...) { - PADDLE_THROW(phi::errors::Unimplemented( - "float16 GEMM_BATCH not supported on CPU, please check your code")); - } -#endif -}; - -#ifdef PADDLE_WITH_MKLML -template <> -template -T *Blas::GEMM_ALLOC(const CBLAS_IDENTIFIER id, - const int M, - const int N, - const int K) const { - return CBlas::GEMM_ALLOC(id, M, N, K); -} - -template <> -template -void Blas::GEMM_PACK(const CBLAS_IDENTIFIER id, - const CBLAS_TRANSPOSE trans, - int M, - int N, - int K, - const T alpha, - const T *src, - const int ld, - T *dst) const { - CBlas::GEMM_PACK(CblasRowMajor, id, trans, M, N, K, alpha, src, ld, dst); -} - -template <> -template -void Blas::GEMM_COMPUTE(int transA, - int transB, - int M, - int N, - int K, - const T *A, - const int lda, - const T *B, - const int ldb, - T beta, - T *C, - const int ldc) const { - CBlas::GEMM_COMPUTE( - CblasRowMajor, transA, transB, M, N, K, A, lda, B, ldb, beta, C, ldc); -} - -template <> -template -void Blas::GEMM_FREE(T *data) const { - CBlas::GEMM_FREE(data); -} -#endif - -template <> -template -void Blas::GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int64_t M, - int64_t N, - int64_t K, - T alpha, - const T *A, - const T *B, - T beta, - T *C) const { - if (M > INT_MAX_VALUE || N > INT_MAX_VALUE || K > INT_MAX_VALUE) { - PADDLE_THROW( - common::errors::Unimplemented("GEMM not supported for large tensor " - "size on CPU, please check your code!")); - } - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - int ldc = N; - CBlas::GEMM(CblasRowMajor, - transA, - transB, - M, - N, - K, - alpha, - A, - lda, - B, - ldb, - beta, - C, - ldc); -} - -template <> -template -void Blas::GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int64_t M, - int64_t N, - int64_t K, - U alpha, - const T *A, - const T *B, - U beta, - T *C) const { - if (M > INT_MAX_VALUE || N > INT_MAX_VALUE || K > INT_MAX_VALUE) { - PADDLE_THROW( - common::errors::Unimplemented("GEMM not supported for large tensor " - "size on CPU, please check your code!")); - } - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - int ldc = N; - CBlas::GEMM(CblasRowMajor, - transA, - transB, - static_cast(M), - static_cast(N), - static_cast(K), - alpha, - A, - lda, - B, - ldb, - beta, - C, - ldc); -} - -template <> -template -void Blas::GEMM(bool transA, - bool transB, - int M, - int N, - int K, - T alpha, - const T *A, - int lda, - const T *B, - int ldb, - T beta, - T *C, - int ldc) const { - CBlas::GEMM(CblasRowMajor, - transA == false ? CblasNoTrans : CblasTrans, - transB == false ? CblasNoTrans : CblasTrans, - M, - N, - K, - alpha, - A, - lda, - B, - ldb, - beta, - C, - ldc); -} - -template <> -template -void Blas::GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int M, - int N, - int K, - T alpha, - const T *A, - int lda, - const T *B, - int ldb, - T beta, - T *C, - int ldc) const { - CBlas::GEMM(CblasRowMajor, - transA, - transB, - M, - N, - K, - alpha, - A, - lda, - B, - ldb, - beta, - C, - ldc); -} - -template -template -void Blas::MatMul(const phi::DenseTensor &mat_a, - bool trans_a, - const phi::DenseTensor &mat_b, - bool trans_b, - T alpha, - phi::DenseTensor *mat_out, - T beta) const { - const auto &dim_a = mat_a.dims(); - const auto &dim_b = mat_b.dims(); - const auto &dim_out = mat_out->dims(); - PADDLE_ENFORCE_EQ( - dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2, - true, - phi::errors::InvalidArgument( - "The input and output of matmul should be matrix, the dim size must " - "be 2," - "but received dim size input_a:%d, input_b:%d, output:%d", - dim_a.size(), - dim_b.size(), - dim_out.size())); - PADDLE_ENFORCE_EQ( - mat_a.place() == mat_b.place() && mat_a.place() == mat_out->place(), - true, - phi::errors::InvalidArgument("The places of matrices in the matmul " - "should be same, please check your " - "code.")); - - int M = dim_out[0]; - int N = dim_out[1]; - int K = !trans_a ? dim_a[1] : dim_a[0]; - - CBLAS_TRANSPOSE transA = !trans_a ? CblasNoTrans : CblasTrans; - CBLAS_TRANSPOSE transB = !trans_b ? CblasNoTrans : CblasTrans; - - this->GEMM(transA, - transB, - M, - N, - K, - alpha, - mat_a.data(), - mat_b.data(), - beta, - mat_out->data()); -} - -template <> -template -void Blas::AXPY(int n, T alpha, const T *x, T *y) const { - CBlas::AXPY(n, alpha, x, 1, y, 1); -} - -template <> -template -void Blas::VCOPY(int n, const T *x, T *y) const { - CBlas::VCOPY(n, x, 1, y, 1); -} - -template <> -template -void Blas::VADD(int n, const T *x, const T *y, T *z) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VADD(n, x, y, z); -#else - if (x == z) { - this->template AXPY(n, (T)(1.), y, z); - } else { - this->template VCOPY(n, y, z); - this->template AXPY(n, (T)(1.), x, z); - } -#endif -} - -template <> -template -void Blas::VSUB(int n, const T *x, const T *y, T *z) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VSUB(n, x, y, z); -#else - // try to find if openblas support vsub - for (int i = 0; i < n; ++i) { - z[i] = x[i] - y[i]; - } -#endif -} - -template <> -template -void Blas::VMUL(int n, const T *x, const T *y, T *z) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VMUL(n, x, y, z); -#else - // try to find if openblas support vmul - for (int i = 0; i < n; ++i) { - z[i] = x[i] * y[i]; - } -#endif -} - -template <> -template -void Blas::VDIV(int n, const T *x, const T *y, T *z) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VDIV(n, x, y, z); -#else - // try to find if openblas support vdiv - for (int i = 0; i < n; ++i) { - z[i] = x[i] / y[i]; - } -#endif -} - -template <> -template -void Blas::VEXP(int n, const T *x, T *y) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VEXP(n, x, y); -#else - // try to find if openblas support vexp - for (int i = 0; i < n; ++i) { - y[i] = std::exp(x[i]); - } -#endif -} - -template <> -template -void Blas::VSQUARE(int n, const T *x, T *y) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VSQUARE(n, x, y); -#else - for (int i = 0; i < n; ++i) { - y[i] = x[i] * x[i]; - } -#endif -} - -template <> -template -void Blas::VPOW(int n, const T *x, T a, T *y) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VPOW(n, x, a, y); -#else - for (int i = 0; i < n; ++i) { - y[i] = std::pow(x[i], a); - } -#endif -} - -template <> -template -T Blas::DOT(int n, const T *x, const T *y) const { -#ifdef PADDLE_WITH_MKLML - return CBlas::DOT(n, x, 1, y, 1); -#else - // try to find if openblas support cblas_dot - T sum = 0; - for (int i = 0; i < n; ++i) { - sum += x[i] * y[i]; - } - return sum; -#endif -} - -template <> -template -void Blas::SCAL(int n, const T a, T *x) const { -#ifdef PADDLE_WITH_MKLML - CBlas::SCAL(n, a, x, 1); -#else - // try to find if openblas support cblas_scal - for (int i = 0; i < n; ++i) { - x[i] = a * x[i]; - } -#endif -} - -template <> -template -T Blas::ASUM(int n, T *x, int inc) const { - auto sum = static_cast(0.0); -#ifdef PADDLE_WITH_MKLML - sum = CBlas::ASUM(n, x, inc); -#else - // TODO(jczaja): check if openblas does provide cblas_sasum/cblas_dasum - for (int c = 0; c < n; ++c) { - sum += x[c]; - } -#endif - return sum; -} - -template <> -template -void Blas::GEMV(bool trans_a, - int M, - int N, - T alpha, - const T *A, - const T *B, - T beta, - T *C) const { - CBLAS_TRANSPOSE transA = !trans_a ? CblasNoTrans : CblasTrans; - CBlas::GEMV(CblasRowMajor, transA, M, N, alpha, A, N, B, 1, beta, C, 1); -} - -template <> -template -void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int64_t M, - int64_t N, - int64_t K, - T alpha, - const T *A, - const T *B, - T beta, - T *C, - int64_t batchCount, - int64_t strideA, - int64_t strideB) const { - PADDLE_ENFORCE_NOT_NULL( - A, phi::errors::InvalidArgument("Pointer A should not be null.")); - PADDLE_ENFORCE_NOT_NULL( - B, phi::errors::InvalidArgument("Pointer B should not be null.")); - PADDLE_ENFORCE_NOT_NULL( - C, phi::errors::InvalidArgument("Pointer C should not be null.")); - - if (M > INT_MAX_VALUE || N > INT_MAX_VALUE || K > INT_MAX_VALUE) { - PADDLE_THROW( - common::errors::Unimplemented("CPU GEMM not supported for large tensor " - "size.")); - } - -#ifdef PADDLE_WITH_MKLML - if (batchCount > INT_MAX_VALUE) { - PADDLE_THROW(common::errors::Unimplemented( - "CPU GEMM not supported for large batch size in MKLML.")); - } - - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - int ldc = N; - auto a_array = std::vector(batchCount); - auto b_array = std::vector(batchCount); - auto c_array = std::vector(batchCount); - for (int k = 0; k < batchCount; ++k) { - a_array[k] = &A[k * strideA]; - b_array[k] = &B[k * strideB]; - c_array[k] = &C[k * M * N]; - } - - CBlas::GEMM_BATCH(CblasRowMajor, - &transA, - &transB, - reinterpret_cast(&M), - reinterpret_cast(&N), - reinterpret_cast(&K), - &alpha, - a_array.data(), - &lda, - b_array.data(), - &ldb, - &beta, - c_array.data(), - &ldc, - 1 /* group_count */, - reinterpret_cast(&batchCount)); -#else - for (int k = 0; k < batchCount; ++k) { - auto *Ak = &A[k * strideA]; - auto *Bk = &B[k * strideB]; - auto *Ck = &C[k * M * N]; - this->template GEMM(transA, - transB, - reinterpret_cast(M), - reinterpret_cast(N), - reinterpret_cast(K), - alpha, - Ak, - Bk, - beta, - Ck); - } -#endif -} - -template <> -template -void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int M, - int N, - int K, - T alpha, - const T **A, - const T **B, - T beta, - T **C, - int batchCount) const { -#ifdef PADDLE_WITH_MKLML - const int lda = (std::max)((transA == CblasNoTrans) ? K : M, 1); - const int ldb = (std::max)((transB == CblasNoTrans) ? N : K, 1); - const int ldc = (std::max)(N, 1); - CBlas::GEMM_BATCH(CblasRowMajor, - &transA, - &transB, - &M, - &N, - &K, - &alpha, - A, - &lda, - B, - &ldb, - &beta, - C, - &ldc, - 1 /* group_count */, - &batchCount); -#else - for (int k = 0; k < batchCount; ++k) { - this->template GEMM( - transA, transB, M, N, K, alpha, A[k], B[k], beta, C[k]); - } -#endif -} - -#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) && \ - !defined(PADDLE_WITH_HIP) // @{ Group Blas MKLML: BatchedGEMMWithHead -template <> -template -void Blas::BatchedGEMMWithHead(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int W1, - int H1, - int W2, - int H2, - T alpha, - const T *A, - const T *B, - T beta, - T *C, - int batchCount, - int64_t strideA, - int64_t strideB, - int64_t head_number, - bool split_b_vertical) const { - int lda = (transA == CblasNoTrans) ? W1 : H1; - int ldb = (transB == CblasNoTrans) ? W2 : H2; - auto a_array = std::vector(batchCount); - auto b_array = std::vector(batchCount); - auto c_array = std::vector(batchCount); - - if (split_b_vertical) { - int ldc = W2; - int sub_width = W2 / head_number; - - for (int i = 0; i < head_number; i++) { - int sub_matA_offset = (transA == CblasNoTrans) - ? i * (W1 / head_number) - : i * (W1 / head_number) * H1; - int sub_matB_offset = (transB == CblasNoTrans) - ? i * (W2 / head_number) - : i * (W2 / head_number) * H2; - int sub_matC_offset = i * W2 / head_number; - for (int k = 0; k < batchCount; ++k) { - a_array[k] = &A[k * strideA] + sub_matA_offset; - b_array[k] = &B[k * strideB] + sub_matB_offset; - c_array[k] = &C[k * H1 * W2] + sub_matC_offset; - } - - CBlas::GEMM_BATCH(CblasRowMajor, - &transA, - &transB, - &H1, - &sub_width, - &H2, - &alpha, - a_array.data(), - &lda, - b_array.data(), - &ldb, - &beta, - c_array.data(), - &ldc, - 1 /* group_count */, - &batchCount); - } - - } else { - PADDLE_ENFORCE_EQ( - W1, - H2, - phi::errors::InvalidArgument( - "The fisrt matrix width should be same as second matrix height," - "but received fisrt matrix width %d" - ", second matrix height %d", - W1, - H2)); - int ldc = W2 * head_number; - int sub_width = W1 / head_number; - - for (int i = 0; i < head_number; i++) { - int sub_matA_offset = (transA == CblasNoTrans) - ? i * (W1 / head_number) - : i * (W1 / head_number) * H1; - int sub_matB_offset = (transB == CblasNoTrans) - ? i * (W1 / head_number) * W2 - : i * (W1 / head_number); - int sub_matC_offset = i * W2; - for (int k = 0; k < batchCount; ++k) { - a_array[k] = &A[k * strideA] + sub_matA_offset; - b_array[k] = &B[k * strideB] + sub_matB_offset; - c_array[k] = &C[k * H1 * head_number * W2] + sub_matC_offset; - } - - CBlas::GEMM_BATCH(CblasRowMajor, - &transA, - &transB, - &H1, - &W2, - &sub_width, - &alpha, - a_array.data(), - &lda, - b_array.data(), - &ldb, - &beta, - c_array.data(), - &ldc, - 1 /* group_count */, - &batchCount); - } - } -} -#endif // @} End Group Blas MKLML: BatchedGEMMWithHead - -template -template -void Blas::MatMul( - const int M, const int N, const int K, const T *A, const T *B, T *C) const { - this->template GEMM(CblasRowMajor, - CblasNoTrans, - CblasNoTrans, - M, - N, - K, - static_cast(1), - A, - K, - B, - N, - static_cast(0), - C, - N); -} - -template <> -template -void Blas::MatMul( - const int M, const int N, const int K, const T *A, const T *B, T *C) const { -#ifdef PADDLE_WITH_LIBXSMM - // Refer to https://github.com/hfp/libxsmm/blob/master/README.md - // But the threshold is custom constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20; - - // Since the matrix is very small, - // so the unit of calculation is already very fast, - // and the if( M*N*K < LIBXSMM_THRESHOLD) would be overhead, - // use xsmm directly. - // Note: SMM use ColMajor - const char transa = 'N'; - const char transb = 'N'; - const T alpha = static_cast(1); - const T beta = static_cast(0); - CBlas::SMM_GEMM( - &transa, &transb, &N, &M, &K, &alpha, B, &N, A, &K, &beta, C, &N); - return; -#endif - - CBlas::GEMM(CblasRowMajor, - CblasNoTrans, - CblasNoTrans, - M, - N, - K, - static_cast(1), - A, - K, - B, - N, - static_cast(0), - C, - N); -} - -template -template -void Blas::MatMul(const phi::DenseTensor &mat_a, - const MatDescriptor &dim_a, - const phi::DenseTensor &mat_b, - const MatDescriptor &dim_b, - T alpha, - phi::DenseTensor *mat_out, - T beta) const { - MatMul(mat_a.data(), - dim_a, - mat_b.data(), - dim_b, - alpha, - mat_out->data(), - beta); -} - -template -template -void Blas::MatMul(const T *mat_a, - const MatDescriptor &dim_a, - const T *mat_b, - const MatDescriptor &dim_b, - T alpha, - T *mat_out, - T beta) const { - PADDLE_ENFORCE_EQ( - dim_a.width_, - dim_b.height_, - phi::errors::InvalidArgument( - "The fisrt matrix width should be same as second matrix height," - "but received fisrt matrix width %d" - ", second matrix height %d", - dim_a.width_, - dim_b.height_)); - - CBLAS_TRANSPOSE transA = !dim_a.trans_ ? CblasNoTrans : CblasTrans; - CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans; - if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) { - this->template GEMM(transA, - transB, - dim_a.height_, - dim_b.width_, - dim_a.width_, - alpha, - mat_a, - mat_b, - beta, - mat_out); - } else { - PADDLE_ENFORCE_EQ( - dim_a.batch_size_ == dim_b.batch_size_ || dim_a.batch_size_ == 0 || - dim_b.batch_size_ == 0, - true, - phi::errors::InvalidArgument( - "dim_a.batch_size should be equal to dim_b.batch_size, or " - "one of dim_a.batch_size and dim_b.batch_size should be 0. " - "But got dim_a.batch_size = %d, dim_b.batch_size = %d.", - dim_a.batch_size_, - dim_b.batch_size_)); - this->template BatchedGEMM( - transA, - transB, - dim_a.height_, - dim_b.width_, - dim_a.width_, - alpha, - mat_a, - mat_b, - beta, - mat_out, - dim_a.batch_size_ == 0 ? dim_b.batch_size_ : dim_a.batch_size_, - dim_a.stride_, - dim_b.stride_); - } -} - -#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) && \ - !defined(PADDLE_WITH_HIP) -// @{ Group Blas MKLML: MatMulWithHead -/* - * Multiple two matrixes with multiple heads - * - * A new parameter, i.e head_number is added compared to normal MatMul. - * The head_number describes the number of heads a matrix is vertically - * split. - * - * When user calls this API, the multiplication of two big matrixes is split - * into multiplication of several (head_number_) small matrixes. e.g. if Mat A - * is [3, 24] and Mat B is [24, 4], when multiple A and B with head_number as - * 4, Mat A will be split as 4 matrix of [3, 6] and Mat B will be - * (horizontally) split as 4 matrix of [6, 4]. The result of final matrix - * will be 4 matrix of [3, 4], i.e. [3, 16]. - * Another example is A is [3, 8], B is [2, 16], head_number is 4. In this - * case, A will be split as [3, 2], B will be (vertically) split as - * [2, 4]. The final result will be 4 matrix of 4 matrix of [3,4], i.e. [3, 16] - */ -template -template -void Blas::MatMulWithHead(const phi::DenseTensor &mat_a, - const MatDescriptor &dim_a, - const phi::DenseTensor &mat_b, - const MatDescriptor &dim_b, - T alpha, - int head_number, - phi::DenseTensor *mat_out, - T beta, - bool mat_b_split_vertical) const { - PADDLE_ENFORCE_EQ( - dim_a.width_ % head_number, - 0, - phi::errors::InvalidArgument( - "The first input width must be some times the head number" - "but received first input width %d" - ", head_number %d", - dim_a.width_, - head_number)); - PADDLE_ENFORCE_GE( - head_number, - 1, - phi::errors::InvalidArgument("The head number should be greater equal 1," - "but received head number %d", - head_number)); - PADDLE_ENFORCE_LE( - head_number, - dim_a.width_, - phi::errors::InvalidArgument( - "The head number should be less equal first input width," - "but received first input width %d" - ", head_number %d", - dim_a.width_, - head_number)); - CBLAS_TRANSPOSE transA = !dim_a.trans_ ? CblasNoTrans : CblasTrans; - CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans; - - if (mat_b_split_vertical) { - PADDLE_ENFORCE_EQ( - dim_b.height_, - dim_a.width_ / head_number, - phi::errors::InvalidArgument( - "The second input height should be equal than first input width," - "but received second input height %d, first input width %d", - dim_b.height_, - dim_a.width_ / head_number)); - PADDLE_ENFORCE_EQ( - dim_a.width_ % head_number, - 0, - phi::errors::InvalidArgument( - "The second input width should be some times the head number" - "but received second input width %d" - ", head_number %d", - dim_b.width_, - head_number)); - } - - if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) { - int lda = !dim_a.trans_ ? dim_a.width_ : dim_a.height_; - int ldb = !dim_b.trans_ ? dim_b.width_ : dim_b.height_; - int sub_matA_offset; - int sub_matB_offset; - int sub_matC_offset; - int sub_mat_M = dim_a.height_; - int sub_mat_N; - int sub_mat_K; - int ldc; - - for (int i = 0; i < head_number; i++) { - sub_matA_offset = dim_a.trans_ - ? i * (dim_a.width_ / head_number) * dim_a.height_ - : i * (dim_a.width_ / head_number); - if (mat_b_split_vertical) { - sub_matB_offset = dim_b.trans_ - ? i * (dim_b.width_ / head_number) * dim_b.height_ - : i * (dim_b.width_ / head_number); - sub_matC_offset = i * dim_b.width_ / head_number; - - sub_mat_N = dim_b.width_ / head_number; - sub_mat_K = dim_b.height_; - - ldc = dim_b.width_; - } else { - sub_matB_offset = - dim_b.trans_ ? i * (dim_b.height_ / head_number) - : i * (dim_b.height_ / head_number) * dim_b.width_; - sub_matC_offset = i * dim_b.width_; - - sub_mat_N = dim_b.width_; - sub_mat_K = dim_a.width_ / head_number; - - ldc = head_number * dim_b.width_; - } - - this->template GEMM(transA, - transB, - sub_mat_M, - sub_mat_N, - sub_mat_K, - alpha, - mat_a.data() + sub_matA_offset, - lda, - mat_b.data() + sub_matB_offset, - ldb, - beta, - mat_out->data() + sub_matC_offset, - ldc); - } - } else { - PADDLE_ENFORCE_EQ( - (dim_a.batch_size_ == dim_b.batch_size_ || dim_a.batch_size_ == 0 || - dim_b.batch_size_ == 0), - true, - phi::errors::InvalidArgument( - "The first input batch size should be equal than second input," - "either two input batch size is 0, but received first input batch " - "size" - " %d, second input batch size %d", - dim_a.batch_size_, - dim_b.batch_size_)); - - this->template BatchedGEMMWithHead( - transA, - transB, - dim_a.width_, - dim_a.height_, - dim_b.width_, - dim_b.height_, - alpha, - mat_a.data(), - mat_b.data(), - beta, - mat_out->data(), - dim_a.batch_size_ == 0 ? dim_b.batch_size_ : dim_a.batch_size_, - dim_a.stride_, - dim_b.stride_, - head_number, - mat_b_split_vertical); - } -} -#endif // @} End Group Blas MKLML: MatMulWithHead - -template -template -void Blas::VINV(int n, const T *a, T *y) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VINV(n, a, y); -#else - for (int i = 0; i < n; ++i) { - y[i] = 1.0 / a[i]; - } -#endif -} - -template <> -template -void Blas::VMERF(int n, const T *a, T *y, int64_t mode) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VMERF(n, a, y, mode); -#else - for (int i = 0; i < n; ++i) { - y[i] = std::erf(a[i]); - } -#endif -} - -#ifdef PADDLE_WITH_MKLML -template <> -template -void Blas::CSRMM(const char *transa, - const int *m, - const int *n, - const int *k, - const T *alpha, - const char *matdescra, - const T *val, - const int *indx, - const int *pntrb, - const int *pntre, - const T *b, - const int *ldb, - const T *beta, - T *c, - const int *ldc) const { - CBlas::CSRMM(transa, - m, - n, - k, - alpha, - matdescra, - val, - indx, - pntrb, - pntre, - b, - ldb, - beta, - c, - ldc); -} -#endif - -template <> -template -void Blas::TRSM(CBLAS_SIDE side, - CBLAS_UPLO uplo, - CBLAS_TRANSPOSE transA, - CBLAS_DIAG diag, - int M, - int N, - T alpha, - const T *A, - int lda, - T *B, - int ldb) const { - CBlas::TRSM( - CblasRowMajor, side, uplo, transA, diag, M, N, alpha, A, lda, B, ldb); -} - -} // namespace funcs -} // namespace phi diff --git a/backends/metax_gpu/kernels/funcs/blas/blaslt_gemm_search.h b/backends/metax_gpu/kernels/funcs/blas/blaslt_gemm_search.h deleted file mode 100644 index 6dcc56f8569..00000000000 --- a/backends/metax_gpu/kernels/funcs/blas/blaslt_gemm_search.h +++ /dev/null @@ -1,794 +0,0 @@ -// 2024 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights -// Reserved. -/* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include - -#include -#include -#include -#include - -#include "paddle/common/flags.h" -#include "paddle/phi/api/include/context_pool.h" -#include "paddle/phi/backends/dynload/cublasLt.h" -#include "paddle/phi/backends/gpu/gpu_info.h" -#include "paddle/phi/common/data_type.h" -#include "paddle/phi/common/memory_utils.h" -#include "paddle/phi/common/place.h" -#include "paddle/phi/core/allocator.h" -#include "paddle/phi/core/dense_tensor.h" - -COMMON_DECLARE_string(cublaslt_device_best_config); - -namespace phi { -namespace funcs { -namespace cublaslt_internal { - -const std::array split_k_candidates = {2, 3, 4, 5, 6, 8, 12, 16, 32}; - -struct CublasLtAlgoConfig { - int m; - int n; - int k; - int algo_id; - int swizzle; - int custom_option; - int tile; - int split_k_val; - int reduction_scheme; - int stages; -}; - -struct CublasLtAlgoSelectorParam { - float time{0.0}; - cublasLtMatmulAlgo_t algo; - CublasLtAlgoConfig algo_config; -}; - -inline bool compare_algo_time(const CublasLtAlgoSelectorParam& param_a, - const CublasLtAlgoSelectorParam& param_b) { - return (param_a.time < param_b.time); -} - -class CublasLtAlgoCache { - public: - static CublasLtAlgoCache& Instance() { - static CublasLtAlgoCache instance(100 /*search_times*/); - return instance; - } - - template - void RunAndMeasureAlgo(cublasLtHandle_t handle, - cublasLtMatmulDesc_t matmul_desc, - cublasLtMatrixLayout_t a_desc, - cublasLtMatrixLayout_t b_desc, - cublasLtMatrixLayout_t bias_desc, - cublasLtMatrixLayout_t c_desc, - void* alpha, - void* beta, - const InT* a, - const InT* b, - const OutT* bias, - OutT* c, - CublasLtAlgoSelectorParam& param, // NOLINT - cudaEvent_t& start_event, // NOLINT - cudaEvent_t& stop_event, // NOLINT - cudaStream_t stream) { - cublasStatus_t status; - cublasLtMatmulHeuristicResult_t heuristic_result; - status = dynload::cublasLtMatmulAlgoCheck(handle, - matmul_desc, - a_desc, - b_desc, - bias_desc, - c_desc, - ¶m.algo, - &heuristic_result); - PADDLE_ENFORCE_GPU_SUCCESS(status); - if (status != CUBLAS_STATUS_SUCCESS) { - param.time = std::numeric_limits::max(); - return; - } - size_t workspace_size = heuristic_result.workspaceSize; - auto workspace = phi::memory_utils::Alloc( - phi::GPUPlace(phi::backends::gpu::GetCurrentDeviceId()), - workspace_size, - phi::Stream(reinterpret_cast(stream))); - - PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(start_event, stream)); - int repeats = search_times_; - - for (int loop = 0; loop < repeats; loop++) { - status = dynload::cublasLtMatmul(handle, - matmul_desc, - alpha, - a, - a_desc, - b, - b_desc, - beta, - bias, - bias_desc, - c, - c_desc, - ¶m.algo, - workspace->ptr(), - workspace_size, - stream); - if (status != CUBLAS_STATUS_SUCCESS) { - param.time = std::numeric_limits::max(); - return; - } - } - - PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(stop_event, stream)); - PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); - - float time; - PADDLE_ENFORCE_GPU_SUCCESS( - cudaEventElapsedTime(&time, start_event, stop_event)); - - param.time = time / repeats; - } - - template - cublasLtMatmulAlgo_t* CublasLtAlgoSelect(cublasLtHandle_t handle, - int m, - int n, - int k, - int batch_count, - const InT* a, - const InT* b, - const OutT* bias, - OutT* c, - void* alpha, - void* beta, - cublasLtMatmulDesc_t matmul_desc, - cublasLtMatrixLayout_t a_desc, - cublasLtMatrixLayout_t b_desc, - cublasLtMatrixLayout_t bias_desc, - cublasLtMatrixLayout_t c_desc, - cublasComputeType_t compute_type, - cudaDataType_t scale_type, - cudaDataType_t a_type, - cudaDataType_t b_type, - cudaDataType_t bias_type, - cudaDataType_t c_type, - cudaStream_t stream) { - // If we don't have config file and we do not search, here return nullptr - if (!has_config_file_ && search_times_ <= 0) { - return nullptr; - } - - // VLOG(0) << "m n k: " << m << " " << n << " " << k; - - int64_t seed = 0; - std::hash hash_fn; - - HashMatmulDesc(matmul_desc, &seed, hash_fn); - HashMatrixLayoutDesc(a_desc, &seed, hash_fn); - HashMatrixLayoutDesc(b_desc, &seed, hash_fn); - HashMatrixLayoutDesc(bias_desc, &seed, hash_fn); - HashMatrixLayoutDesc(c_desc, &seed, hash_fn); - - { - std::lock_guard lock(cache_mutex_); - if (algo_caches_.count(seed)) { - VLOG(3) << "CublasLtAlgoSelect Found in cache"; - return &algo_caches_[seed]; - } - } - - if (search_configs_.empty()) { - std::ifstream infile; - std::string config_file_path = FLAGS_cublaslt_device_best_config; - infile.open(config_file_path.c_str()); - if (infile.is_open()) { - size_t workspace_size; - float time; - char comma; - while (!infile.eof()) { - CublasLtAlgoConfig search_config; - infile >> search_config.m >> comma >> search_config.k >> comma >> - search_config.n >> comma >> search_config.algo_id >> comma >> - search_config.swizzle >> comma >> search_config.custom_option >> - comma >> search_config.tile >> comma >> - search_config.split_k_val >> comma >> - search_config.reduction_scheme >> comma >> search_config.stages >> - comma >> workspace_size >> comma >> time; - search_configs_.push_back(search_config); - } - infile.close(); - VLOG(3) << "Loaded " << search_configs_.size() << " configs"; - } - } - if (!search_configs_.empty()) { - auto configure_algo = [&](const CublasLtAlgoConfig& search_config) - -> cublasLtMatmulAlgo_t* { - cublasLtMatmulAlgo_t algo; - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulAlgoInit(handle, - compute_type, - scale_type, - b_type, - a_type, - c_type, - c_type, - search_config.algo_id, - &algo)); - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulAlgoConfigSetAttribute( - &algo, - CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, - &search_config.custom_option, - sizeof(search_config.custom_option))); - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulAlgoConfigSetAttribute( - &algo, - CUBLASLT_ALGO_CONFIG_TILE_ID, - &search_config.tile, - sizeof(search_config.tile))); - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulAlgoConfigSetAttribute( - &algo, - CUBLASLT_ALGO_CONFIG_SPLITK_NUM, - &search_config.split_k_val, - sizeof(search_config.split_k_val))); - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulAlgoConfigSetAttribute( - &algo, - CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, - &search_config.swizzle, - sizeof(search_config.swizzle))); - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulAlgoConfigSetAttribute( - &algo, - CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, - &search_config.reduction_scheme, - sizeof(search_config.reduction_scheme))); - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulAlgoConfigSetAttribute( - &algo, - CUBLASLT_ALGO_CONFIG_STAGES_ID, - &search_config.stages, - sizeof(search_config.stages))); - std::lock_guard lock(cache_mutex_); - algo_caches_[seed] = algo; - return &algo_caches_[seed]; - }; - const CublasLtAlgoConfig* pre = nullptr; - for (size_t i = 0; i < search_configs_.size(); i++) { - if (search_configs_[i].n == n && search_configs_[i].k == k && - m <= search_configs_[i].m) { - return configure_algo(search_configs_[i]); - } else if (search_configs_[i].n == n && search_configs_[i].k == k && - m > search_configs_[i].m) { - if (pre == nullptr || pre->m < search_configs_[i].m) - pre = &search_configs_[i]; - } - } - if (pre != nullptr) { - // use max m in file - return configure_algo(*pre); - } - } - - // if we have cache but not found algo, and we don't want to search, - // here return nullptr - if (search_times_ <= 0) { - return nullptr; - } - - VLOG(3) << "CublasLtAlgoSelect Not Found in cache"; - - // Get Ids - // https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmulAlgoGetIds - cublasStatus_t status = CUBLAS_STATUS_SUCCESS; - int algo_ids[requested_algo_count_]; // NOLINT - - int num_algo_ids; - status = dynload::cublasLtMatmulAlgoGetIds(handle, - compute_type, - scale_type, - a_type, - b_type, - bias_type, - c_type, - requested_algo_count_, - algo_ids, - &num_algo_ids); - PADDLE_ENFORCE_GPU_SUCCESS(status); - - // Traverse all possible algo combinations - int step = 0; - int limit = 20000; - std::vector params; - - for (int idx = 0; idx < num_algo_ids; idx++) { - cublasLtMatmulAlgo_t algo; - - /* Initialize algo structure with given Algp ID */ - // https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmulAlgoInit - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulAlgoInit(handle, - compute_type, - scale_type, - a_type, - b_type, - bias_type, - c_type, - algo_ids[idx], - &algo)); - - // Query the tiles enums supported by that algo which is used to alloc - // enough space to store it - // https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmulAlgoCapGetAttribute - size_t attr_size = 0; - - int batch_support; - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulAlgoCapGetAttribute( - &algo, - CUBLASLT_ALGO_CAP_STRIDED_BATCH_SUPPORT, - &batch_support, - sizeof(batch_support), - &attr_size)); - if (batch_count > 1 && batch_support == 0) { - continue; - } - - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulAlgoCapGetAttribute( - &algo, CUBLASLT_ALGO_CAP_TILE_IDS, nullptr, 0, &attr_size)); - - int num_tiles = static_cast(attr_size / sizeof(int)); - std::vector tiles(num_tiles == 0 ? 1 : num_tiles); - if (num_tiles == 0) { - tiles[0] = CUBLASLT_MATMUL_TILE_UNDEFINED; - num_tiles = 1; - } else { - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulAlgoCapGetAttribute( - &algo, - CUBLASLT_ALGO_CAP_TILE_IDS, - tiles.data(), - sizeof(int) * num_tiles, - &attr_size)); - } - - // Query the stages enums supported by that algo (cuda must >= 11.0) - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulAlgoCapGetAttribute( - &algo, CUBLASLT_ALGO_CAP_STAGES_IDS, nullptr, 0, &attr_size)); - int num_stages = static_cast(attr_size / sizeof(int)); - std::vector stages(num_stages == 0 ? 1 : num_stages); - if (num_stages == 0) { - stages[0] = CUBLASLT_MATMUL_STAGES_UNDEFINED; - num_stages = 1; - } else { - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulAlgoCapGetAttribute( - &algo, - CUBLASLT_ALGO_CAP_STAGES_IDS, - stages.data(), - sizeof(int) * num_stages, - &attr_size)); - } - - // Retrieve Other Algo Capabilities attributes - int splitk_support, red_mask, swizzling_max, custom_option_max; - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulAlgoCapGetAttribute( - &algo, - CUBLASLT_ALGO_CAP_SPLITK_SUPPORT, - &splitk_support, - sizeof(splitk_support), - &attr_size)); - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulAlgoCapGetAttribute( - &algo, - CUBLASLT_ALGO_CAP_REDUCTION_SCHEME_MASK, - &red_mask, - sizeof(red_mask), - &attr_size)); - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulAlgoCapGetAttribute( - &algo, - CUBLASLT_ALGO_CAP_CTA_SWIZZLING_SUPPORT, - &swizzling_max, - sizeof(swizzling_max), - &attr_size)); - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulAlgoCapGetAttribute( - &algo, - CUBLASLT_ALGO_CAP_CUSTOM_OPTION_MAX, - &custom_option_max, - sizeof(custom_option_max), - &attr_size)); - - /* Loop over the different tiles */ - for (int tile_id = 0; tile_id < num_tiles && step < limit; tile_id++) { - /* Loop over different stages count */ - for (int stage_id = 0; stage_id < num_stages && step < limit; - stage_id++) { - /* Loop over the different custom option if any */ - for (int custom_option = 0; - custom_option <= custom_option_max && step < limit; - custom_option++) { - /* Loop over the CTAs swizzling support */ - for (int k = 0; k <= swizzling_max && step < limit; k++) { - int splir_k_trial = 0; - if (splitk_support) { - splir_k_trial += - sizeof(split_k_candidates) / sizeof(split_k_candidates[0]); - } - - for (int l = 0; (l < (1 + splir_k_trial)) && (step < limit); - l++) { - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulAlgoConfigSetAttribute( - &algo, - CUBLASLT_ALGO_CONFIG_TILE_ID, - &tiles[tile_id], - sizeof(tiles[tile_id]))); - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulAlgoConfigSetAttribute( - &algo, - CUBLASLT_ALGO_CONFIG_STAGES_ID, - &stages[stage_id], - sizeof(stages[stage_id]))); - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulAlgoConfigSetAttribute( - &algo, - CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, - &custom_option, - sizeof(custom_option))); - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulAlgoConfigSetAttribute( - &algo, - CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, - &k, - sizeof(k))); - int split_k_val = 1; - int reduction_scheme = CUBLASLT_REDUCTION_SCHEME_NONE; - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulAlgoConfigSetAttribute( - &algo, - CUBLASLT_ALGO_CONFIG_SPLITK_NUM, - &split_k_val, - sizeof(split_k_val))); - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulAlgoConfigSetAttribute( - &algo, - CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, - &reduction_scheme, - sizeof(int))); - if (l > 0) { // Split-K case - split_k_val = split_k_candidates[l - 1]; - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulAlgoConfigSetAttribute( - &algo, - CUBLASLT_ALGO_CONFIG_SPLITK_NUM, - &split_k_candidates[l - 1], - sizeof(split_k_candidates[l - 1]))); - for (reduction_scheme = 1; - reduction_scheme < - static_cast(CUBLASLT_REDUCTION_SCHEME_MASK) && - (step < limit); - reduction_scheme = reduction_scheme << 1) { - if (reduction_scheme & red_mask) { - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulAlgoConfigSetAttribute( - &algo, - CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, - &reduction_scheme, - sizeof(reduction_scheme))); - - cublasLtMatmulHeuristicResult_t heurResult; - status = dynload::cublasLtMatmulAlgoCheck(handle, - matmul_desc, - a_desc, - b_desc, - bias_desc, - c_desc, - &algo, - &heurResult); - if (status == CUBLAS_STATUS_SUCCESS) { - CublasLtAlgoSelectorParam param; - param.algo = algo; - param.algo_config.m = m; - param.algo_config.n = n; - param.algo_config.k = k; - param.algo_config.algo_id = algo_ids[idx]; - param.algo_config.tile = tiles[tile_id]; - param.algo_config.swizzle = k; - param.algo_config.custom_option = custom_option; - param.algo_config.split_k_val = split_k_val; - param.algo_config.reduction_scheme = reduction_scheme; - param.algo_config.stages = stages[stage_id]; - params.emplace_back(param); - step++; - } - } // end if - } - } else { - // Prepare algos - cublasLtMatmulHeuristicResult_t heurResult; - // https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmulAlgoCheck - status = dynload::cublasLtMatmulAlgoCheck(handle, - matmul_desc, - a_desc, - b_desc, - bias_desc, - c_desc, - &algo, - &heurResult); - if (status == CUBLAS_STATUS_SUCCESS) { - CublasLtAlgoSelectorParam param; - param.algo = algo; - param.algo_config.m = m; - param.algo_config.n = n; - param.algo_config.k = k; - param.algo_config.algo_id = algo_ids[idx]; - param.algo_config.tile = tiles[tile_id]; - param.algo_config.swizzle = k; - param.algo_config.custom_option = custom_option; - param.algo_config.split_k_val = split_k_val; - param.algo_config.reduction_scheme = reduction_scheme; - param.algo_config.stages = stages[stage_id]; - params.emplace_back(param); - step++; - } - } - } - } - } - } - } - } - cudaEvent_t start_event; - cudaEvent_t stop_event; - - PADDLE_ENFORCE_GPU_SUCCESS(cudaEventCreate(&start_event)); - PADDLE_ENFORCE_GPU_SUCCESS(cudaEventCreate(&stop_event)); - - if (step == 0) { - VLOG(3) << "No algo can be used"; - return nullptr; - } - - VLOG(3) << "CublasLtAlgoSelect Start testRun " << step << " " - << params.size(); - - for (int i = 0; i < step; i++) { - RunAndMeasureAlgo(handle, - matmul_desc, - a_desc, - b_desc, - bias_desc, - c_desc, - alpha, - beta, - a, - b, - bias, - c, - params[i], - start_event, - stop_event, - stream); - } - std::sort(params.begin(), params.end(), compare_algo_time); - - size_t res_id = 0; - while (params[res_id].time == 0.0) { - res_id++; - if (res_id >= params.size()) break; - } - - if (res_id >= params.size()) { - VLOG(3) << "No algo can be used"; - return nullptr; - } - - VLOG(3) << "algo selected"; - - std::lock_guard lock(cache_mutex_); - algo_caches_[seed] = params[res_id].algo; - return &algo_caches_[seed]; - } - - ~CublasLtAlgoCache() { SerializeAlgoCachesToFile(); } - - private: - std::string algo_caches_file_{"./cublaslt_algo_caches_from_paddle"}; - std::unordered_map algo_caches_; - std::vector search_configs_; - int search_times_; - static constexpr int requested_algo_count_ = 100; - std::mutex cache_mutex_; - bool has_config_file_; - - explicit CublasLtAlgoCache(int search_times) - : search_times_(search_times), has_config_file_(true) { - // Init algo_caches_ from cache file - std::ifstream infile; - infile.open(algo_caches_file_); - if (!infile.is_open()) { - has_config_file_ = false; - VLOG(3) << "No CublasLtAlgoCache file found"; - return; - } - size_t cublaslt_version = 0, real_cublaslt_version = 0; - int64_t seed = 0; - std::array algo_data; - infile >> cublaslt_version; - VLOG(1) << "cublaslt_version " << cublaslt_version; - - if (dynload::cublasLtGetCudartVersion() != cublaslt_version) { - LOG(INFO) << algo_caches_file_ - << " is not compatible with current cublaslt_version " - << real_cublaslt_version; - return; - } - - while (!infile.eof()) { - infile >> seed >> algo_data[0] >> algo_data[1] >> algo_data[2] >> - algo_data[3] >> algo_data[4] >> algo_data[5] >> algo_data[6] >> - algo_data[7]; - - for (int i = 0; i < 8; ++i) { - algo_caches_[seed].data[i] = algo_data[i]; - } - } - infile.close(); - } - - // Serialize algo_caches_ to cache file - void SerializeAlgoCachesToFile() { - if (search_times_ > 0) { - int dev; - cudaGetDevice(&dev); - if (dev == 0) { - std::ofstream outfile; - outfile.open(algo_caches_file_, std::ios::out | std::ios::trunc); - outfile << dynload::cublasLtGetCudartVersion() << std::endl; - - for (const auto& [seed, algo] : algo_caches_) { - outfile << seed << " "; - for (size_t value : algo.data) { - outfile << value << " "; - } - outfile << std::endl; - } - outfile.close(); - } - } - } - - inline int64_t RoundToNextHighPowOfTwo(int64_t n, int64_t min_val) { - n--; - n |= (n >> 1); - n |= (n >> 2); - n |= (n >> 4); - n |= (n >> 8); - n |= (n >> 16); - return std::max(min_val, (n + 1)); - } - - void HashMatmulDesc(cublasLtMatmulDesc_t desc, - int64_t* seed, - const std::hash& hash_fn) { - size_t size_to_write; - int trans_a, trans_b; - uint32_t epilogue; - // int8_t fast_accum; - - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulDescGetAttribute(desc, - CUBLASLT_MATMUL_DESC_TRANSA, - &trans_a, - sizeof(trans_a), - &size_to_write)); - HashValue(seed, hash_fn, static_cast(trans_a)); - - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulDescGetAttribute(desc, - CUBLASLT_MATMUL_DESC_TRANSB, - &trans_b, - sizeof(trans_b), - &size_to_write)); - HashValue(seed, hash_fn, static_cast(trans_b)); - - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulDescGetAttribute(desc, - CUBLASLT_MATMUL_DESC_EPILOGUE, - &epilogue, - sizeof(epilogue), - &size_to_write)); - HashValue(seed, hash_fn, static_cast(epilogue)); - - // PADDLE_ENFORCE_GPU_SUCCESS( - // dyl::cublasLtMatmulDescGetAttribute(desc, - // CUBLASLT_MATMUL_DESC_FAST_ACCUM, - // &fast_accum, - // sizeof(fast_accum), - // &size_to_write)); - // HashValue(seed, hash_fn, static_cast(fast_accum)); - } - - void HashMatrixLayoutDesc(cublasLtMatrixLayout_t desc, - int64_t* seed, - const std::hash& hash_fn) { - size_t size_to_write; - uint32_t dtype; - int32_t batch; - uint64_t row, col; - int64_t ld, batch_offset; - - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatrixLayoutGetAttribute(desc, - CUBLASLT_MATRIX_LAYOUT_TYPE, - &dtype, - sizeof(dtype), - &size_to_write)); - HashValue(seed, hash_fn, static_cast(dtype)); - - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatrixLayoutGetAttribute( - desc, - CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, - &batch, - sizeof(batch), - &size_to_write)); - HashValue(seed, hash_fn, static_cast(batch)); - - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatrixLayoutGetAttribute( - desc, CUBLASLT_MATRIX_LAYOUT_ROWS, &row, sizeof(row), &size_to_write)); - HashValue(seed, hash_fn, RoundToNextHighPowOfTwo(row, 32)); - - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatrixLayoutGetAttribute( - desc, CUBLASLT_MATRIX_LAYOUT_COLS, &col, sizeof(col), &size_to_write)); - HashValue(seed, hash_fn, RoundToNextHighPowOfTwo(col, 32)); - - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatrixLayoutGetAttribute( - desc, CUBLASLT_MATRIX_LAYOUT_LD, &ld, sizeof(ld), &size_to_write)); - HashValue(seed, hash_fn, RoundToNextHighPowOfTwo(ld, 32)); - - // PADDLE_ENFORCE_GPU_SUCCESS(dyl::cublasLtMatrixLayoutGetAttribute( - // desc, CUBLASLT_MATRIX_LAYOUT_ROWS, &row, sizeof(row), - // &size_to_write)); - // HashValue(seed, hash_fn, row); - - // PADDLE_ENFORCE_GPU_SUCCESS(dyl::cublasLtMatrixLayoutGetAttribute( - // desc, CUBLASLT_MATRIX_LAYOUT_COLS, &col, sizeof(col), - // &size_to_write)); - // HashValue(seed, hash_fn, col); - - // PADDLE_ENFORCE_GPU_SUCCESS(dyl::cublasLtMatrixLayoutGetAttribute( - // desc, CUBLASLT_MATRIX_LAYOUT_LD, &ld, sizeof(ld), &size_to_write)); - // HashValue(seed, hash_fn, ld); - - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatrixLayoutGetAttribute( - desc, - CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, - &batch_offset, - sizeof(batch_offset), - &size_to_write)); - HashValue(seed, hash_fn, static_cast(batch_offset)); - } - - void HashValue(int64_t* seed, - const std::hash& hash_fn, - int64_t value) { - *seed ^= hash_fn(value) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2); - } -}; - -} // namespace cublaslt_internal -} // namespace funcs -} // namespace phi diff --git a/backends/metax_gpu/kernels/funcs/blas/blaslt_impl.cu.h b/backends/metax_gpu/kernels/funcs/blas/blaslt_impl.cu.h deleted file mode 100755 index d98182abef3..00000000000 --- a/backends/metax_gpu/kernels/funcs/blas/blaslt_impl.cu.h +++ /dev/null @@ -1,1137 +0,0 @@ -// 2024 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights -// Reserved. -/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 && 0 - -#include // NOLINT - -#include "cuda.h" // NOLINT -#include "glog/logging.h" -// #include "paddle/phi/backends/dynload/cublasLt.h" -#include "paddle/phi/backends/gpu/cuda/cuda_helper.h" -#include "paddle/phi/common/amp_type_traits.h" -#include "paddle/phi/common/memory_utils.h" -#include "paddle/phi/core/flags.h" -#include "paddle/phi/kernels/autotune/gpu_timer.h" -#include "paddle/phi/kernels/autotune/switch_autotune.h" - -PHI_DECLARE_int64(cublaslt_exhaustive_search_times); -#endif - -namespace phi { -namespace funcs { - -#if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 && 0) - -// Set this enum according to -// https://docs.nvidia.com/cuda/cublas/index.html#cublasltepilogue-t -// While kMatmul, kMatmulGrad, kMatmulGradWithoutBias share the same -// enum value, but if all elements for MatmulPlanner->GetKey() is same, -// no matter forward or backward, they could share the same descriptor -// cache, in that the descriptor is for description of matmul operation. -enum MatmulFusedType { - kMatmul = 0, - kMatmulGrad = 1, - kMatmulGradWithoutBias = 2, - kMatmulBias = 3, - kMatmulRelu = 4, - kMatmulBiasRelu = 5, - kMatmulBiasGelu = 6, - kMatmulBiasReluWithReservedData = 7, - kMatmulBiasGeluWithReservedData = 8, - kMatmulReluGrad = 9, - kMatmulGeluGrad = 10, - kMatmulBiasGradToA = 11, - kMatmulBiasGradToB = 12 -}; - -static cublasLtEpilogue_t ConvertFusedType(MatmulFusedType fused_type) { - static std::map fused_type_map = { - {MatmulFusedType::kMatmul, CUBLASLT_EPILOGUE_DEFAULT}, - {MatmulFusedType::kMatmulGrad, CUBLASLT_EPILOGUE_DEFAULT}, - {MatmulFusedType::kMatmulGradWithoutBias, CUBLASLT_EPILOGUE_DEFAULT}, - {MatmulFusedType::kMatmulBias, CUBLASLT_EPILOGUE_BIAS}, - {MatmulFusedType::kMatmulRelu, CUBLASLT_EPILOGUE_RELU}, - {MatmulFusedType::kMatmulBiasRelu, CUBLASLT_EPILOGUE_RELU_BIAS}, - {MatmulFusedType::kMatmulBiasGelu, CUBLASLT_EPILOGUE_GELU_BIAS}, - {MatmulFusedType::kMatmulBiasReluWithReservedData, - CUBLASLT_EPILOGUE_RELU_AUX_BIAS}, - {MatmulFusedType::kMatmulBiasGeluWithReservedData, - CUBLASLT_EPILOGUE_GELU_AUX_BIAS}, - {MatmulFusedType::kMatmulReluGrad, CUBLASLT_EPILOGUE_DRELU}, - {MatmulFusedType::kMatmulGeluGrad, CUBLASLT_EPILOGUE_DGELU}, - {MatmulFusedType::kMatmulBiasGradToA, CUBLASLT_EPILOGUE_BGRADA}, - {MatmulFusedType::kMatmulBiasGradToB, CUBLASLT_EPILOGUE_BGRADB}}; - - return fused_type_map[fused_type]; -} - -enum FusedGEMMGradInType { kDX = 0, kDY = 1, kDZ = 2 }; - -template -struct FusedGEMMGradTrait; - -template <> -struct FusedGEMMGradTrait { - static constexpr auto kXGradA = FusedGEMMGradInType::kDZ; - static constexpr auto kXGradB = FusedGEMMGradInType::kDY; - static constexpr auto kXGradATrans = false; - static constexpr auto kXGradBTrans = true; - - static constexpr auto kYGradA = FusedGEMMGradInType::kDX; - static constexpr auto kYGradB = FusedGEMMGradInType::kDZ; - static constexpr auto kYGradATrans = true; - static constexpr auto kYGradBTrans = false; -}; - -template <> -struct FusedGEMMGradTrait { - static constexpr auto kXGradA = FusedGEMMGradInType::kDY; - static constexpr auto kXGradB = FusedGEMMGradInType::kDZ; - static constexpr auto kXGradATrans = false; - static constexpr auto kXGradBTrans = true; - - static constexpr auto kYGradA = FusedGEMMGradInType::kDX; - static constexpr auto kYGradB = FusedGEMMGradInType::kDZ; - static constexpr auto kYGradATrans = false; - static constexpr auto kYGradBTrans = false; -}; - -template <> -struct FusedGEMMGradTrait { - static constexpr auto kXGradA = FusedGEMMGradInType::kDZ; - static constexpr auto kXGradB = FusedGEMMGradInType::kDY; - static constexpr auto kXGradATrans = false; - static constexpr auto kXGradBTrans = false; - - static constexpr auto kYGradA = FusedGEMMGradInType::kDZ; - static constexpr auto kYGradB = FusedGEMMGradInType::kDX; - static constexpr auto kYGradATrans = true; - static constexpr auto kYGradBTrans = false; -}; - -template <> -struct FusedGEMMGradTrait { - static constexpr auto kXGradA = FusedGEMMGradInType::kDY; - static constexpr auto kXGradB = FusedGEMMGradInType::kDZ; - static constexpr auto kXGradATrans = true; - static constexpr auto kXGradBTrans = true; - - static constexpr auto kYGradA = FusedGEMMGradInType::kDZ; - static constexpr auto kYGradB = FusedGEMMGradInType::kDX; - static constexpr auto kYGradATrans = true; - static constexpr auto kYGradBTrans = true; -}; - -// To tell any matmul or fused matmul operation from each other. -struct MatmulPlanner { - public: - const void* bias{nullptr}; - void* aux_data{nullptr}; - - MatmulPlanner() {} - MatmulPlanner(const std::vector& x_dims, - const std::vector& y_dims, - const bool trans_x, - const bool trans_y, - phi::DataType dtype, - MatmulFusedType fused_type, - const void* bias_data = nullptr, - void* reserve_data = nullptr, // Commonly for ReLu bit-mask. - bool use_addto = false, - bool no_exchange = true) - : bias(bias_data), aux_data(reserve_data), fused_type_(fused_type) { - use_addto_ = use_addto; - key_ = phi::autotune::GenKey(x_dims, - y_dims, - static_cast(trans_x), - static_cast(trans_y), - static_cast(dtype), - static_cast(fused_type_), - static_cast(use_addto_), - static_cast(no_exchange)); - } - - bool UseAddTo() const { return use_addto_; } - size_t GetKey() const { return key_; } - MatmulFusedType GetFusedType() const { return fused_type_; } - - size_t GenSubKey() const { return key_; } - - private: - MatmulFusedType fused_type_; - bool use_addto_; - size_t key_; -}; - -template -cublasComputeType_t GetCudaComputeType() { - if (std::is_same::value) { - return CUBLAS_COMPUTE_64F; - } else if (std::is_same::value) { - return CUBLAS_COMPUTE_32I; - } else { - return CUBLAS_COMPUTE_32F; - } -} - -struct MatmulDescriptor { - public: - cublasLtMatmulDesc_t op_desc{nullptr}; - cublasLtMatrixLayout_t x_desc{nullptr}; - cublasLtMatrixLayout_t y_desc{nullptr}; - cublasLtMatrixLayout_t out_desc{nullptr}; - cublasLtMatmulAlgo_t* algo{nullptr}; - bool is_cached{false}; - - MatmulDescriptor() {} - MatmulDescriptor(const MatmulDescriptor& obj) { - algo = obj.algo; - x_desc = obj.x_desc; - y_desc = obj.y_desc; - op_desc = obj.op_desc; - out_desc = obj.out_desc; - is_cached = obj.is_cached; - } - - MatmulDescriptor& operator=(const MatmulDescriptor& obj) { - algo = obj.algo; - x_desc = obj.x_desc; - y_desc = obj.y_desc; - op_desc = obj.op_desc; - out_desc = obj.out_desc; - is_cached = obj.is_cached; - - return *this; - } - - ~MatmulDescriptor() PADDLE_MAY_THROW { - if (!is_cached) { - PADDLE_WARN_GPU_SUCCESS(dynload::cublasLtMatmulDescDestroy(op_desc)); - PADDLE_WARN_GPU_SUCCESS(dynload::cublasLtMatrixLayoutDestroy(y_desc)); - PADDLE_WARN_GPU_SUCCESS(dynload::cublasLtMatrixLayoutDestroy(x_desc)); - PADDLE_WARN_GPU_SUCCESS(dynload::cublasLtMatrixLayoutDestroy(out_desc)); - delete algo; - - op_desc = nullptr; - x_desc = nullptr; - y_desc = nullptr; - out_desc = nullptr; - algo = nullptr; - } - } - - // x_desc, y_desc, op_desc are allocated in heap memory. - template - void Create(const int64_t M, - const int64_t N, - const int64_t K, - const bool trans_x, - const bool trans_y, - phi::funcs::MatmulPlanner* planner, - const int batch_size = 1, - const int64_t stride_x = 0, - const int64_t stride_y = 0, - const int64_t stride_out = 0, - bool grad_for_dx = true) { - using MT = typename phi::dtype::MPTypeTrait::Type; - cudaDataType_t mat_type = phi::backends::gpu::ToCudaDataType(); - cudaDataType_t out_mat_type = phi::backends::gpu::ToCudaDataType(); - cudaDataType_t scale_type = phi::backends::gpu::ToCudaDataType(); - cublasComputeType_t compute_type = GetCudaComputeType(); - - if (std::is_same::value) { - out_mat_type = phi::backends::gpu::ToCudaDataType(); - scale_type = phi::backends::gpu::ToCudaDataType(); - } - - // Create operation descriptor; see cublasLtMatmulDescAttributes_t for - // details about defaults; just need to set the transforms for A and B - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulDescCreate(&op_desc, compute_type, scale_type)); - SetFusedEpilogueOpDescriptor(planner, trans_x, trans_y, N); - - // Create matrix descriptors - CreateMatrixLayout(&x_desc, mat_type, M, K, trans_x); - CreateMatrixLayout(&y_desc, mat_type, K, N, trans_y); - CreateMatrixLayout(&out_desc, out_mat_type, M, N, false); - - // Config batch size and stride. - if (batch_size > 1) { - SetBatchAndStride(x_desc, batch_size, stride_x); - SetBatchAndStride(y_desc, batch_size, stride_y); - SetBatchAndStride(out_desc, batch_size, stride_out); - } - } - - cublasLtMatmulAlgo_t* SetAlgo() { - // while entering this function, the desc shall be cached. - is_cached = true; - algo = new cublasLtMatmulAlgo_t; - return algo; - } - - template - void SetFusedEpiloguePtr(phi::funcs::MatmulPlanner* planner) { - if (planner->bias != nullptr) { - const T* bias_data = static_cast(planner->bias); - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulDescSetAttribute( - op_desc, - CUBLASLT_MATMUL_DESC_BIAS_POINTER, - &bias_data, - sizeof(bias_data))); - } - if (planner->aux_data != nullptr) { - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulDescSetAttribute( - op_desc, - CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, - &(planner->aux_data), - sizeof(planner->aux_data))); - } - } - - std::string GetDescResultString(std::string prefix, - bool has_algo = true) const { - std::ostringstream out; - out << prefix << " \n"; -#define GET_DESC_DATA_STRING(src) \ - do { \ - out << " " << #src << " = ["; \ - int num = sizeof((*src)) / sizeof(src->data[0]); \ - for (int i = 0; i < num; ++i) { \ - if (i == 0) { \ - out << src->data[i]; \ - } else { \ - out << ", " << src->data[i]; \ - } \ - } \ - out << "]\n"; \ - } while (0); - - if (has_algo) { - GET_DESC_DATA_STRING(algo); - } - GET_DESC_DATA_STRING(x_desc); - GET_DESC_DATA_STRING(y_desc); - GET_DESC_DATA_STRING(out_desc); - GET_DESC_DATA_STRING(op_desc); -#undef GET_DESC_DATA_STRING - return out.str(); - } - - void ExchangeXYDesc(bool no_exchange) {} - - protected: - void SetFusedEpilogueOpDescriptor(phi::funcs::MatmulPlanner* planner, - const bool trans_x, - const bool trans_y, - int64_t lead_dim) { - cublasOperation_t cublas_trans_x = trans_x ? CUBLAS_OP_T : CUBLAS_OP_N; - cublasOperation_t cublas_trans_y = trans_y ? CUBLAS_OP_T : CUBLAS_OP_N; - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulDescSetAttribute(op_desc, - CUBLASLT_MATMUL_DESC_TRANSB, - &cublas_trans_x, - sizeof(cublas_trans_x))); - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulDescSetAttribute(op_desc, - CUBLASLT_MATMUL_DESC_TRANSA, - &cublas_trans_y, - sizeof(cublas_trans_y))); - MatmulFusedType fused_type = planner->GetFusedType(); - if (fused_type != MatmulFusedType::kMatmul) { - cublasLtEpilogue_t cublaslt_fused_type = ConvertFusedType(fused_type); - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulDescSetAttribute(op_desc, - CUBLASLT_MATMUL_DESC_EPILOGUE, - &cublaslt_fused_type, - sizeof(fused_type))); - } - if (planner->aux_data) { - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulDescSetAttribute( - op_desc, - CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, - &lead_dim, - sizeof(lead_dim))); - } - } - - void CreateMatrixLayout(cublasLtMatrixLayout_t* desc, - cudaDataType type, - uint64_t rows, - uint64_t cols, - bool trans) { - if (trans) { - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatrixLayoutCreate(desc, type, rows, cols, rows)); - } else { - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatrixLayoutCreate(desc, type, cols, rows, cols)); - } - } - - void SetBatchAndStride(cublasLtMatrixLayout_t desc, - int batch_size, - int64_t stride) { - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatrixLayoutSetAttribute( - desc, - CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, - &batch_size, - sizeof(batch_size))); - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatrixLayoutSetAttribute( - desc, - CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, - &stride, - sizeof(stride))); - } -}; - -struct MatmulGradDescriptor : MatmulDescriptor { - public: - MatmulGradDescriptor() {} - - template - void Create(const int64_t M, - const int64_t N, - const int64_t K, - const bool trans_x, - const bool trans_y, - phi::funcs::MatmulPlanner* planner, - const int batch_size = 1, - int64_t stride_x = 0, - int64_t stride_y = 0, - int64_t stride_out = 0, - bool grad_for_dx = true) { - using MT = typename phi::dtype::MPTypeTrait::Type; - cudaDataType_t mat_type = phi::backends::gpu::ToCudaDataType(); - cudaDataType_t scale_type = phi::backends::gpu::ToCudaDataType(); - cublasComputeType_t compute_type = GetCudaComputeType(); - - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulDescCreate(&op_desc, compute_type, scale_type)); - this->SetFusedEpilogueOpDescriptor( - planner, trans_x, trans_y, TransX ? M : K); - - // Create operation desciriptor; see cublasLtMatmulDescAttributes_t for - // details about defaults; just need to set the transforms for A and B - this->CreateMatrixLayout(&x_desc, mat_type, N, M, true); - if (grad_for_dx) { - this->CreateMatrixLayout(&y_desc, mat_type, K, N, TransY); - this->CreateMatrixLayout( - &out_desc, phi::backends::gpu::ToCudaDataType(), M, K, TransX); - } else { - this->CreateMatrixLayout(&y_desc, mat_type, M, K, TransX); - this->CreateMatrixLayout( - &out_desc, phi::backends::gpu::ToCudaDataType(), K, N, TransY); - } - } - - void ExchangeXYDesc(bool no_exchange) { - if (no_exchange) { - return; - } - auto* temp = y_desc; - y_desc = x_desc; - x_desc = temp; - } -}; - -template -struct CublasLtBase { - public: - using MT = typename phi::dtype::MPTypeTrait::Type; - static phi::Allocator::AllocationPtr GetWorkspace(const phi::GPUContext& ctx, - size_t workspace_size) { - return phi::memory_utils::Alloc( - ctx.GetPlace(), - workspace_size, - phi::Stream(reinterpret_cast(ctx.stream()))); - } - - static void RunImpl(const phi::GPUContext& ctx, - MatmulDescT* desc, - const size_t sub_key, - const T* x_ptr, - const T* y_ptr, - OutT* out_ptr, - phi::funcs::MatmulPlanner* planner) { - MT alpha = static_cast(1); - MT beta = planner->UseAddTo() ? static_cast(1) : static_cast(0); - cublasLtHandle_t cublaslt_handle = ctx.cublaslt_handle(); - - // NOTE(limingshu): As workspace_size varies from different DL framework, - // I wonder is there any smarter idea for workspace setting, currently I - // just followed the settings from the NVIDIA colleague`s setting. - size_t workspace_size = static_cast(4) * 1024 * 1024; - phi::Allocator::AllocationPtr workspace = GetWorkspace(ctx, workspace_size); - - if (planner != nullptr) { - if (phi::autotune::AutoTuneStatus::Instance().UseAutoTune() && - (!desc->is_cached)) { - SearchBestAlgo(ctx, - cublaslt_handle, - desc, - static_cast(&alpha), - static_cast(&beta), - y_ptr, - x_ptr, - out_ptr, - workspace->ptr(), - workspace_size); - MatmulDescT* best_desc = new MatmulDescT(*desc); - VLOG(6) << best_desc->GetDescResultString( - "[Searched CublasltDescriptor] "); - - auto& cache = phi::autotune::AutoTuneCache::Instance().GetMatmul(); - cache.SetSubKey(sub_key, reinterpret_cast(best_desc)); - } - } - - VLOG(7) << desc->GetDescResultString("[Impl CublasltDescriptor] "); - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmul(cublaslt_handle, - desc->op_desc, - static_cast(&alpha), - y_ptr, - desc->y_desc, - x_ptr, - desc->x_desc, - static_cast(&beta), - out_ptr, - desc->out_desc, - out_ptr, - desc->out_desc, - desc->algo, - workspace->ptr(), - workspace_size, - ctx.stream())); - } - - static void SearchBestAlgo(const phi::GPUContext& ctx, - const cublasLtHandle_t& lt_handle, - MatmulDescT* desc, - const void* alpha, - const void* beta, - const void* y_data, - const void* x_data, - void* out_data, - void* workspace_ptr, - size_t workspace_size) { - cublasLtMatmulPreference_t preference; - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulPreferenceCreate(&preference)); - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulPreferenceSetAttribute( - preference, - CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, - &workspace_size, - sizeof(workspace_size))); - - int returned_results = 0; - constexpr int requested_algo_count = 10; - std::vector heuristic_results( - requested_algo_count); - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulAlgoGetHeuristic(lt_handle, - desc->op_desc, - desc->y_desc, - desc->x_desc, - desc->out_desc, - desc->out_desc, - preference, - requested_algo_count, - heuristic_results.data(), - &returned_results)); - PADDLE_ENFORCE_GT(returned_results, - 0, - phi::errors::Unavailable("No GEMM algorithm avaliable.")); - int best_algo_idx = -1; - if (returned_results == 1 || FLAGS_cublaslt_exhaustive_search_times <= 0) { - best_algo_idx = 0; - } else { - float min_time_cost = std::numeric_limits::max(); - for (int algo_idx = 0; algo_idx < returned_results; ++algo_idx) { - float cur_time_cost = - RunAndMeasureAlgo(ctx, - lt_handle, - desc, - alpha, - beta, - y_data, - x_data, - out_data, - workspace_ptr, - workspace_size, - &(heuristic_results[algo_idx].algo)); - VLOG(6) << "[MatmulWithCublaslt] algo[" << algo_idx - << "] time: " << cur_time_cost << " s"; - - if ((best_algo_idx == 0 && (1.05 * cur_time_cost < min_time_cost)) || - (cur_time_cost < min_time_cost)) { - best_algo_idx = algo_idx; - min_time_cost = cur_time_cost; - } - } - } - VLOG(6) << "[MatmulWithCublaslt] best_algo_idx: " << best_algo_idx; - - cublasLtMatmulAlgo_t* best_algo = desc->SetAlgo(); - *best_algo = heuristic_results[best_algo_idx].algo; - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulPreferenceDestroy(preference)); - } - - static float RunAndMeasureAlgo(const phi::GPUContext& ctx, - const cublasLtHandle_t& lt_handle, - MatmulDescT* desc, - const void* alpha, - const void* beta, - const void* y_data, - const void* x_data, - void* out_data, - void* workspace_ptr, - size_t workspace_size, - cublasLtMatmulAlgo_t* algo) { - int repeats = FLAGS_cublaslt_exhaustive_search_times; - if (repeats <= 0) { - return std::numeric_limits::max(); - } - - phi::GpuTimer timer; - float time_cost = 0.f; - const auto& stream = ctx.stream(); - - for (int i = 0; i < repeats; ++i) { - timer.Start(stream); - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmul(lt_handle, - desc->op_desc, - alpha, - y_data, - desc->y_desc, - x_data, - desc->x_desc, - beta, - out_data, - desc->out_desc, - out_data, - desc->out_desc, - algo, - workspace_ptr, - workspace_size, - stream)); - timer.Stop(stream); - ctx.Wait(); - auto time = timer.ElapsedTime(); - if (i > 0) { - // Exclude the warmup runtime. - time_cost += time; - } - } - return (time_cost / (repeats - 1)); - } -}; - -template <> -struct CublasLtBase { - public: - static phi::Allocator::AllocationPtr GetWorkspace(const phi::GPUContext& ctx, - size_t workspace_size) { - return phi::memory_utils::Alloc( - ctx.GetPlace(), - workspace_size, - phi::Stream(reinterpret_cast(ctx.stream()))); - } - - static void RunImpl(const phi::GPUContext& ctx, - MatmulDescriptor* desc, - const size_t sub_key, - const int8_t* x_ptr, - const int8_t* y_ptr, - int32_t* out_ptr, - phi::funcs::MatmulPlanner* planner) { - int32_t alpha = 1; - int32_t beta = - planner->UseAddTo() ? static_cast(1) : static_cast(0); - cublasLtHandle_t cublaslt_handle = ctx.cublaslt_handle(); - - size_t workspace_size = static_cast(4) * 1024 * 1024; - phi::Allocator::AllocationPtr workspace = GetWorkspace(ctx, workspace_size); - - if (planner != nullptr) { - if (phi::autotune::AutoTuneStatus::Instance().UseAutoTune() && - (!desc->is_cached)) { - SearchBestAlgo(ctx, - cublaslt_handle, - desc, - static_cast(&alpha), - static_cast(&beta), - y_ptr, - x_ptr, - out_ptr, - workspace->ptr(), - workspace_size); - MatmulDescriptor* best_desc = new MatmulDescriptor(*desc); - VLOG(6) << best_desc->GetDescResultString( - "[Searched CublasltDescriptor] "); - - auto& cache = phi::autotune::AutoTuneCache::Instance().GetMatmul(); - cache.SetSubKey(sub_key, reinterpret_cast(best_desc)); - } - } - - VLOG(7) << desc->GetDescResultString("[Impl CublasltDescriptor] "); - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmul(cublaslt_handle, - desc->op_desc, - static_cast(&alpha), - y_ptr, - desc->y_desc, - x_ptr, - desc->x_desc, - static_cast(&beta), - out_ptr, - desc->out_desc, - out_ptr, - desc->out_desc, - desc->algo, - workspace->ptr(), - workspace_size, - ctx.stream())); - } - - static void SearchBestAlgo(const phi::GPUContext& ctx, - const cublasLtHandle_t& lt_handle, - MatmulDescriptor* desc, - const void* alpha, - const void* beta, - const void* y_data, - const void* x_data, - void* out_data, - void* workspace_ptr, - size_t workspace_size) { - cublasLtMatmulPreference_t preference; - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulPreferenceCreate(&preference)); - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulPreferenceSetAttribute( - preference, - CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, - &workspace_size, - sizeof(workspace_size))); - - int returned_results = 0; - constexpr int requested_algo_count = 10; - std::vector heuristic_results( - requested_algo_count); - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulAlgoGetHeuristic(lt_handle, - desc->op_desc, - desc->y_desc, - desc->x_desc, - desc->out_desc, - desc->out_desc, - preference, - requested_algo_count, - heuristic_results.data(), - &returned_results)); - PADDLE_ENFORCE_GT(returned_results, - 0, - phi::errors::Unavailable("No GEMM algorithm avaliable.")); - int best_algo_idx = -1; - if (returned_results == 1 || FLAGS_cublaslt_exhaustive_search_times <= 0) { - best_algo_idx = 0; - } else { - float min_time_cost = std::numeric_limits::max(); - for (int algo_idx = 0; algo_idx < returned_results; ++algo_idx) { - float cur_time_cost = - RunAndMeasureAlgo(ctx, - lt_handle, - desc, - alpha, - beta, - y_data, - x_data, - out_data, - workspace_ptr, - workspace_size, - &(heuristic_results[algo_idx].algo)); - VLOG(6) << "[MatmulWithCublaslt] algo[" << algo_idx - << "] time: " << cur_time_cost << " s"; - - if ((best_algo_idx == 0 && (1.05 * cur_time_cost < min_time_cost)) || - (cur_time_cost < min_time_cost)) { - best_algo_idx = algo_idx; - min_time_cost = cur_time_cost; - } - } - } - VLOG(6) << "[MatmulWithCublaslt] best_algo_idx: " << best_algo_idx; - - cublasLtMatmulAlgo_t* best_algo = desc->SetAlgo(); - *best_algo = heuristic_results[best_algo_idx].algo; - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulPreferenceDestroy(preference)); - } - - static float RunAndMeasureAlgo(const phi::GPUContext& ctx, - const cublasLtHandle_t& lt_handle, - MatmulDescriptor* desc, - const void* alpha, - const void* beta, - const void* y_data, - const void* x_data, - void* out_data, - void* workspace_ptr, - size_t workspace_size, - cublasLtMatmulAlgo_t* algo) { - int repeats = FLAGS_cublaslt_exhaustive_search_times; - if (repeats <= 0) { - return std::numeric_limits::max(); - } - - phi::GpuTimer timer; - float time_cost = 0.f; - const auto& stream = ctx.stream(); - - for (int i = 0; i < repeats; ++i) { - timer.Start(stream); - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmul(lt_handle, - desc->op_desc, - alpha, - y_data, - desc->y_desc, - x_data, - desc->x_desc, - beta, - out_data, - desc->out_desc, - out_data, - desc->out_desc, - algo, - workspace_ptr, - workspace_size, - stream)); - timer.Stop(stream); - ctx.Wait(); - auto time = timer.ElapsedTime(); - if (i > 0) { - // Exclude the warmup runtime. - time_cost += time; - } - } - return (time_cost / (repeats - 1)); - } -}; - -// To judge if desc is cached or not. -template -struct DescriptorSetter { - public: - DescT desc; - size_t sub_key{std::numeric_limits::min()}; - - DescriptorSetter(phi::funcs::MatmulPlanner* planner, - const int64_t M, - const int64_t N, - const int64_t K, - const bool trans_x, - const bool trans_y, - const int batch_size = 1, - int64_t stride_x = 0, - int64_t stride_y = 0, - int64_t stride_out = 0, - const bool no_exchange = true, - bool grad_for_dx = true) { - if (std::is_same::value) { - if (!trans_x && !trans_y) { - PADDLE_ENFORCE_EQ( - (N % 4 == 0 || N == 1), - true, - phi::errors::InvalidArgument( - "The dimension size N used in int8 matmul must be 1 or a " - "multiple of 4 does not " - "match the size (%d) currently contained in the container.", - N)); - PADDLE_ENFORCE_EQ( - (K % 4 == 0), - true, - phi::errors::InvalidArgument( - "The dimension size K used in int8 matmul must be a multiple " - "of 4 does not " - "match the size (%d) currently contained in the container.", - K)); - } else if (!trans_x && trans_y) { - PADDLE_ENFORCE_EQ( - (K % 4 == 0), - true, - phi::errors::InvalidArgument( - "The dimension size K used in int8 matmul must be a multiple " - "of 4 does not " - "match the size (%d) currently contained in the container.", - K)); - } else if (trans_x && !trans_y) { - PADDLE_ENFORCE_EQ( - (M % 4 == 0 || M == 1), - true, - phi::errors::InvalidArgument( - "The dimension size M used in int8 matmul must be 1 or a " - "multiple of 4 does not " - "match the size (%d) currently contained in the container.", - M)); - PADDLE_ENFORCE_EQ( - (N % 4 == 0 || N == 1), - true, - phi::errors::InvalidArgument( - "The dimension size N used in int8 matmul must be 1 or a " - "multiple of 4 does not " - "match the size (%d) currently contained in the container.", - N)); - } else { - PADDLE_ENFORCE_EQ( - (M % 4 == 0 || M == 1), - true, - phi::errors::InvalidArgument( - "The dimension size M used in int8 matmul must be 1 or a " - "multiple of 4 does not " - "match the size (%d) currently contained in the container.", - M)); - PADDLE_ENFORCE_EQ( - (K % 4 == 0), - true, - phi::errors::InvalidArgument( - "The dimension size K used in int8 matmul must be a multiple " - "of 4 does not " - "match the size (%d) currently contained in the container.", - K)); - } - } - - if (planner != nullptr) { - sub_key = planner->GenSubKey(); - } - - auto& mamtul_cache = phi::autotune::AutoTuneCache::Instance().GetMatmul(); - if (mamtul_cache.FindSubKey(sub_key)) { - desc = *(reinterpret_cast(mamtul_cache.GetSubKey(sub_key))); - desc.template SetFusedEpiloguePtr(planner); - VLOG(7) << desc.GetDescResultString("[Heap CublasltDescriptor] "); - } else { - desc.template Create(M, - N, - K, - trans_x, - trans_y, - planner, - batch_size, - stride_x, - stride_y, - stride_out, - grad_for_dx); - desc.ExchangeXYDesc(no_exchange); - if (planner != nullptr) { - desc.template SetFusedEpiloguePtr(planner); - } - VLOG(7) << desc.GetDescResultString("[Stack CublasltDescriptor] ", false); - } - } -}; - -// For matmul with kernels autotune -template -struct MatmulWithCublasLt : public CublasLtBase { - public: - static void Run(const phi::GPUContext& ctx, - const T* x_data, - const T* y_data, - OutT* out_data, - const int64_t M, - const int64_t N, - const int64_t K, - const bool trans_x, - const bool trans_y, - phi::funcs::MatmulPlanner* planner = nullptr) { - auto setter = DescriptorSetter( - planner, M, N, K, trans_x, trans_y); - CublasLtBase::RunImpl( - ctx, &setter.desc, setter.sub_key, x_data, y_data, out_data, planner); - } - - static void RunWithBatch(const phi::GPUContext& ctx, - const T* x_data, - const T* y_data, - OutT* out_data, - const int64_t M, - const int64_t N, - const int64_t K, - bool trans_x, - bool trans_y, - int batch_size, - int64_t stride_x, - int64_t stride_y, - int64_t stride_out, - phi::funcs::MatmulPlanner* planner = nullptr) { - auto setter = DescriptorSetter(planner, - M, - N, - K, - trans_x, - trans_y, - batch_size, - stride_x, - stride_y, - stride_out); - CublasLtBase::RunImpl( - ctx, &setter.desc, setter.sub_key, x_data, y_data, out_data, planner); - } - - static void RunWithBatch(const phi::GPUContext& ctx, - const T** x_data, - const T** y_data, - OutT** out_data, - const int64_t M, - const int64_t N, - const int64_t K, - bool trans_x, - bool trans_y, - int batch_size, - phi::funcs::MatmulPlanner* planner = nullptr) { - for (int i = 0; i < batch_size; ++i) { - Run(ctx, - x_data[i], - y_data[i], - out_data[i], - M, - N, - K, - trans_x, - trans_y, - planner); - } - } -}; - -// As for just Linear fused ephilogue below: out = matmul(x, y) + bias. -template -struct LinearWithCublasLt : public CublasLtBase { - static void Run(const phi::GPUContext& ctx, - const phi::DenseTensor* x, - const phi::DenseTensor* y, - phi::DenseTensor* out, - const void* bias_data, - void* reserve_data, - const int64_t M, - const int64_t N, - const int64_t K, - const bool trans_x, - const bool trans_y, - const MatmulFusedType fused_type) { - auto planner = phi::funcs::MatmulPlanner(common::vectorize(x->dims()), - common::vectorize(y->dims()), - trans_x, - trans_y, - phi::CppTypeToDataType::Type(), - fused_type, - bias_data, - reserve_data); - auto setter = DescriptorSetter( - &planner, M, N, K, trans_x, trans_y); - CublasLtBase::RunImpl(ctx, - &setter.desc, - setter.sub_key, - x->data(), - y->data(), - out->data(), - &planner); - } -}; - -template -struct LinearGradWithCublasLt : public CublasLtBase { - static void Run( - const phi::GPUContext& ctx, - const phi::DenseTensor* x, - const phi::DenseTensor* y, - phi::DenseTensor* out, - const void* bias_data, - void* reserve_data, - const int64_t M, - const int64_t N, - const int64_t K, - const MatmulFusedType fused_type, - const bool trans_x, - const bool trans_y, - const bool use_addto, - const bool no_exchange, // exchange x_desc and y_desc for grad. - bool grad_for_dx = true) { - auto planner = phi::funcs::MatmulPlanner(common::vectorize(x->dims()), - common::vectorize(y->dims()), - trans_x, - trans_y, - phi::CppTypeToDataType::Type(), - fused_type, - bias_data, - reserve_data, - use_addto, - no_exchange); - auto setter = - DescriptorSetter( - &planner, - M, - N, - K, - trans_x, - trans_y, - /*batch_size=*/1, - /*stride_x=*/0, - /*stride_y=*/0, - /*stride_out=*/0, - /*exchange_x_y_desc=*/no_exchange, - /*grad_for_dx=*/grad_for_dx); - - // To setting data type for different kinda out_data. - if (grad_for_dx) { - CublasLtBase::RunImpl( - ctx, - &setter.desc, - setter.sub_key, - no_exchange ? x->data() : y->data(), - no_exchange ? y->data() : x->data(), - out->data(), - &planner); - } else { - CublasLtBase::RunImpl( - ctx, - &setter.desc, - setter.sub_key, - no_exchange ? x->data() : y->data(), - no_exchange ? y->data() : x->data(), - out->data(), - &planner); - } - } -}; -#else -// A void structure just for successfully compile. -struct MatmulPlanner {}; -#endif // (PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 && 0 - -} // namespace funcs -} // namespace phi diff --git a/backends/metax_gpu/kernels/funcs/blas/cublas.cc b/backends/metax_gpu/kernels/funcs/blas/cublas.cc deleted file mode 100644 index 77a0cced00b..00000000000 --- a/backends/metax_gpu/kernels/funcs/blas/cublas.cc +++ /dev/null @@ -1,40 +0,0 @@ -// 2024 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights -// Reserved. -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "cublas.h" // NOLINT - -namespace phi { -namespace dynload { -std::once_flag cublas_dso_flag; -void *cublas_dso_handle = nullptr; - -#define DEFINE_WRAP(__name) DynLoad__##__name __name - -CUBLAS_BLAS_ROUTINE_EACH(DEFINE_WRAP); - -#ifdef CUBLAS_BLAS_ROUTINE_EACH_R2 -CUBLAS_BLAS_ROUTINE_EACH_R2(DEFINE_WRAP); -#endif - -#ifdef CUBLAS_BLAS_ROUTINE_EACH_R3 -CUBLAS_BLAS_ROUTINE_EACH_R3(DEFINE_WRAP); -#endif - -#ifdef CUBLAS_BLAS_ROUTINE_EACH_R4 -CUBLAS_BLAS_ROUTINE_EACH_R4(DEFINE_WRAP); -#endif -} // namespace dynload -} // namespace phi diff --git a/backends/metax_gpu/kernels/funcs/blas/cublas.h b/backends/metax_gpu/kernels/funcs/blas/cublas.h deleted file mode 100755 index 776c7a1723b..00000000000 --- a/backends/metax_gpu/kernels/funcs/blas/cublas.h +++ /dev/null @@ -1,148 +0,0 @@ -// 2024 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights -// Reserved. -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -// clang-format off -#pragma once -#include -#include - -#include // NOLINT -#include - -#include "kernels/dynload/dynamic_loader.h" -#include "./port.h" // NOLINT -// clang-format on -namespace phi { -namespace dynload { - -extern std::once_flag cublas_dso_flag; -extern void* cublas_dso_handle; - -/** - * The following macro definition can generate structs - * (for each function) to dynamic load cublas routine - * via operator overloading. - * - * note: default dynamic linked libs - */ -#define DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(__name) \ - struct DynLoad__##__name { \ - template \ - inline auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \ - using cublas_func = \ - decltype(::__name(std::declval()...)) (*)(Args...); \ - std::call_once(cublas_dso_flag, []() { \ - cublas_dso_handle = phi::dynload::GetCublasDsoHandle(); \ - }); \ - std::string replaced_name = #__name; \ - replaced_name = replaced_name.replace(0, 2, "mc"); \ - int index = replaced_name.find("_", 0); \ - if (index != -1) replaced_name = replaced_name.substr(0, index); \ - static void* p_##__name = \ - dlsym(cublas_dso_handle, replaced_name.c_str()); \ - return reinterpret_cast(p_##__name)(args...); \ - } \ - }; \ - extern DynLoad__##__name __name - -#define CUBLAS_BLAS_ROUTINE_EACH(__macro) \ - __macro(cublasSaxpy_v2); \ - __macro(cublasDaxpy_v2); \ - __macro(cublasCaxpy_v2); \ - __macro(cublasZaxpy_v2); \ - __macro(cublasSscal_v2); \ - __macro(cublasDscal_v2); \ - __macro(cublasScopy_v2); \ - __macro(cublasDcopy_v2); \ - __macro(cublasSgemv_v2); \ - __macro(cublasDgemv_v2); \ - __macro(cublasCgemv_v2); \ - __macro(cublasZgemv_v2); \ - __macro(cublasSgemm_v2); \ - __macro(cublasDgemm_v2); \ - __macro(cublasCgemm_v2); \ - __macro(cublasZgemm_v2); \ - __macro(cublasHgemm); \ - __macro(cublasSgemmEx); \ - __macro(cublasSgeam); \ - __macro(cublasDgeam); \ - __macro(cublasStrsm_v2); \ - __macro(cublasDtrsm_v2); \ - __macro(cublasCtrsm_v2); \ - __macro(cublasZtrsm_v2); \ - __macro(cublasCreate_v2); \ - __macro(cublasDestroy_v2); \ - __macro(cublasSetStream_v2); \ - __macro(cublasSetPointerMode_v2); \ - __macro(cublasGetPointerMode_v2); \ - __macro(cublasSgemmBatched); \ - __macro(cublasDgemmBatched); \ - __macro(cublasCgemmBatched); \ - __macro(cublasZgemmBatched); \ - __macro(cublasStrsmBatched); \ - __macro(cublasDtrsmBatched); \ - __macro(cublasCtrsmBatched); \ - __macro(cublasZtrsmBatched); \ - __macro(cublasSgetrfBatched); \ - __macro(cublasSgetriBatched); \ - __macro(cublasDgetrfBatched); \ - __macro(cublasDgetriBatched); \ - __macro(cublasSmatinvBatched); \ - __macro(cublasDmatinvBatched); \ - __macro(cublasSgetrsBatched); \ - __macro(cublasDgetrsBatched); \ - __macro(cublasCgetrfBatched); \ - __macro(cublasCgetriBatched); \ - __macro(cublasCmatinvBatched); \ - __macro(cublasZgetrfBatched); \ - __macro(cublasZgetriBatched); \ - __macro(cublasZmatinvBatched); - -CUBLAS_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP) - -// APIs available after CUDA 8.0 -#if CUDA_VERSION >= 8000 -#define CUBLAS_BLAS_ROUTINE_EACH_R2(__macro) \ - __macro(cublasGemmEx); \ - __macro(cublasSgemmStridedBatched); \ - __macro(cublasDgemmStridedBatched); \ - __macro(cublasCgemmStridedBatched); \ - __macro(cublasZgemmStridedBatched); \ - __macro(cublasHgemmStridedBatched); - -CUBLAS_BLAS_ROUTINE_EACH_R2(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP) -#endif - -// APIs available after CUDA 9.0 -#if CUDA_VERSION >= 9000 -#define CUBLAS_BLAS_ROUTINE_EACH_R3(__macro) \ - __macro(cublasSetMathMode); \ - __macro(cublasGetMathMode); - -CUBLAS_BLAS_ROUTINE_EACH_R3(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP) -#endif - -// APIs available after CUDA 9.1 -#if CUDA_VERSION >= 9010 -#define CUBLAS_BLAS_ROUTINE_EACH_R4(__macro) \ - __macro(cublasGemmBatchedEx); \ - __macro(cublasGemmStridedBatchedEx); - -CUBLAS_BLAS_ROUTINE_EACH_R4(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP) -#endif - -#undef DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP -} // namespace dynload -} // namespace phi diff --git a/backends/metax_gpu/kernels/funcs/blas/cublasLt.cc b/backends/metax_gpu/kernels/funcs/blas/cublasLt.cc deleted file mode 100644 index 776f7fdd812..00000000000 --- a/backends/metax_gpu/kernels/funcs/blas/cublasLt.cc +++ /dev/null @@ -1,27 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "cublasLt.h" - -namespace phi { -namespace dynload { -std::once_flag cublasLt_dso_flag; -void *cublasLt_dso_handle = nullptr; - -#define DEFINE_WRAP(__name) DynLoad__##__name __name - -CUBLASLT_BLAS_ROUTINE_EACH(DEFINE_WRAP); - -} // namespace dynload -} // namespace phi diff --git a/backends/metax_gpu/kernels/funcs/blas/cublasLt.h b/backends/metax_gpu/kernels/funcs/blas/cublasLt.h deleted file mode 100644 index 2f8a929dd0c..00000000000 --- a/backends/metax_gpu/kernels/funcs/blas/cublasLt.h +++ /dev/null @@ -1,115 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -Copyright (c) 2022 NVIDIA Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include - -#include // NOLINT -#include - -#include "./port.h" -#include "kernels/dynload/dynamic_loader.h" - -namespace phi { -namespace dynload { - -extern std::once_flag cublasLt_dso_flag; -extern void* cublasLt_dso_handle; - -/** - * The following macro definition can generate structs - * (for each function) to dynamic load cublasLt routine - * via operator overloading. - * - * note: default dynamic linked libs - */ -#define DECLARE_DYNAMIC_LOAD_CUBLASLT_WRAP(__name) \ - struct DynLoad__##__name { \ - template \ - inline auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \ - using cublasLt_func = \ - decltype(::__name(std::declval()...)) (*)(Args...); \ - std::call_once(cublasLt_dso_flag, []() { \ - cublasLt_dso_handle = phi::dynload::GetCublasLtDsoHandle(); \ - }); \ - std::string replaced_name = #__name; \ - replaced_name = replaced_name.replace(0, 2, "mc"); \ - static void* p_##__name = \ - dlsym(cublasLt_dso_handle, replaced_name.c_str()); \ - return reinterpret_cast(p_##__name)(args...); \ - } \ - }; \ - extern DynLoad__##__name __name - -// APIs available after CUDA 11.1 -#if CUDA_VERSION >= 11010 -#define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \ - __macro(cublasLtCreate); \ - __macro(cublasLtDestroy); \ - __macro(cublasLtMatmul); \ - __macro(cublasLtMatmulDescCreate); \ - __macro(cublasLtMatmulDescDestroy); \ - __macro(cublasLtMatmulDescSetAttribute); \ - __macro(cublasLtMatmulDescGetAttribute); \ - __macro(cublasLtMatrixLayoutCreate); \ - __macro(cublasLtMatrixLayoutDestroy); \ - __macro(cublasLtMatrixLayoutSetAttribute); \ - __macro(cublasLtMatrixLayoutGetAttribute); \ - __macro(cublasLtMatmulPreferenceCreate); \ - __macro(cublasLtMatmulPreferenceDestroy); \ - __macro(cublasLtMatmulPreferenceSetAttribute); \ - __macro(cublasLtMatmulAlgoGetHeuristic); \ - __macro(cublasLtMatrixTransform); \ - __macro(cublasLtMatrixTransformDescCreate); \ - __macro(cublasLtMatrixTransformDescDestroy); \ - __macro(cublasLtMatrixTransformDescSetAttribute); \ - __macro(cublasLtMatmulAlgoInit); \ - __macro(cublasLtMatmulAlgoConfigSetAttribute); \ - __macro(cublasLtMatmulAlgoConfigGetAttribute); \ - __macro(cublasLtMatmulAlgoGetIds); \ - __macro(cublasLtMatmulAlgoCapGetAttribute); \ - __macro(cublasLtMatmulAlgoCheck); -// __macro(cublasLtGetCudartVersion); -#else -#define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \ - __macro(cublasLtCreate); \ - __macro(cublasLtDestroy); \ - __macro(cublasLtMatmul); \ - __macro(cublasLtMatmulDescCreate); \ - __macro(cublasLtMatmulDescDestroy); \ - __macro(cublasLtMatmulDescSetAttribute); \ - __macro(cublasLtMatmulDescGetAttribute); \ - __macro(cublasLtMatrixLayoutCreate); \ - __macro(cublasLtMatrixLayoutDestroy); \ - __macro(cublasLtMatrixLayoutSetAttribute); \ - __macro(cublasLtMatrixLayoutGetAttribute); \ - __macro(cublasLtMatmulPreferenceCreate); \ - __macro(cublasLtMatmulPreferenceDestroy); \ - __macro(cublasLtMatmulPreferenceSetAttribute); \ - __macro(cublasLtMatmulAlgoGetHeuristic); \ - __macro(cublasLtMatrixTransform); \ - __macro(cublasLtMatrixTransformDescCreate); \ - __macro(cublasLtMatrixTransformDescDestroy); \ - __macro(cublasLtMatrixTransformDescSetAttribute); -#endif - -CUBLASLT_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLASLT_WRAP) -// #endif - -#undef DECLARE_DYNAMIC_LOAD_CUBLASLT_WRAP -} // namespace dynload -} // namespace phi diff --git a/backends/metax_gpu/kernels/funcs/blas/cublaslt.h b/backends/metax_gpu/kernels/funcs/blas/cublaslt.h deleted file mode 100755 index 24505567baf..00000000000 --- a/backends/metax_gpu/kernels/funcs/blas/cublaslt.h +++ /dev/null @@ -1,328 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include -#include - -#include "./cublasLt.h" -#include "paddle/phi/common/float8_e4m3fn.h" -#include "paddle/phi/core/dense_tensor.h" - -namespace dyl = phi::dynload; - -namespace phi { - -struct CublasLtAlgoParam { - int algoId; - int swizzle; - int customOption; - int tile; - int splitK_val; - int reductionScheme; - int stages; - size_t workspace_size; -}; - -const std::map, CublasLtAlgoParam> AlgoParamCache{}; - -class CublasLtHelper { - public: - CublasLtHelper(int m, int k, int n, cublasLtHandle_t handle) - : handle_(handle), alpha_(1), beta_(0), m_(m), k_(k), n_(n) { - cublasStatus_t status; - - cublasComputeType_t cudaComputeType = CUBLAS_COMPUTE_32I; - - // matmul desc - status = dyl::cublasLtMatmulDescCreate( - &matmul_desc_, cudaComputeType, CUDA_R_32I); - - PADDLE_ENFORCE_EQ( - status, - CUBLAS_STATUS_SUCCESS, - common::errors::External( - "cublasLtMatmulDescCreate execution error" - "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " - "information")); - cublasOperation_t op_transpose = CUBLAS_OP_T; - status = dyl::cublasLtMatmulDescSetAttribute(matmul_desc_, - CUBLASLT_MATMUL_DESC_TRANSA, - &op_transpose, - sizeof(op_transpose)); - PADDLE_ENFORCE_EQ( - status, - CUBLAS_STATUS_SUCCESS, - common::errors::External( - "cublasLtMatmulDescSetAttribute execution error" - "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " - "information")); - - // matrix desc - status = dyl::cublasLtMatrixLayoutCreate(&B_desc_, CUDA_R_8I, k, n, k); - PADDLE_ENFORCE_EQ( - status, - CUBLAS_STATUS_SUCCESS, - common::errors::External( - "cublasLtMatrixLayoutCreate execution error" - "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " - "information")); - - status = dyl::cublasLtMatrixLayoutCreate(&A_desc_, CUDA_R_8I, k, m, k); - PADDLE_ENFORCE_EQ( - status, - CUBLAS_STATUS_SUCCESS, - common::errors::External( - "cublasLtMatrixLayoutCreate execution error" - "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " - "information")); - - status = dyl::cublasLtMatrixLayoutCreate(&C_desc_, CUDA_R_32I, n, m, n); - PADDLE_ENFORCE_EQ( - status, - CUBLAS_STATUS_SUCCESS, - common::errors::External( - "cublasLtMatrixLayoutCreate execution error" - "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " - "information")); - -#if CUDA_VERSION >= 11020 - - int algoId = 21; - int swizzle = 0; - int customOption = 0; - int tile = 15; - int splitK_val = 0; - int reductionScheme = 0; - int stages = 23; - workspace_size_ = 0; - if (m >= 128) { - tile = 20; - stages = 17; - } - - std::tuple key(m_, k_, n_); - if (AlgoParamCache.count(key) != 0) { - auto value = AlgoParamCache.at(key); - algoId = value.algoId; - swizzle = value.swizzle; - customOption = value.customOption; - tile = value.tile; - splitK_val = value.splitK_val; - reductionScheme = value.reductionScheme; - stages = value.stages; - workspace_size_ = value.workspace_size; - } - - dyl::cublasLtMatmulAlgoInit(handle_, - cudaComputeType, - CUDA_R_32I, - CUDA_R_8I, - CUDA_R_8I, - CUDA_R_32I, - CUDA_R_32I, - algoId, - &algo_); - dyl::cublasLtMatmulAlgoConfigSetAttribute( - &algo_, - CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, - &(customOption), - sizeof(customOption)); - dyl::cublasLtMatmulAlgoConfigSetAttribute( - &algo_, CUBLASLT_ALGO_CONFIG_TILE_ID, &(tile), sizeof(tile)); - dyl::cublasLtMatmulAlgoConfigSetAttribute(&algo_, - CUBLASLT_ALGO_CONFIG_SPLITK_NUM, - &(splitK_val), - sizeof(splitK_val)); - dyl::cublasLtMatmulAlgoConfigSetAttribute( - &algo_, - CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, - &(swizzle), - sizeof(swizzle)); - dyl::cublasLtMatmulAlgoConfigSetAttribute( - &algo_, - CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, - &(reductionScheme), - sizeof(int)); -#if CUDA_VERSION >= 11000 - dyl::cublasLtMatmulAlgoConfigSetAttribute( - &algo_, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(stages), sizeof(stages)); -#endif -#endif - } - ~CublasLtHelper() {} - - void GEMM(const int8_t* A_dev, - const int8_t* B_dev, - int32_t* C_dev, - cudaStream_t stream, - void* workspace = nullptr) { - cublasStatus_t status; - - status = dyl::cublasLtMatmul(handle_, - matmul_desc_, - &alpha_, - B_dev, - B_desc_, - A_dev, - A_desc_, - &beta_, - C_dev, - C_desc_, - C_dev, - C_desc_, -#if CUDA_VERSION >= 11020 - &algo_, - workspace, - workspace_size_, -#else - nullptr, - nullptr, - 0, -#endif - stream); - PADDLE_ENFORCE_EQ( - status, - CUBLAS_STATUS_SUCCESS, - common::errors::External( - "cublasLtMatmul execution error" - "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " - "information")); - } - - private: - cublasLtHandle_t handle_; - cublasLtMatmulDesc_t matmul_desc_; - cublasLtMatrixLayout_t A_desc_; - cublasLtMatrixLayout_t B_desc_; - cublasLtMatrixLayout_t C_desc_; - - cublasLtMatmulAlgo_t algo_; - - int32_t alpha_ = 1; - int32_t beta_ = 0; - - int m_ = 0; - int k_ = 0; - int n_ = 0; - - size_t workspace_size_ = 0; -}; - -template -inline cudaDataType_t GetCublasLtDataType() { - return CUDA_R_32F; -} - -template <> -inline cudaDataType_t GetCublasLtDataType() { - return CUDA_R_16F; -} - -template <> -inline cudaDataType_t GetCublasLtDataType() { - return CUDA_R_16BF; -} - -#if CUDA_VERSION >= 12010 -template -void CublasLtMatmulFP8(const phi::GPUContext& dev_ctx, - const phi::DenseTensor& mat_a, - const phi::DenseTensor& mat_b, - phi::DenseTensor* workspace, - phi::DenseTensor* out) { - int m = mat_a.dims()[0]; - int k = mat_a.dims()[1]; - int n = mat_b.dims()[1]; - - // init data structure - cublasStatus_t status; - auto A_type = CUDA_R_8F_E4M3; - auto B_type = CUDA_R_8F_E4M3; - auto C_type = GetCublasLtDataType(); - - cublasLtMatmulDesc_t matmul_desc_; - cublasLtMatrixLayout_t A_desc_; - cublasLtMatrixLayout_t B_desc_; - cublasLtMatrixLayout_t C_desc_; - float alpha_ = 1.0f; - float beta_ = 0.0f; - - cublasComputeType_t cudaComputeType = CUBLAS_COMPUTE_32F; - status = - dyl::cublasLtMatmulDescCreate(&matmul_desc_, cudaComputeType, CUDA_R_32F); - cublasOperation_t op_transpose = CUBLAS_OP_T; - status = dyl::cublasLtMatmulDescSetAttribute(matmul_desc_, - CUBLASLT_MATMUL_DESC_TRANSA, - &op_transpose, - sizeof(op_transpose)); - status = dyl::cublasLtMatrixLayoutCreate(&B_desc_, B_type, k, n, k); - status = dyl::cublasLtMatrixLayoutCreate(&A_desc_, A_type, k, m, k); - status = dyl::cublasLtMatrixLayoutCreate(&C_desc_, C_type, n, m, n); - - // Need to use heuristic - int returnedResults = 0; - cublasLtMatmulHeuristicResult_t heuristicResult = {}; - cublasLtMatmulPreference_t preference = NULL; - size_t work_space_size = workspace->numel(); - - status = dyl::cublasLtMatmulPreferenceCreate(&preference); - status = dyl::cublasLtMatmulPreferenceSetAttribute( - preference, - CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, - &work_space_size, - sizeof(work_space_size)); - - status = dyl::cublasLtMatmulAlgoGetHeuristic(dev_ctx.cublaslt_handle(), - matmul_desc_, - B_desc_, - A_desc_, - C_desc_, - C_desc_, - preference, - 1, - &heuristicResult, - &returnedResults); - - PADDLE_ENFORCE_NE(returnedResults, - 0, - common::errors::NotFound( - "Unable to find suitable cuBLAS GEMM algorithm")); - - status = - dyl::cublasLtMatmul(dev_ctx.cublaslt_handle(), - matmul_desc_, - &alpha_, - mat_b.data(), - B_desc_, - mat_a.data(), - A_desc_, - &beta_, - out->data(), - C_desc_, - out->data(), - C_desc_, - // nullptr, - &heuristicResult.algo, - // nullptr, - reinterpret_cast(workspace->data()), - // 0, - work_space_size, - dev_ctx.stream()); -} -#endif - -} // namespace phi diff --git a/backends/metax_gpu/kernels/funcs/blas/port.cc b/backends/metax_gpu/kernels/funcs/blas/port.cc deleted file mode 100644 index bc6d54e5c5f..00000000000 --- a/backends/metax_gpu/kernels/funcs/blas/port.cc +++ /dev/null @@ -1,163 +0,0 @@ -// 2024 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights -// Reserved. Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// clang-format off -#include "port.h" // NOLINT - -#include -#include -#include -#include -#include "glog/logging.h" -#if !defined(_WIN32) -#include // dladdr -#include -#include - -#else -#include // std::accumulate in msvc -// clang-format on -void *dlsym(void *handle, const char *symbol_name) { - FARPROC found_symbol; - found_symbol = GetProcAddress((HMODULE)handle, symbol_name); - - if (found_symbol == NULL) { - LOG(ERROR) << "Load symbol " << symbol_name << " failed."; - throw std::runtime_error(std::string(symbol_name) + " not found."); - } - return reinterpret_cast(found_symbol); -} - -void *dlopen(const char *filename, int flag) { - std::string file_name(filename); - HMODULE hModule = LoadLibrary(file_name.c_str()); - if (!hModule) { - if (flag) { - throw std::runtime_error(file_name + " not found."); - } else { - return nullptr; - } - } - return reinterpret_cast(hModule); -} - -int gettimeofday(struct timeval *tp, void *tzp) { - time_t clock; - struct tm tm; - SYSTEMTIME wtm; - - GetLocalTime(&wtm); - tm.tm_year = wtm.wYear - 1900; - tm.tm_mon = wtm.wMonth - 1; - tm.tm_mday = wtm.wDay; - tm.tm_hour = wtm.wHour; - tm.tm_min = wtm.wMinute; - tm.tm_sec = wtm.wSecond; - tm.tm_isdst = -1; - clock = mktime(&tm); - tp->tv_sec = clock; - tp->tv_usec = wtm.wMilliseconds * 1000; - - return (0); -} -#endif // !_WIN32 - -void ExecShellCommand(const std::string &cmd, std::string *message) { - std::array buffer; -#if !defined(_WIN32) - std::shared_ptr pipe(popen(cmd.c_str(), "r"), pclose); -#else - std::shared_ptr pipe(_popen(cmd.c_str(), "r"), _pclose); -#endif // _WIN32 - if (!pipe) { - LOG(ERROR) << "error running command: " << cmd; - return; - } - while (!feof(pipe.get())) { - if (fgets(buffer.data(), 128, pipe.get()) != nullptr) { - *message += buffer.data(); - } - } -} - -bool PathExists(const std::string &path) { -#if !defined(_WIN32) - struct stat statbuf; - if (stat(path.c_str(), &statbuf) != -1) { - if (S_ISDIR(statbuf.st_mode)) { - return true; - } - } -#else - struct _stat statbuf; - if (_stat(path.c_str(), &statbuf) != -1) { - if (S_ISDIR(statbuf.st_mode)) { - return true; - } - } -#endif // !_WIN32 - return false; -} - -#if !defined(_WIN32) -constexpr char kSEP = '/'; -#else -constexpr char kSEP = '\\'; -#endif // _WIN32 - -bool FileExists(const std::string &filepath) { -#if !defined(_WIN32) - struct stat buffer; - return (stat(filepath.c_str(), &buffer) == 0); -#else - struct _stat buffer; - return (_stat(filepath.c_str(), &buffer) == 0); -#endif // !_WIN32 -} - -std::string DirName(const std::string &filepath) { - auto pos = filepath.rfind(kSEP); - if (pos == std::string::npos) { - return ""; - } - return filepath.substr(0, pos); -} - -void MkDir(const char *path) { - std::string path_error(path); - path_error += " mkdir failed!"; -#if !defined(_WIN32) - if (mkdir(path, 0755)) { - if (errno != EEXIST) { - throw std::runtime_error(path_error); - } - } -#else - BOOL return_value = CreateDirectory(path, NULL); - if (!return_value) { - auto errorno = GetLastError(); - if (errorno != ERROR_ALREADY_EXISTS) { - throw std::runtime_error(path_error); - } - } -#endif // !_WIN32 -} - -void MkDirRecursively(const char *fullpath) { - if (*fullpath == '\0') return; // empty string - if (FileExists(fullpath)) return; - - MkDirRecursively(DirName(fullpath).c_str()); - MkDir(fullpath); -} diff --git a/backends/metax_gpu/kernels/funcs/blas/port.h b/backends/metax_gpu/kernels/funcs/blas/port.h deleted file mode 100644 index d2a59199bb7..00000000000 --- a/backends/metax_gpu/kernels/funcs/blas/port.h +++ /dev/null @@ -1,61 +0,0 @@ -// 2024 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights -// Reserved. Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include - -#define GLOG_NO_ABBREVIATED_SEVERITIES // msvc conflict logging with windows.h - -#if !defined(_WIN32) -#include // dladdr -#include - -#else -#ifndef NOMINMAX -#define NOMINMAX // msvc max/min macro conflict with std::min/max -#endif -// solve static linking error in windows -// https://github.com/google/glog/issues/301 -#define GOOGLE_GLOG_DLL_DECL -#include // _popen, _pclose -#include -#include -#include - -#ifndef S_ISDIR // windows port for sys/stat.h -#define S_ISDIR(mode) (((mode)&S_IFMT) == S_IFDIR) -#endif // S_ISDIR - -void *dlsym(void *handle, const char *symbol_name); - -void *dlopen(const char *filename, int flag); - -int gettimeofday(struct timeval *tp, void *tzp); -#endif // !_WIN32 - -void ExecShellCommand(const std::string &cmd, std::string *message); - -bool PathExists(const std::string &path); - -// TODO(yuyang18): If the functions below are needed by other files, move them -// to paddle::filesystem namespace. -bool FileExists(const std::string &filepath); - -std::string DirName(const std::string &filepath); - -void MkDir(const char *path); - -void MkDirRecursively(const char *fullpath); diff --git a/backends/metax_gpu/kernels/funcs/layer_norm_util.h b/backends/metax_gpu/kernels/funcs/layer_norm_util.h index 3e16e615b1d..0f8210d8b8f 100644 --- a/backends/metax_gpu/kernels/funcs/layer_norm_util.h +++ b/backends/metax_gpu/kernels/funcs/layer_norm_util.h @@ -18,7 +18,7 @@ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/device_context.h" -#include "../funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/math_function.h" // clang-format on namespace phi { diff --git a/backends/metax_gpu/kernels/funcs/quant_dequant.h b/backends/metax_gpu/kernels/funcs/quant_dequant.h deleted file mode 100644 index 301ae351c40..00000000000 --- a/backends/metax_gpu/kernels/funcs/quant_dequant.h +++ /dev/null @@ -1,430 +0,0 @@ -// 2024 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights -// Reserved. -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -// clang-format off -#include -#include "paddle/common/hostdevice.h" -#include "paddle/phi/backends/gpu/gpu_launch_config.h" -#include "paddle/phi/common/float16.h" -#include "paddle/phi/common/transform.h" -#include "paddle/phi/kernels/funcs/aligned_vector.h" -#include "blas/blas.h" -// clang-format on -namespace phi { - -using backends::gpu::GpuLaunchConfig; - -constexpr int DequantKernelVecSize = 4; - -template -inline HOSTDEVICE T roundWithTiesToEven(T x) { - T xLower = floor(x); - T xUpper = ceil(x); - // x is in interval [xl,xu]. Choose closest of two bounds, breaking ties to - // even. - T dLower = x - xLower; - T dUpper = xUpper - x; - return static_cast( - (dLower == dUpper ? fmod(xLower, 2.0F) == 0.0F : dLower < dUpper) - ? xLower - : xUpper); -} - -template -inline HOSTDEVICE T roundWithTiesAwayFromZero(T x) { - return static_cast(x > 0 ? ceil(x) : floor(x)); -} - -template -__forceinline__ __device__ int8_t quant_helper(const T input, - const float scale, - const int round_type, - const float max_bound, - const float min_bound) { - float quant_value = max_bound * scale * static_cast(input); - - if (round_type == 0) { - quant_value = static_cast(roundWithTiesToEven(quant_value)); - } else { - quant_value = static_cast(round(quant_value)); - } - quant_value = quant_value > max_bound ? max_bound : quant_value; - quant_value = quant_value < min_bound ? min_bound : quant_value; - return static_cast(quant_value); -} - -template -__forceinline__ __device__ int8_t -quant_helper_ties_to_even_or_away_from_zero(const T input, - const float scale, - const int round_type, - const float max_bound, - const float min_bound) { - float quant_value = max_bound * scale * static_cast(input); - - if (round_type == 0) { - quant_value = static_cast(roundWithTiesToEven(quant_value)); - } else { - quant_value = static_cast(roundWithTiesAwayFromZero(quant_value)); - } - quant_value = quant_value > max_bound ? max_bound : quant_value; - quant_value = quant_value < min_bound ? min_bound : quant_value; - return static_cast(quant_value); -} - -template -__global__ void QuantKernel(const T* input, - char4* output, - const float scale, - const int m, - const int n, - const int round_type, - const float max_bound, - const float min_bound) { - int n_id = (blockIdx.x * blockDim.x + threadIdx.x) << 2; - int m_id = blockIdx.y * blockDim.y + threadIdx.y; - - bool check = ((m_id < m) && (n_id < n)); - if (check) { - char4 tmp; - tmp.x = quant_helper( - input[m_id * n + n_id], scale, round_type, max_bound, min_bound); - tmp.y = quant_helper( - input[m_id * n + n_id + 1], scale, round_type, max_bound, min_bound); - tmp.z = quant_helper( - input[m_id * n + n_id + 2], scale, round_type, max_bound, min_bound); - tmp.w = quant_helper( - input[m_id * n + n_id + 3], scale, round_type, max_bound, min_bound); - output[(m_id * n + n_id) >> 2] = tmp; - } -} - -template -__global__ void QuantKernelWithVecSize(const T* input, - char4* output, - const float scale, - const int m, - const int n, - const int round_type, - const float max_bound, - const float min_bound) { - int n_id = (blockIdx.x * blockDim.x + threadIdx.x) << 2; - int m_id = blockIdx.y * blockDim.y + threadIdx.y; - - bool check = ((m_id < m) && (n_id < n)); - if (check) { - char4 tmp; - tmp.x = quant_helper_ties_to_even_or_away_from_zero( - input[m_id * n + n_id], scale, round_type, max_bound, min_bound); - tmp.y = quant_helper_ties_to_even_or_away_from_zero( - input[m_id * n + n_id + 1], scale, round_type, max_bound, min_bound); - tmp.z = quant_helper_ties_to_even_or_away_from_zero( - input[m_id * n + n_id + 2], scale, round_type, max_bound, min_bound); - tmp.w = quant_helper_ties_to_even_or_away_from_zero( - input[m_id * n + n_id + 3], scale, round_type, max_bound, min_bound); - output[(m_id * n + n_id) >> 2] = tmp; - } -} - -template -__global__ void QuantKernelWithVecSize(const T* input, - char3* output, - const float scale, - const int m, - const int n, - const int round_type, - const float max_bound, - const float min_bound) { - int n_id = (blockIdx.x * blockDim.x + threadIdx.x) * 3; - int m_id = blockIdx.y * blockDim.y + threadIdx.y; - - bool check = ((m_id < m) && (n_id < n)); - if (check) { - char3 tmp; - tmp.x = quant_helper_ties_to_even_or_away_from_zero( - input[m_id * n + n_id], scale, round_type, max_bound, min_bound); - tmp.y = quant_helper_ties_to_even_or_away_from_zero( - input[m_id * n + n_id + 1], scale, round_type, max_bound, min_bound); - tmp.z = quant_helper_ties_to_even_or_away_from_zero( - input[m_id * n + n_id + 2], scale, round_type, max_bound, min_bound); - output[(m_id * n + n_id) / 3] = tmp; - } -} - -template -__global__ void QuantKernelWithVecSize(const T* input, - char2* output, - const float scale, - const int m, - const int n, - const int round_type, - const float max_bound, - const float min_bound) { - int n_id = (blockIdx.x * blockDim.x + threadIdx.x) * 2; - int m_id = blockIdx.y * blockDim.y + threadIdx.y; - - bool check = ((m_id < m) && (n_id < n)); - if (check) { - char2 tmp; - tmp.x = quant_helper_ties_to_even_or_away_from_zero( - input[m_id * n + n_id], scale, round_type, max_bound, min_bound); - tmp.y = quant_helper_ties_to_even_or_away_from_zero( - input[m_id * n + n_id + 1], scale, round_type, max_bound, min_bound); - output[(m_id * n + n_id) >> 1] = tmp; - } -} - -template -__global__ void QuantKernelWithVecSize(const T* input, - char* output, - const float scale, - const int m, - const int n, - const int round_type, - const float max_bound, - const float min_bound) { - int n_id = (blockIdx.x * blockDim.x + threadIdx.x); - int m_id = blockIdx.y * blockDim.y + threadIdx.y; - - bool check = ((m_id < m) && (n_id < n)); - if (check) { - char tmp; - tmp = quant_helper_ties_to_even_or_away_from_zero( - input[m_id * n + n_id], scale, round_type, max_bound, min_bound); - output[m_id * n + n_id] = tmp; - } -} - -template -void LaunchQuantKernel(const T* input, - int8_t* output, - const float scale, - const int m, - const int n, - const int round_type, - const float max_bound, - const float min_bound, - gpuStream_t stream) { - // TODO(minghaoBD): optimize the kennel launch times when m==1 or n==1 -#ifdef PADDLE_WITH_HIP - dim3 grid(((n >> 2) + 63) / 64, (m + 7) / 8); - dim3 block(64, 8); -#else - dim3 grid(((n >> 2) + 31) / 32, (m + 31) / 32); - dim3 block(32, 32); -#endif - - QuantKernel<<>>(input, - (char4*)output, // NOLINT - scale, - m, - n, - round_type, - max_bound, - min_bound); -} - -template -void LaunchQuantKernelWithVecSize(const T* input, - int8_t* output, - const float scale, - const int m, - const int n, - const int round_type, - const float max_bound, - const float min_bound, - gpuStream_t stream) { - int vec_size = 1; - if (n % 4 == 0) { - vec_size = 4; - } else if (n % 3 == 0) { - vec_size = 3; - } else if (n % 2 == 0) { - vec_size = 2; - } - -#ifdef PADDLE_WITH_HIP - dim3 grid(((n / vec_size) + 63) / 64, (m + 7) / 8); - dim3 block(64, 8); -#else - dim3 grid(((n / vec_size) + 31) / 32, (m + 31) / 32); - dim3 block(32, 32); -#endif - - switch (vec_size) { - case 4: - QuantKernelWithVecSize<<>>( - input, - reinterpret_cast(output), - scale, - m, - n, - round_type, - max_bound, - min_bound); - break; - case 3: - QuantKernelWithVecSize<<>>( - input, - reinterpret_cast(output), - scale, - m, - n, - round_type, - max_bound, - min_bound); - break; - case 2: - QuantKernelWithVecSize<<>>( - input, - reinterpret_cast(output), - scale, - m, - n, - round_type, - max_bound, - min_bound); - break; - case 1: - QuantKernelWithVecSize<<>>( - input, - reinterpret_cast(output), - scale, - m, - n, - round_type, - max_bound, - min_bound); - break; - default: - return; - } -} - -template -__global__ void DequantKernel(T* output, - const int32_t* input, - const int m, // batch size - const int n, // hidden - const float quant_in_scale, - const float* dequant_out_scale_data) { - int numel = m * n; - int stride = blockDim.x * gridDim.x * VecSize; - int idx = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize; - int col_id = idx % n; - - phi::AlignedVector in_vec; - phi::AlignedVector out_scale_vec; - phi::AlignedVector out_vec; - - for (; idx < numel; idx += stride) { - phi::Load(input + idx, &in_vec); - phi::Load(dequant_out_scale_data + col_id, &out_scale_vec); - -#pragma unroll - for (int i = 0; i < VecSize; ++i) { - out_vec[i] = - static_cast(static_cast(in_vec[i]) * out_scale_vec[i]); - } - - phi::Store(out_vec, output + idx); - } -} - -template -void LaunchDequantKernel(const int32_t* input, - T* output, - const int m, // m - const int n, // n - gpuStream_t stream, - GpuLaunchConfig* gpu_config, - const float quant_in_scale, - const float* dequant_out_scale_data) { - DequantKernel - <<block_per_grid, gpu_config->thread_per_block, 0, stream>>>( - output, input, m, n, quant_in_scale, dequant_out_scale_data); -} - -template -__global__ void DequantKernelWithScaleOfInputAndWeight( - T* output, - const int32_t* input, - const int m, // batch size - const int n, // hidden - const float quant_in_scale, - const float* quant_weight_scale, - float quant_max_bound) { - int numel = m * n; - int stride = blockDim.x * gridDim.x * VecSize; - int idx = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize; - int col_id = idx % n; - - phi::AlignedVector in_vec; - phi::AlignedVector out_scale_vec; - phi::AlignedVector out_vec; - - for (; idx < numel; idx += stride) { - phi::Load(input + idx, &in_vec); - phi::Load(quant_weight_scale + col_id, &out_scale_vec); - -#pragma unroll - for (int i = 0; i < VecSize; ++i) { - out_vec[i] = static_cast(static_cast(in_vec[i]) / - (quant_max_bound * quant_max_bound * - quant_in_scale * out_scale_vec[i])); - } - - phi::Store(out_vec, output + idx); - } -} - -template -void LaunchDequantKernelWithScaleOfInputAndWeight( - const int32_t* input, - T* output, - const int m, // m - const int n, // n - gpuStream_t stream, - GpuLaunchConfig* gpu_config, - const float quant_in_scale, - const float* quant_weight_scale, - float quant_max_bound) { - if (n % DequantKernelVecSize != 0) { - DequantKernelWithScaleOfInputAndWeight<<block_per_grid, - gpu_config->thread_per_block, - 0, - stream>>>(output, - input, - m, - n, - quant_in_scale, - quant_weight_scale, - quant_max_bound); - return; - } - DequantKernelWithScaleOfInputAndWeight - <<block_per_grid, gpu_config->thread_per_block, 0, stream>>>( - output, - input, - m, - n, - quant_in_scale, - quant_weight_scale, - quant_max_bound); -} - -} // namespace phi diff --git a/backends/metax_gpu/kernels/funcs/softmax.cu b/backends/metax_gpu/kernels/funcs/softmax.cu index 44bfd02a308..a587f9ed016 100644 --- a/backends/metax_gpu/kernels/funcs/softmax.cu +++ b/backends/metax_gpu/kernels/funcs/softmax.cu @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. */ #include +#include "glog/logging.h" #include "kernels/metax_kernel/metax_context.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_dnn.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/softmax.h" #include "paddle/phi/kernels/funcs/softmax_impl.h" - namespace phi { namespace funcs { @@ -38,6 +38,7 @@ void SoftmaxCUDNNFunctor::operator()( ScopedTensorDescriptor yDesc; std::vector cudnn_tensor_dims = common::vectorize(X->dims()); DataLayout layout = DataLayout::kNCHW; + VLOG(0) << "Enter softmax Kernel22."; if (cudnn_tensor_dims.size() == 5) { layout = DataLayout::kNCDHW; } diff --git a/backends/metax_gpu/kernels/gpudnn/cudnn.cc b/backends/metax_gpu/kernels/gpudnn/cudnn.cc deleted file mode 100644 index dc403282c1c..00000000000 --- a/backends/metax_gpu/kernels/gpudnn/cudnn.cc +++ /dev/null @@ -1,78 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/phi/backends/dynload/cudnn.h" // NOLINT - -#include "paddle/phi/core/enforce.h" - -namespace phi::dynload { - -std::once_flag cudnn_dso_flag; -void* cudnn_dso_handle = nullptr; - -#define DEFINE_WRAP(__name) DynLoad__##__name __name - -CUDNN_DNN_ROUTINE_EACH(DEFINE_WRAP); - -#ifdef CUDNN_DNN_ROUTINE_EACH_AFTER_R7_LESS_R8 -CUDNN_DNN_ROUTINE_EACH_AFTER_R7_LESS_R8(DEFINE_WRAP); -#endif - -#ifdef CUDNN_DNN_ROUTINE_EACH_R7 -CUDNN_DNN_ROUTINE_EACH_R7(DEFINE_WRAP); -#endif - -#ifdef CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7 -CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7(DEFINE_WRAP); -#endif - -#ifdef CUDNN_DNN_ROUTINE_EACH_AFTER_R7 -CUDNN_DNN_ROUTINE_EACH_AFTER_R7(DEFINE_WRAP); -#endif - -#ifdef CUDNN_DNN_ROUTINE_EACH_R8 -CUDNN_DNN_ROUTINE_EACH_R8(DEFINE_WRAP); -#endif - -#ifdef CUDNN_DNN_ROUTINE_EACH_FRONTEND -CUDNN_DNN_ROUTINE_EACH_FRONTEND(DEFINE_WRAP); -#endif - -#ifdef CUDNN_DNN_ROUTINE_EACH_REMOVED_IN_E9 -CUDNN_DNN_ROUTINE_EACH_REMOVED_IN_E9(DEFINE_WRAP); -#endif - -#ifdef CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7_REMOVED_IN_E9 -CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7_REMOVED_IN_E9(DEFINE_WRAP); -#endif - -#ifdef CUDNN_DNN_ROUTINE_EACH_R9 -CUDNN_DNN_ROUTINE_EACH_R9(DEFINE_WRAP); -#endif - -bool HasCUDNN() { - std::call_once(cudnn_dso_flag, - []() { cudnn_dso_handle = GetCUDNNDsoHandle(); }); - return cudnn_dso_handle != nullptr; -} - -void EnforceCUDNNLoaded(const char* fn_name) { - PADDLE_ENFORCE_NOT_NULL( - cudnn_dso_handle, - common::errors::PreconditionNotMet( - "Cannot load cudnn shared library. Cannot invoke method %s.", - fn_name)); -} - -} // namespace phi::dynload diff --git a/backends/metax_gpu/kernels/gpudnn/cudnn.h b/backends/metax_gpu/kernels/gpudnn/cudnn.h deleted file mode 100644 index 65cb6b338b7..00000000000 --- a/backends/metax_gpu/kernels/gpudnn/cudnn.h +++ /dev/null @@ -1,218 +0,0 @@ -// 2024 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights -// Reserved. -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#ifdef PADDLE_WITH_CUDA -#include - -#include // NOLINT - -#include "paddle/phi/backends/dynload/dynamic_loader.h" -#include "paddle/phi/common/port.h" - -namespace phi { -namespace dynload { - -extern std::once_flag cudnn_dso_flag; -extern void* cudnn_dso_handle; -extern bool HasCUDNN(); - -extern void EnforceCUDNNLoaded(const char* fn_name); -#define DECLARE_DYNAMIC_LOAD_CUDNN_WRAP(__name) \ - struct DynLoad__##__name { \ - template \ - auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \ - using cudnn_func = decltype(&::__name); \ - std::call_once(cudnn_dso_flag, []() { \ - cudnn_dso_handle = phi::dynload::GetCUDNNDsoHandle(); \ - }); \ - EnforceCUDNNLoaded(#__name); \ - std::string replaced_name = #__name; \ - replaced_name = replaced_name.replace(0, 2, "mc"); \ - static void* p_##__name = \ - dlsym(cudnn_dso_handle, replaced_name.c_str()); \ - return reinterpret_cast(p_##__name)(args...); \ - } \ - }; \ - extern struct DynLoad__##__name __name - -/** - * include all needed cudnn functions in HPPL - * different cudnn version has different interfaces - **/ -#define CUDNN_DNN_ROUTINE_EACH(__macro) \ - __macro(cudnnSetTensor4dDescriptor); \ - __macro(cudnnSetTensor4dDescriptorEx); \ - __macro(cudnnSetTensorNdDescriptor); \ - __macro(cudnnGetTensorNdDescriptor); \ - __macro(cudnnGetConvolutionNdForwardOutputDim); \ - __macro(cudnnCreateTensorDescriptor); \ - __macro(cudnnDestroyTensorDescriptor); \ - __macro(cudnnCreateFilterDescriptor); \ - __macro(cudnnSetFilter4dDescriptor); \ - __macro(cudnnSetFilterNdDescriptor); \ - __macro(cudnnGetFilterNdDescriptor); \ - __macro(cudnnSetPooling2dDescriptor); \ - __macro(cudnnSetPoolingNdDescriptor); \ - __macro(cudnnGetPoolingNdDescriptor); \ - __macro(cudnnDestroyFilterDescriptor); \ - __macro(cudnnCreateConvolutionDescriptor); \ - __macro(cudnnCreatePoolingDescriptor); \ - __macro(cudnnDestroyPoolingDescriptor); \ - __macro(cudnnSetConvolution2dDescriptor); \ - __macro(cudnnDestroyConvolutionDescriptor); \ - __macro(cudnnSetConvolutionNdDescriptor); \ - __macro(cudnnGetConvolutionNdDescriptor); \ - __macro(cudnnDeriveBNTensorDescriptor); \ - __macro(cudnnCreateSpatialTransformerDescriptor); \ - __macro(cudnnSetSpatialTransformerNdDescriptor); \ - __macro(cudnnDestroySpatialTransformerDescriptor); \ - __macro(cudnnSpatialTfGridGeneratorForward); \ - __macro(cudnnSpatialTfGridGeneratorBackward); \ - __macro(cudnnSpatialTfSamplerForward); \ - __macro(cudnnSpatialTfSamplerBackward); \ - __macro(cudnnCreate); \ - __macro(cudnnDestroy); \ - __macro(cudnnSetStream); \ - __macro(cudnnActivationForward); \ - __macro(cudnnActivationBackward); \ - __macro(cudnnConvolutionForward); \ - __macro(cudnnConvolutionBackwardBias); \ - __macro(cudnnGetConvolutionForwardWorkspaceSize); \ - __macro(cudnnTransformTensor); \ - __macro(cudnnPoolingForward); \ - __macro(cudnnPoolingBackward); \ - __macro(cudnnSoftmaxBackward); \ - __macro(cudnnSoftmaxForward); \ - __macro(cudnnGetVersion); \ - __macro(cudnnFindConvolutionForwardAlgorithmEx); \ - __macro(cudnnFindConvolutionBackwardFilterAlgorithmEx); \ - __macro(cudnnFindConvolutionBackwardFilterAlgorithm); \ - __macro(cudnnFindConvolutionBackwardDataAlgorithmEx); \ - __macro(cudnnGetErrorString); \ - __macro(cudnnCreateDropoutDescriptor); \ - __macro(cudnnDropoutGetStatesSize); \ - __macro(cudnnSetDropoutDescriptor); \ - __macro(cudnnRestoreDropoutDescriptor); \ - __macro(cudnnCreateRNNDescriptor); \ - __macro(cudnnGetRNNParamsSize); \ - __macro(cudnnGetRNNWorkspaceSize); \ - __macro(cudnnGetRNNTrainingReserveSize); \ - __macro(cudnnRNNForwardTraining); \ - __macro(cudnnRNNBackwardData); \ - __macro(cudnnRNNBackwardWeights); \ - __macro(cudnnRNNForwardInference); \ - __macro(cudnnDestroyDropoutDescriptor); \ - __macro(cudnnDestroyRNNDescriptor); \ - __macro(cudnnSetTensorNdDescriptorEx); \ - __macro(cudnnAddTensor); \ - __macro(cudnnConvolutionBackwardData); \ - __macro(cudnnConvolutionBackwardFilter); \ - __macro(cudnnGetConvolutionBackwardFilterWorkspaceSize); \ - __macro(cudnnGetConvolutionBackwardDataWorkspaceSize); \ - __macro(cudnnBatchNormalizationForwardTraining); \ - __macro(cudnnBatchNormalizationForwardInference); \ - __macro(cudnnBatchNormalizationBackward); \ - __macro(cudnnCreateActivationDescriptor); \ - __macro(cudnnSetActivationDescriptor); \ - __macro(cudnnGetActivationDescriptor); \ - __macro(cudnnDestroyActivationDescriptor); \ - __macro(cudnnSetRNNDescriptor_v6); -CUDNN_DNN_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) - -#if CUDNN_VERSION >= 7000 && CUDNN_VERSION < 8000 -#define CUDNN_DNN_ROUTINE_EACH_AFTER_R7_LESS_R8(__macro) \ - __macro(cudnnGetConvolutionBackwardFilterAlgorithm); \ - __macro(cudnnGetConvolutionForwardAlgorithm); \ - __macro(cudnnGetConvolutionBackwardDataAlgorithm); \ - __macro(cudnnSetRNNDescriptor); -CUDNN_DNN_ROUTINE_EACH_AFTER_R7_LESS_R8(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) -#endif - -#if CUDNN_VERSION >= 7001 -#define CUDNN_DNN_ROUTINE_EACH_R7(__macro) \ - __macro(cudnnSetConvolutionGroupCount); \ - __macro(cudnnSetConvolutionMathType); \ - __macro(cudnnConvolutionBiasActivationForward); \ - __macro(cudnnCreateCTCLossDescriptor); \ - __macro(cudnnDestroyCTCLossDescriptor); \ - __macro(cudnnGetCTCLossDescriptor); \ - __macro(cudnnSetCTCLossDescriptor); \ - __macro(cudnnGetCTCLossWorkspaceSize); \ - __macro(cudnnCTCLoss); \ - __macro(cudnnGetConvolutionBackwardDataAlgorithm_v7); \ - __macro(cudnnGetConvolutionBackwardFilterAlgorithm_v7); \ - __macro(cudnnGetConvolutionForwardAlgorithm_v7); \ - __macro(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount); -CUDNN_DNN_ROUTINE_EACH_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) -#endif - -#if CUDNN_VERSION >= 7201 -#define CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7(__macro) \ - __macro(cudnnCreateRNNDataDescriptor); \ - __macro(cudnnDestroyRNNDataDescriptor); \ - __macro(cudnnSetRNNDataDescriptor); \ - __macro(cudnnSetRNNPaddingMode); \ - __macro(cudnnRNNForwardTrainingEx); \ - __macro(cudnnRNNBackwardDataEx); \ - __macro(cudnnRNNBackwardWeightsEx); \ - __macro(cudnnRNNForwardInferenceEx); -CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) -#endif - -#if CUDNN_VERSION >= 7401 -#define CUDNN_DNN_ROUTINE_EACH_AFTER_R7(__macro) \ - __macro(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize); \ - __macro(cudnnBatchNormalizationForwardTrainingEx); \ - __macro(cudnnGetBatchNormalizationBackwardExWorkspaceSize); \ - __macro(cudnnBatchNormalizationBackwardEx); \ - __macro(cudnnGetBatchNormalizationTrainingExReserveSpaceSize); -CUDNN_DNN_ROUTINE_EACH_AFTER_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) -#endif - -#if CUDNN_VERSION >= 8000 -#define CUDNN_DNN_ROUTINE_EACH_R8(__macro) \ - __macro(cudnnSetRNNDescriptor_v8); \ - __macro(cudnnCreateFusedOpsPlan); \ - __macro(cudnnCreateFusedOpsConstParamPack); \ - __macro(cudnnCreateFusedOpsVariantParamPack); \ - __macro(cudnnDestroyFusedOpsPlan); \ - __macro(cudnnDestroyFusedOpsConstParamPack); \ - __macro(cudnnDestroyFusedOpsVariantParamPack); \ - __macro(cudnnFusedOpsExecute); \ - __macro(cudnnSetFusedOpsConstParamPackAttribute); \ - __macro(cudnnSetFusedOpsVariantParamPackAttribute); \ - __macro(cudnnMakeFusedOpsPlan); -CUDNN_DNN_ROUTINE_EACH_R8(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) -#endif - -#ifdef PADDLE_WITH_CUDNN_FRONTEND -#define CUDNN_DNN_ROUTINE_EACH_FRONTEND(__macro) \ - __macro(cudnnBackendCreateDescriptor); \ - __macro(cudnnBackendDestroyDescriptor); \ - __macro(cudnnBackendExecute); \ - __macro(cudnnBackendFinalize); \ - __macro(cudnnBackendGetAttribute); \ - __macro(cudnnBackendSetAttribute); \ - __macro(cudnnGetStream); \ - __macro(cudnnReorderFilterAndBias); -CUDNN_DNN_ROUTINE_EACH_FRONTEND(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) -#endif - -} // namespace dynload -} // namespace phi - -#endif diff --git a/backends/metax_gpu/kernels/gpudnn/softmax_kernel_dnn.cu b/backends/metax_gpu/kernels/gpudnn/softmax_kernel_dnn.cu new file mode 100644 index 00000000000..b51f92c96a4 --- /dev/null +++ b/backends/metax_gpu/kernels/gpudnn/softmax_kernel_dnn.cu @@ -0,0 +1,70 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "kernels/gpudnn/softmax_gpudnn.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/softmax_kernel.h" + +namespace phi { + +template +void SoftmaxGPUDNNKernel(const Context& dev_ctx, + const DenseTensor& x, + int axis, + DenseTensor* out) { + dev_ctx.template Alloc(out); + if (x.numel() == 0) return; + + const int rank = x.dims().size(); + // For 0D Tensor + if (rank == 0) { + phi::funcs::set_constant(dev_ctx, out, static_cast(1.0)); + return; + } + + SoftmaxForwardCUDAKernelDriver(dev_ctx, x, axis, out); +} + +} // namespace phi + +#ifdef PADDLE_WITH_HIP +PD_REGISTER_PLUGIN_KERNEL(softmax, + metax_gpu, + ALL_LAYOUT, + phi::SoftmaxGPUDNNKernel, + float, + phi::float16, + phi::bfloat16) {} +#else +#if CUDNN_VERSION_MIN(8, 1, 0) +PD_REGISTER_PLUGIN_KERNEL(softmax, + metax_gpu, + ALL_LAYOUT, + phi::SoftmaxGPUDNNKernel, + float, + double, + phi::float16, + phi::bfloat16) {} +#else +PD_REGISTER_PLUGIN_KERNEL(softmax, + metax_gpu, + ALL_LAYOUT, + phi::SoftmaxGPUDNNKernel, + float, + double, + phi::float16) {} +#endif +#endif diff --git a/backends/metax_gpu/kernels/impl/addmm_kernel_impl.h b/backends/metax_gpu/kernels/impl/addmm_kernel_impl.h index b517b719d49..a2c69b6adf0 100644 --- a/backends/metax_gpu/kernels/impl/addmm_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/addmm_kernel_impl.h @@ -20,7 +20,7 @@ limitations under the License. */ #include "glog/logging.h" #include "paddle/phi/kernels/addmm_kernel.h" -#include "../funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" // clang-format on diff --git a/backends/metax_gpu/kernels/impl/baddbmm_kernel_impl.h b/backends/metax_gpu/kernels/impl/baddbmm_kernel_impl.h index 593c044fc76..1c52ea22e4e 100644 --- a/backends/metax_gpu/kernels/impl/baddbmm_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/baddbmm_kernel_impl.h @@ -17,9 +17,9 @@ limitations under the License. */ #include #include "glog/logging.h" -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/kernels/baddbmm_kernel.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" diff --git a/backends/metax_gpu/kernels/impl/bilinear_grad_kernel_impl.h b/backends/metax_gpu/kernels/impl/bilinear_grad_kernel_impl.h index ef61d48202f..b64f94bc7ef 100644 --- a/backends/metax_gpu/kernels/impl/bilinear_grad_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/bilinear_grad_kernel_impl.h @@ -14,8 +14,8 @@ #pragma once -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/eigen/common.h" namespace phi { diff --git a/backends/metax_gpu/kernels/impl/bilinear_kernel_impl.h b/backends/metax_gpu/kernels/impl/bilinear_kernel_impl.h index c124e84eb6d..48861d48932 100644 --- a/backends/metax_gpu/kernels/impl/bilinear_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/bilinear_kernel_impl.h @@ -14,8 +14,8 @@ #pragma once -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/utils/optional.h" diff --git a/backends/metax_gpu/kernels/impl/bmm_grad_kernel_impl.h b/backends/metax_gpu/kernels/impl/bmm_grad_kernel_impl.h index 543df3ee964..cd5978ae59f 100644 --- a/backends/metax_gpu/kernels/impl/bmm_grad_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/bmm_grad_kernel_impl.h @@ -14,9 +14,9 @@ #pragma once -#include "kernels/funcs/blas/blas.h" -#include "kernels/impl/matmul_grad_kernel_impl.h" #include "paddle/phi/kernels/bmm_grad_kernel.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/impl/matmul_grad_kernel_impl.h" namespace phi { diff --git a/backends/metax_gpu/kernels/impl/bmm_kernel_impl.h b/backends/metax_gpu/kernels/impl/bmm_kernel_impl.h index 7b4164032b2..ce493b4908a 100644 --- a/backends/metax_gpu/kernels/impl/bmm_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/bmm_kernel_impl.h @@ -14,8 +14,8 @@ #pragma once -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/bmm_kernel.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" namespace phi { diff --git a/backends/metax_gpu/kernels/impl/cholesky_grad_kernel_impl.h b/backends/metax_gpu/kernels/impl/cholesky_grad_kernel_impl.h index 02332652660..5d146dae8d5 100644 --- a/backends/metax_gpu/kernels/impl/cholesky_grad_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/cholesky_grad_kernel_impl.h @@ -14,8 +14,8 @@ limitations under the License. */ #pragma once -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/cholesky_grad_kernel.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/for_range.h" namespace phi { diff --git a/backends/metax_gpu/kernels/impl/cholesky_solve_grad_kernel_impl.h b/backends/metax_gpu/kernels/impl/cholesky_solve_grad_kernel_impl.h index 62115e9ee6a..098092767c4 100644 --- a/backends/metax_gpu/kernels/impl/cholesky_solve_grad_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/cholesky_solve_grad_kernel_impl.h @@ -14,7 +14,6 @@ #pragma once -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/cholesky_solve_grad_kernel.h" #include "paddle/phi/kernels/cholesky_solve_kernel.h" #include "paddle/phi/kernels/complex_kernel.h" @@ -22,6 +21,7 @@ #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/expand_kernel.h" #include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/common_shape.h" #include "paddle/phi/kernels/funcs/complex_functors.h" #include "paddle/phi/kernels/funcs/for_range.h" diff --git a/backends/metax_gpu/kernels/impl/conv_grad_kernel_impl.h b/backends/metax_gpu/kernels/impl/conv_grad_kernel_impl.h index 25e0d93a6a4..6066720ab07 100644 --- a/backends/metax_gpu/kernels/impl/conv_grad_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/conv_grad_kernel_impl.h @@ -14,10 +14,10 @@ #pragma once -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/cpu/conv_util.h" #include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/funcs/batch_norm_utils.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/im2col.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/vol2col.h" diff --git a/backends/metax_gpu/kernels/impl/conv_kernel_impl.h b/backends/metax_gpu/kernels/impl/conv_kernel_impl.h index 2cf5fa166e7..4395e5d5782 100644 --- a/backends/metax_gpu/kernels/impl/conv_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/conv_kernel_impl.h @@ -14,11 +14,11 @@ #pragma once -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/conv_kernel.h" #include "paddle/phi/kernels/cpu/conv_util.h" #include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/funcs/batch_norm_utils.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/im2col.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/vol2col.h" diff --git a/backends/metax_gpu/kernels/impl/conv_transpose_kernel_impl.h b/backends/metax_gpu/kernels/impl/conv_transpose_kernel_impl.h index c7c002d4e9e..aadc5d2b8a0 100644 --- a/backends/metax_gpu/kernels/impl/conv_transpose_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/conv_transpose_kernel_impl.h @@ -14,12 +14,12 @@ #pragma once -#include "kernels/funcs/blas/blas.h" #include "paddle/common/ddim.h" #include "paddle/common/layout.h" #include "paddle/phi/kernels/conv_transpose_kernel.h" #include "paddle/phi/kernels/cpu/conv_util.h" #include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/concat_and_split_functor.h" #include "paddle/phi/kernels/funcs/im2col.h" #include "paddle/phi/kernels/funcs/slice.h" diff --git a/backends/metax_gpu/kernels/impl/deformable_conv_grad_kernel_impl.h b/backends/metax_gpu/kernels/impl/deformable_conv_grad_kernel_impl.h index d2419966342..b9931a89978 100644 --- a/backends/metax_gpu/kernels/impl/deformable_conv_grad_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/deformable_conv_grad_kernel_impl.h @@ -14,11 +14,11 @@ #pragma once -#include "kernels/funcs/blas/blas.h" #include "paddle/common/hostdevice.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/deformable_conv_functor.h" namespace phi { diff --git a/backends/metax_gpu/kernels/impl/elementwise.h b/backends/metax_gpu/kernels/impl/elementwise.h index 52a7709424b..b9f3d8af1c9 100644 --- a/backends/metax_gpu/kernels/impl/elementwise.h +++ b/backends/metax_gpu/kernels/impl/elementwise.h @@ -14,9 +14,9 @@ limitations under the License. */ #pragma once -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/common_shape.h" #include "paddle/phi/kernels/funcs/eigen/common.h" diff --git a/backends/metax_gpu/kernels/impl/flatten2_kernel_impl.h b/backends/metax_gpu/kernels/impl/flatten2_kernel_impl.h index d4526922c7b..dc4059a7225 100644 --- a/backends/metax_gpu/kernels/impl/flatten2_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/flatten2_kernel_impl.h @@ -15,10 +15,10 @@ #pragma once #include -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/flatten_grad_kernel.h" #include "paddle/phi/kernels/flatten_kernel.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/flatten2_utils.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/backends/metax_gpu/kernels/impl/gru_unit_kernel_impl.h b/backends/metax_gpu/kernels/impl/gru_unit_kernel_impl.h index 0929a327035..ef12141f911 100644 --- a/backends/metax_gpu/kernels/impl/gru_unit_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/gru_unit_kernel_impl.h @@ -16,10 +16,10 @@ #include -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/backends/all_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/activation_functor.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/utils/optional.h" namespace phi { diff --git a/backends/metax_gpu/kernels/impl/index_select_impl.h b/backends/metax_gpu/kernels/impl/index_select_impl.h index 78284107d34..ac39cab2704 100644 --- a/backends/metax_gpu/kernels/impl/index_select_impl.h +++ b/backends/metax_gpu/kernels/impl/index_select_impl.h @@ -15,9 +15,9 @@ #pragma once #include "glog/logging.h" -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/backends/metax_gpu/kernels/impl/inverse_grad_kernel_impl.h b/backends/metax_gpu/kernels/impl/inverse_grad_kernel_impl.h index 85aff008b4e..64b56f2cd1c 100644 --- a/backends/metax_gpu/kernels/impl/inverse_grad_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/inverse_grad_kernel_impl.h @@ -14,10 +14,10 @@ #pragma once -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/kernels/complex_kernel.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/matrix_inverse.h" #include "paddle/phi/kernels/inverse_grad_kernel.h" diff --git a/backends/metax_gpu/kernels/impl/lstm_kernel_impl.h b/backends/metax_gpu/kernels/impl/lstm_kernel_impl.h index 079548b4ad0..4a061fe4716 100644 --- a/backends/metax_gpu/kernels/impl/lstm_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/lstm_kernel_impl.h @@ -15,8 +15,8 @@ #pragma once #include -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/detail/activation_functions.h" #include "paddle/phi/kernels/funcs/lstm_compute.h" #include "paddle/phi/kernels/funcs/lstm_utils.h" diff --git a/backends/metax_gpu/kernels/impl/lu_grad_kernel_impl.h b/backends/metax_gpu/kernels/impl/lu_grad_kernel_impl.h index e9ef47490bc..5a2e5d48a11 100644 --- a/backends/metax_gpu/kernels/impl/lu_grad_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/lu_grad_kernel_impl.h @@ -14,7 +14,7 @@ #pragma once -#include "kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/impl/lu_kernel_impl.h" #include "paddle/phi/kernels/triangular_solve_kernel.h" diff --git a/backends/metax_gpu/kernels/impl/lu_solve_grad_kernel_impl.h b/backends/metax_gpu/kernels/impl/lu_solve_grad_kernel_impl.h index 21c711c53ef..24dee650dfe 100644 --- a/backends/metax_gpu/kernels/impl/lu_solve_grad_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/lu_solve_grad_kernel_impl.h @@ -15,9 +15,9 @@ #pragma once #include "paddle/phi/infermeta/binary.h" -// #include "paddle/phi/kernels/funcs/blas/blas.h" +// #include "paddle/phi/paddle/phi/kernels/funcs/blas/blas.h" -#include "kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/matrix_solve.h" #include "paddle/phi/kernels/impl/lu_kernel_impl.h" diff --git a/backends/metax_gpu/kernels/impl/matmul_grad_kernel_impl.h b/backends/metax_gpu/kernels/impl/matmul_grad_kernel_impl.h deleted file mode 100644 index 823851666f1..00000000000 --- a/backends/metax_gpu/kernels/impl/matmul_grad_kernel_impl.h +++ /dev/null @@ -1,2042 +0,0 @@ -// 2024 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights -// Reserved. -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -// clang-format off -#include "glog/logging.h" - -#include "paddle/phi/backends/cpu/cpu_context.h" -#include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/kernels/complex_kernel.h" -#include "paddle/phi/kernels/empty_kernel.h" -#include "paddle/phi/kernels/full_kernel.h" -#include "paddle/phi/kernels/funcs/reduce_function.h" -#include "paddle/phi/kernels/funcs/reduce_functor.h" -#include "paddle/phi/kernels/impl/dot_grad_kernel_impl.h" -// #include "paddle/phi/kernels/impl/matmul_kernel_impl.h" -#include "paddle/phi/kernels/reduce_sum_kernel.h" - -#include "../impl/matmul_kernel_impl.h" -// clang-format on - -#if defined(__NVCC__) || defined(__HIPCC__) -#include "paddle/phi/kernels/gpu/reduce.h" -#endif - -namespace phi { - -template -struct ReduceSumForMatmulGrad { - void operator()(const Context& dev_ctx, - const DenseTensor& input, - DenseTensor* output, - const std::vector& reduce_dims); -}; - -template -struct ReduceSumForMatmulGrad { - void operator()(const CPUContext& dev_ctx, - const DenseTensor& input, - DenseTensor* output, - const std::vector& reduce_dims) { - std::vector reduce_dims_tmp(reduce_dims.begin(), - reduce_dims.end()); - funcs::ReduceKernelImpl( - dev_ctx, input, output, reduce_dims_tmp, true, false); - } -}; - -#if defined(__NVCC__) || defined(__HIPCC__) -template -struct ReduceSumForMatmulGrad { - void operator()(const GPUContext& dev_ctx, - const DenseTensor& input, - DenseTensor* output, - const std::vector& reduce_dims) { - phi::SumKernel( - dev_ctx, input, reduce_dims, input.dtype(), false, output); - } -}; -#endif - -// Reshape a rank-3 tensor from P x M x N to (P * M) x N. -// Identity op if the tensor is not of rank 3. -static DenseTensor FoldInitDims(const DenseTensor& input) { - DenseTensor output = input; - auto in_dims = input.dims(); - if (in_dims.size() == 3) { - output.Resize({in_dims[0] * in_dims[1], in_dims[2]}); - } - return output; -} - -// Reshape a rank-3 tensor from P x M x N to M x (P * N). -// (Warning: This requires transposing data and writes into new memory.) -// Identity op if the tensor is not of rank 3. -template -static DenseTensor FoldHeadAndLastDims(const Context& dev_ctx, - const DenseTensor& input) { - auto in_dims = input.dims(); - if (in_dims.size() != 3) { - return input; - } - DenseTensor output = EmptyLike(dev_ctx, input); - output.Resize({in_dims[1], in_dims[0], in_dims[2]}); - std::vector axis = {1, 0, 2}; - funcs::Transpose trans; - trans(dev_ctx, input, &output, axis); - output.Resize({in_dims[1], in_dims[0] * in_dims[2]}); - return output; -} - -template -typename std::enable_if::value>::type MatMul( - const Context& dev_ctx, - const DenseTensor& a, - bool trans_a, - const DenseTensor& b, - bool trans_b, - DenseTensor* out, - bool flag = false) { - dev_ctx.template Alloc(out); - auto blas = phi::funcs::GetBlas(dev_ctx); - auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(a.dims(), 0, trans_a); - auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(b.dims(), 0, trans_b); - if (a.dims().size() == 3 && b.dims().size() <= 2) { - // the transpose_X must be false, if is true, the transpose cost much time - if (!trans_a) { - mat_dim_a.height_ *= mat_dim_a.batch_size_; - mat_dim_a.batch_size_ = 0; - } - } - blas.MatMul(a.data(), - mat_dim_a, - b.data(), - mat_dim_b, - static_cast(1), - dev_ctx.template Alloc(out), - static_cast(flag)); -} - -/** - * Get row matrix shape from a vector shape. If the rank of x_dim > 1, the - * original x_dim is returned. - */ -static DDim RowMatrixFromVector(const DDim& x_dim) { - if (x_dim.size() > 1) { - return x_dim; - } - return common::make_ddim({1, x_dim[0]}); -} - -/** - * Get column matrix shape from a vector shape. If the ran of y_dim > 1, the - * original y_dim is returned. - */ -static DDim ColumnMatrixFromVector(const DDim& y_dim) { - if (y_dim.size() > 1) { - return y_dim; - } - return common::make_ddim({y_dim[0], 1}); -} - -/** - * Reshape a tensor to 3-D or 2-D tensor by matrix descriptor. - * - * The shape would be [BatchSize, H, W] or [H, W]. - * If transposed, `H,W` will be swapped. - */ -static void ReshapeTensorIntoMatrixSequence( - DenseTensor* x, const phi::funcs::MatDescriptor& descriptor) { - int64_t h, w; - h = descriptor.height_; - w = descriptor.width_; - if (descriptor.trans_) { - std::swap(w, h); - } - if (descriptor.batch_size_) { - x->Resize({descriptor.batch_size_, h, w}); - } else { - x->Resize({h, w}); - } -} - -static void ReshapeXYOutIntoMatrixSequence(DenseTensor* x, - DenseTensor* y, - DenseTensor* out, - bool trans_x, - bool trans_y) { - auto x_dim = RowMatrixFromVector(x->dims()); - auto y_dim = ColumnMatrixFromVector(y->dims()); - auto mat_dim_x = phi::funcs::CreateMatrixDescriptor(x_dim, 0, trans_x); - auto mat_dim_y = phi::funcs::CreateMatrixDescriptor(y_dim, 0, trans_y); - if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) { - out->Resize({mat_dim_x.height_, mat_dim_y.width_}); - } else { - out->Resize({(std::max)(mat_dim_x.batch_size_, mat_dim_y.batch_size_), - mat_dim_x.height_, - mat_dim_y.width_}); - } - - ReshapeTensorIntoMatrixSequence(x, mat_dim_x); - ReshapeTensorIntoMatrixSequence(y, mat_dim_y); -} - -template -void CalcInputGrad(const Context& dev_ctx, - const DenseTensor& a, - bool trans_a, - bool is_fold_init_dims_a, - const DenseTensor& b, - bool trans_b, - bool is_fold_init_dims_b, - DenseTensor* out, - bool flag = false) { - if (out == nullptr) return; - bool need_combine = - (a.dims().size() == 3 || b.dims().size() == 3) && out->dims().size() == 2; - if (!need_combine) { - MatMul(dev_ctx, a, trans_a, b, trans_b, out, flag); - } else { - MatMul( - dev_ctx, - is_fold_init_dims_a ? FoldInitDims(a) - : FoldHeadAndLastDims(dev_ctx, a), - trans_a, - is_fold_init_dims_b ? FoldInitDims(b) - : FoldHeadAndLastDims(dev_ctx, b), - trans_b, - out, - flag); - } -} - -template -void MatmulGradKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - const DenseTensor& out_grad, - bool transpose_x, - bool transpose_y, - DenseTensor* dx, - DenseTensor* dy) { - // get dims - std::vector x_dims = common::vectorize(x.dims()); - std::vector y_dims = common::vectorize(y.dims()); - std::vector dout_dims = common::vectorize(out_grad.dims()); - - int x_ndim = x_dims.size(); - int y_ndim = y_dims.size(); - int ndim = dout_dims.size(); - - // Case1 : x's or y's dim = 1 - if (x_ndim == 1 && y_ndim == 1) { - if (dx) dev_ctx.template Alloc(dx); - if (dy) dev_ctx.template Alloc(dy); - if (out_grad.numel() == 1) { - DotGradFunction()(dev_ctx, &x, &y, &out_grad, dx, dy); - return; - } - } - - bool is_broadcast = true; - if (x_ndim <= 2 || y_ndim <= 2) { - is_broadcast = false; - } else if (x_ndim != y_ndim) { - is_broadcast = true; - } else { - is_broadcast = !std::equal( - x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2, y_dims.cbegin()); - } - - // for complex - DenseTensor x_conj; - DenseTensor y_conj; - - // Case2: no broadcast or no batch size, it aims to speed and it is same as - // matmul in old version. - if (!is_broadcast) { - DenseTensor x_help = x; - DenseTensor y_help = y; - DenseTensor out_grad_help = out_grad; - - ReshapeXYOutIntoMatrixSequence( - &x_help, &y_help, &out_grad_help, transpose_x, transpose_y); - - DDim dx_dims; - if (dx) { - dx_dims = dx->dims(); - if (dx_dims != x_help.dims()) { - dx->Resize(x_help.dims()); - } - - y_conj = Conj(dev_ctx, y_help); - } - - DDim dy_dims; - if (dy) { - dy_dims = dy->dims(); - if (dy_dims != y_help.dims()) { - dy->Resize(y_help.dims()); - } - - x_conj = Conj(dev_ctx, x_help); - } - - if (transpose_x && transpose_y) { - CalcInputGrad( - dev_ctx, y_conj, true, true, out_grad_help, true, false, dx); - CalcInputGrad( - dev_ctx, out_grad_help, true, true, x_conj, true, false, dy); - } else if (transpose_x) { - CalcInputGrad( - dev_ctx, y_conj, false, false, out_grad_help, true, false, dx); - CalcInputGrad( - dev_ctx, x_conj, false, false, out_grad_help, false, true, dy); - } else if (transpose_y) { - CalcInputGrad( - dev_ctx, out_grad_help, false, false, y_conj, false, true, dx); - CalcInputGrad( - dev_ctx, out_grad_help, true, true, x_conj, false, true, dy); - } else { - CalcInputGrad( - dev_ctx, out_grad_help, false, false, y_conj, true, false, dx); - CalcInputGrad( - dev_ctx, x_conj, true, true, out_grad_help, false, true, dy); - } - - if (dx) { - if (dx_dims != x_help.dims()) { - dx->Resize(dx_dims); - } - } - if (dy) { - if (dy_dims != y_help.dims()) { - dy->Resize(dy_dims); - } - } - } else { - // Case3: broadcast. It need cost much time to reduce sum for the - // broadcast and wastes the memory. - // So we should avoid the case in reality. - VLOG(3) << "It need cost much time to reduce sum for the broadcast and " - "wastes the memory. So we should avoid the case in reality"; - x_conj = Conj(dev_ctx, x); - y_conj = Conj(dev_ctx, y); - - DenseTensor dx_help; - DenseTensor dy_help; - - if (transpose_x) { - if (transpose_y) { - // X'Y': dA = Y'G', dB = G'X' - if (dx) - MatMulFunction(dev_ctx, - y_conj, - out_grad, - y_dims, - dout_dims, - &dx_help, - true, - true); - if (dy) - MatMulFunction(dev_ctx, - out_grad, - x_conj, - dout_dims, - x_dims, - &dy_help, - true, - true); - } else { - // X'Y: dX = YG', dY = XG - if (dx) - MatMulFunction(dev_ctx, - y_conj, - out_grad, - y_dims, - dout_dims, - &dx_help, - false, - true); - if (dy) - MatMulFunction(dev_ctx, - x_conj, - out_grad, - x_dims, - dout_dims, - &dy_help, - false, - false); - } - } else { - if (transpose_y) { - // XY': dX = GY, dY = G'X - if (dx) - MatMulFunction(dev_ctx, - out_grad, - y_conj, - dout_dims, - y_dims, - &dx_help, - false, - false); - if (dy) - MatMulFunction(dev_ctx, - out_grad, - x_conj, - dout_dims, - x_dims, - &dy_help, - true, - false); - } else { - // XY: dX = GY', dY = X'G - if (dx) - MatMulFunction(dev_ctx, - out_grad, - y_conj, - dout_dims, - y_dims, - &dx_help, - false, - true); - if (dy) - MatMulFunction(dev_ctx, - x_conj, - out_grad, - x_dims, - dout_dims, - &dy_help, - true, - false); - } - } - - // get help dims - const std::vector dx_help_dims = - common::vectorize(dx_help.dims()); - const std::vector dy_help_dims = - common::vectorize(dy_help.dims()); - - std::vector dx_broadcast_dims(ndim); - std::vector dy_broadcast_dims(ndim); - - std::fill( - dx_broadcast_dims.data(), dx_broadcast_dims.data() + ndim - x_ndim, 1); - std::fill( - dy_broadcast_dims.data(), dy_broadcast_dims.data() + ndim - y_ndim, 1); - std::copy(x_dims.data(), - x_dims.data() + x_ndim, - dx_broadcast_dims.data() + ndim - x_ndim); - std::copy(y_dims.data(), - y_dims.data() + y_ndim, - dy_broadcast_dims.data() + ndim - y_ndim); - - std::vector dx_reduce_dims; - std::vector dy_reduce_dims; - for (int idx = 0; idx <= ndim - 3; idx++) { - if (dx_help_dims[idx] != 1 && dx_broadcast_dims[idx] == 1) { - dx_reduce_dims.push_back(idx); - } - if (dy_help_dims[idx] != 1 && dy_broadcast_dims[idx] == 1) { - dy_reduce_dims.push_back(idx); - } - } - // reduce sum to get grad by ReduceSum - if (dx) { - if (dx_reduce_dims.empty()) { - *dx = std::move(dx_help); - } else { - ReduceSumForMatmulGrad()( - dev_ctx, dx_help, dx, dx_reduce_dims); - } - dx->Resize(x.dims()); - } - if (dy) { - if (dy_reduce_dims.empty()) { - *dy = std::move(dy_help); - } else { - ReduceSumForMatmulGrad()( - dev_ctx, dy_help, dy, dy_reduce_dims); - } - dy->Resize(y.dims()); - } - // Get the OutputGrad(out) - } -} - -template -void MatmulDoubleGradKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - const DenseTensor& dout, - const paddle::optional& ddx, - const paddle::optional& ddy, - bool transpose_x, - bool transpose_y, - DenseTensor* dx, - DenseTensor* dy, - DenseTensor* ddout) { - // Get dims from the input x, y, output_grad - std::vector x_dims = common::vectorize(x.dims()); - std::vector y_dims = common::vectorize(y.dims()); - std::vector dout_dims = common::vectorize(dout.dims()); - - int x_ndim = x_dims.size(); - int y_ndim = y_dims.size(); - int ndim = dout_dims.size(); - - // Case1 : x's or y's dim = 1 - if (x_ndim == 1 && y_ndim == 1) { - DotDoubleGradFunction()( - dev_ctx, &x, &y, &dout, &ddx, &ddy, dx, dy, ddout); - return; - } - - DenseTensor x_conj; - DenseTensor y_conj; - DenseTensor dout_conj; - - bool is_broadcast = true; - if (x_ndim <= 2 || y_ndim <= 2) { - is_broadcast = false; - } else if (x_ndim != y_ndim) { - is_broadcast = true; - } else { - is_broadcast = !std::equal( - x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2, y_dims.cbegin()); - } - - if (!is_broadcast) { - // Case2: no broadcast or no batch size - DenseTensor x_help = x; - DenseTensor y_help = y; - DenseTensor dout_help = dout; - ReshapeXYOutIntoMatrixSequence( - &x_help, &y_help, &dout_help, transpose_x, transpose_y); - DDim dx_dims; - - if (dx) { - dx_dims = dx->dims(); - if (dx_dims != x_help.dims()) { - dx->Resize(x_help.dims()); - } - } - - DDim dy_dims; - if (dy) { - dy_dims = dy->dims(); - if (dy_dims != y_help.dims()) { - dy->Resize(y_help.dims()); - } - } - - DDim ddout_dims; - if (ddout) { - ddout_dims = ddout->dims(); - if (ddout_dims != dout_help.dims()) { - ddout->Resize(dout_help.dims()); - } - - x_conj = Conj(dev_ctx, x_help); - y_conj = Conj(dev_ctx, y_help); - } - - if (dx || dy) { - dout_conj = Conj(dev_ctx, dout_help); - } - - bool ddout_flag = false; - if (ddx) { - auto ddx_mat = ddx.get(); - if (ddx_mat.dims() != x_help.dims()) { - ddx_mat.Resize(x_help.dims()); - } - if (dy) { - if (transpose_x && transpose_y) { - // dy = dout' * ddx' - CalcInputGrad( - dev_ctx, dout_conj, true, true, ddx_mat, true, false, dy, false); - } else if (transpose_x) { - // dy = ddx * dout - CalcInputGrad(dev_ctx, - ddx_mat, - false, - false, - dout_conj, - false, - true, - dy, - false); - } else if (transpose_y) { - // dy = dout' * ddx - CalcInputGrad( - dev_ctx, dout_conj, true, true, ddx_mat, false, true, dy, false); - } else { - // dy = ddx' * dout - CalcInputGrad( - dev_ctx, ddx_mat, true, true, dout_conj, false, true, dy, false); - } - } - - if (ddout) { - CalcInputGrad(dev_ctx, - ddx_mat, - transpose_x, - true, - y_conj, - transpose_y, - false, - ddout, - ddout_flag); - ddout_flag = true; - } - } else if (!ddx && dy) { - FullLikeKernel(dev_ctx, y, Scalar(0.0), y.dtype(), dy); - } - if (ddy) { - auto ddy_mat = ddy.get(); - if (ddy_mat.dims() != y_help.dims()) { - ddy_mat.Resize(y_help.dims()); - } - if (dx) { - if (transpose_x && transpose_y) { - // dx = ddy' * dout' - CalcInputGrad( - dev_ctx, ddy_mat, true, true, dout_conj, true, false, dx, false); - } else if (transpose_x) { - // dx = ddy * dout' - CalcInputGrad(dev_ctx, - ddy_mat, - false, - false, - dout_conj, - true, - false, - dx, - false); - } else if (transpose_y) { - // dx = dout * ddy - CalcInputGrad(dev_ctx, - dout_conj, - false, - false, - ddy_mat, - false, - true, - dx, - false); - } else { - // dx = dout * ddy' - CalcInputGrad(dev_ctx, - dout_conj, - false, - false, - ddy_mat, - true, - false, - dx, - false); - } - } - - if (ddout) { - CalcInputGrad(dev_ctx, - x_conj, - transpose_x, - true, - ddy_mat, - transpose_y, - false, - ddout, - ddout_flag); - } - } else if (!ddy && dx) { - FullLikeKernel(dev_ctx, x, Scalar(0.0), x.dtype(), dx); - } - if (ddout && !ddx && !ddy) { - FullLikeKernel( - dev_ctx, dout, Scalar(0.0), dout.dtype(), ddout); - } - - if (dx) { - if (dx_dims != x_help.dims()) { - dx->Resize(dx_dims); - } - } - - if (dy) { - if (dy_dims != y_help.dims()) { - dy->Resize(dy_dims); - } - } - - if (ddout) { - if (ddout_dims != dout_help.dims()) { - ddout->Resize(ddout_dims); - } - } - } else { - // Case3: broadcast. It need cost much time to reduce sum for the - // broadcast and wastes the memory. - // So we should avoid the case in reality. - VLOG(3) << "It need cost much time to reduce sum for the broadcast and " - "wastes the memory. So we should avoid the case in reality"; - if (dx || dy) { - dout_conj = Conj(dev_ctx, dout); - } - if (ddout) { - x_conj = Conj(dev_ctx, x); - y_conj = Conj(dev_ctx, y); - } - - DenseTensor dx_help; - DenseTensor dy_help; - - if (transpose_x) { - if (transpose_y) { - if (dx && ddy) { - MatMulFunction(dev_ctx, - ddy.get(), - dout_conj, - y_dims, - dout_dims, - &dx_help, - true, - true); - } - if (dy && ddx) { - MatMulFunction(dev_ctx, - dout_conj, - ddx.get(), - dout_dims, - x_dims, - &dy_help, - true, - true); - } - } else { - if (dx && ddy) { - MatMulFunction(dev_ctx, - ddy.get(), - dout_conj, - y_dims, - dout_dims, - &dx_help, - false, - true); - } - if (dy && ddx) { - MatMulFunction(dev_ctx, - ddx.get(), - dout_conj, - x_dims, - dout_dims, - &dy_help, - false, - false); - } - } - } else { - if (transpose_y) { - if (dx && ddy) { - MatMulFunction(dev_ctx, - dout_conj, - ddy.get(), - dout_dims, - y_dims, - &dx_help, - false, - false); - } - if (dy && ddx) { - MatMulFunction(dev_ctx, - dout_conj, - ddx.get(), - dout_dims, - x_dims, - &dy_help, - true, - false); - } - } else { - if (dx && ddy) { - MatMulFunction(dev_ctx, - dout_conj, - ddy.get(), - dout_dims, - y_dims, - &dx_help, - false, - true); - } - if (dy && ddx) { - MatMulFunction(dev_ctx, - ddx.get(), - dout_conj, - x_dims, - dout_dims, - &dy_help, - true, - false); - } - } - } - - // get help dims - const std::vector dx_help_dims = - common::vectorize(dx_help.dims()); - const std::vector dy_help_dims = - common::vectorize(dy_help.dims()); - - std::vector dx_broadcast_dims(ndim); - std::vector dy_broadcast_dims(ndim); - - std::fill( - dx_broadcast_dims.data(), dx_broadcast_dims.data() + ndim - x_ndim, 1); - std::fill( - dy_broadcast_dims.data(), dy_broadcast_dims.data() + ndim - y_ndim, 1); - std::copy(x_dims.data(), - x_dims.data() + x_ndim, - dx_broadcast_dims.data() + ndim - x_ndim); - std::copy(y_dims.data(), - y_dims.data() + y_ndim, - dy_broadcast_dims.data() + ndim - y_ndim); - - std::vector dx_reduce_dims; - std::vector dy_reduce_dims; - for (int idx = 0; idx <= ndim - 3; idx++) { - if (dx_help_dims[idx] != 1 && dx_broadcast_dims[idx] == 1) { - dx_reduce_dims.push_back(idx); - } - if (dy_help_dims[idx] != 1 && dy_broadcast_dims[idx] == 1) { - dy_reduce_dims.push_back(idx); - } - } - // Reduce sum to get grad by ReduceSum - if (dx && dx_help.initialized()) { - if (dx_reduce_dims.empty()) { - *dx = std::move(dx_help); - } else { - ReduceSumForMatmulGrad()( - dev_ctx, dx_help, dx, dx_reduce_dims); - } - dx->Resize(x.dims()); - } else if (dx && !dx_help.initialized()) { - FullLikeKernel(dev_ctx, x, Scalar(0.0), x.dtype(), dx); - } - if (dy && dy_help.initialized()) { - if (dy_reduce_dims.empty()) { - *dy = std::move(dy_help); - } else { - ReduceSumForMatmulGrad()( - dev_ctx, dy_help, dy, dy_reduce_dims); - } - dy->Resize(y.dims()); - } else if (dy && !dy_help.initialized()) { - FullLikeKernel(dev_ctx, y, Scalar(0.0), y.dtype(), dy); - } - - if (ddout) { - // Calculate the gradient of OutputGrad(Out) - if (ddx) { - MatMulFunction(dev_ctx, - ddx.get(), - y_conj, - x_dims, - y_dims, - ddout, - transpose_x, - transpose_y); - } - - if (ddy) { - MatMulFunction(dev_ctx, - x_conj, - ddy.get(), - x_dims, - y_dims, - ddout, - transpose_x, - transpose_y, - true); - } - } - } -} - -template -void MatmulTripleGradKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - const DenseTensor& dout, - const paddle::optional& ddx, - const paddle::optional& ddy, - const paddle::optional& d_dx, - const paddle::optional& d_dy, - const paddle::optional& d_ddout, - bool transpose_x, - bool transpose_y, - DenseTensor* out_d_x, - DenseTensor* out_d_y, - DenseTensor* out_d_dout, - DenseTensor* out_d_ddx, - DenseTensor* out_d_ddy) { - // Get dims from the input x, y, output_grad - std::vector x_dims = common::vectorize(x.dims()); - std::vector y_dims = common::vectorize(y.dims()); - std::vector dout_dims = common::vectorize(dout.dims()); - - int x_ndim = x_dims.size(); - int y_ndim = y_dims.size(); - int ndim = dout_dims.size(); - - // Case1 : x's and y's dim = 1 - if (x_ndim == 1 && y_ndim == 1) { - VLOG(3) << "======== MatMulV2TripleGradKernel, Compute ====== Case 1"; - DotTripleGradFunction()(dev_ctx, - &x, - &y, - &dout, - &ddx, - &ddy, - &d_dx, - &d_dy, - &d_ddout, - out_d_x, - out_d_y, - out_d_dout, - out_d_ddx, - out_d_ddy); - return; - } - - DenseTensor x_conj; - DenseTensor y_conj; - DenseTensor dout_conj; - DenseTensor ddx_conj; - DenseTensor ddy_conj; - - bool is_broadcast = true; - if (x_ndim <= 2 || y_ndim <= 2) { - is_broadcast = false; - } else if (x_ndim != y_ndim) { - is_broadcast = true; - } else { - is_broadcast = !std::equal( - x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2, y_dims.cbegin()); - } - - if (!is_broadcast) { - // Case2: no broadcast or no batch size - VLOG(3) << "======== MatMulV2TripleGradKernel, Compute ====== Case 2"; - DenseTensor x_help = x; - DenseTensor y_help = y; - DenseTensor dout_help = dout; - - DenseTensor ddx_help; - DenseTensor ddy_help; - ReshapeXYOutIntoMatrixSequence( - &x_help, &y_help, &dout_help, transpose_x, transpose_y); - if (ddx) { - ddx_help = ddx.get(); - if (ddx_help.dims() != x_help.dims()) { - ddx_help.Resize(x_help.dims()); - } - } - - if (ddy) { - ddy_help = ddy.get(); - if (ddy_help.dims() != y_help.dims()) { - ddy_help.Resize(y_help.dims()); - } - } - - DDim out_dx_dims; - if (out_d_x) { - out_dx_dims = out_d_x->dims(); - if (out_dx_dims != x_help.dims()) { - out_d_x->Resize(x_help.dims()); - } - if (ddy) { - ddy_conj = Conj(dev_ctx, ddy_help); - } - } - DDim out_dy_dims; - if (out_d_y) { - out_dy_dims = out_d_y->dims(); - if (out_dy_dims != y_help.dims()) { - out_d_y->Resize(y_help.dims()); - } - if (ddx) { - ddx_conj = Conj(dev_ctx, ddx_help); - } - } - DDim out_d_dout_dims; - if (out_d_dout) { - out_d_dout_dims = out_d_dout->dims(); - if (out_d_dout_dims != dout_help.dims()) { - out_d_dout->Resize(dout_help.dims()); - } - if (ddx && !ddx_conj.IsInitialized()) { - ddx_conj = Conj(dev_ctx, ddx_help); - } - if (ddy && !ddy_conj.IsInitialized()) { - ddy_conj = Conj(dev_ctx, ddy_help); - } - } - DDim out_d_ddx_dims; - if (out_d_ddx) { - out_d_ddx_dims = out_d_ddx->dims(); - if (out_d_ddx_dims != x_help.dims()) { - out_d_ddx->Resize(x_help.dims()); - } - dout_conj = Conj(dev_ctx, dout_help); - y_conj = Conj(dev_ctx, y_help); - } - DDim out_d_ddy_dims; - if (out_d_ddy) { - out_d_ddy_dims = out_d_ddy->dims(); - if (out_d_ddy_dims != y_help.dims()) { - out_d_ddy->Resize(y_help.dims()); - } - if (!dout_conj.IsInitialized()) { - dout_conj = Conj(dev_ctx, dout_help); - } - x_conj = Conj(dev_ctx, x_help); - } - - bool d_dout_flag = false; - bool d_ddx_flag = false; - bool d_ddy_flag = false; - if (d_ddout) { - auto d_ddout_mat = d_ddout.get(); - if (d_ddout_mat.dims() != dout_help.dims()) { - d_ddout_mat.Resize(dout_help.dims()); - } - - if (out_d_y && ddx) { - if (transpose_x && transpose_y) { - // out_d_y = d_ddout' * ddx' - CalcInputGrad(dev_ctx, - d_ddout_mat, - true, - true, - ddx_conj, - true, - false, - out_d_y, - false); - } else if (transpose_x) { - // out_d_y = ddx * d_ddout - CalcInputGrad(dev_ctx, - ddx_conj, - false, - false, - d_ddout_mat, - false, - true, - out_d_y, - false); - } else if (transpose_y) { - // out_d_y = d_ddout' * ddx - CalcInputGrad(dev_ctx, - d_ddout_mat, - true, - true, - ddx_conj, - false, - true, - out_d_y, - false); - } else { - // out_d_y = ddx' * d_ddout - CalcInputGrad(dev_ctx, - ddx_conj, - true, - true, - d_ddout_mat, - false, - true, - out_d_y, - false); - } - } else if (out_d_y) { - FullLikeKernel(dev_ctx, y, Scalar(0.0), y.dtype(), out_d_y); - } - if (out_d_x && ddy) { - if (transpose_x && transpose_y) { - // out_d_x = ddy' * d_ddout' - CalcInputGrad(dev_ctx, - ddy_conj, - true, - true, - d_ddout_mat, - true, - false, - out_d_x, - false); - } else if (transpose_x) { - // out_d_x = ddy * d_ddout' - CalcInputGrad(dev_ctx, - ddy_conj, - false, - false, - d_ddout_mat, - true, - false, - out_d_x, - false); - } else if (transpose_y) { - // out_d_x = d_ddout * ddy - CalcInputGrad(dev_ctx, - d_ddout_mat, - false, - false, - ddy_conj, - false, - true, - out_d_x, - false); - } else { - // out_d_x = d_ddout * ddy' - CalcInputGrad(dev_ctx, - d_ddout_mat, - false, - false, - ddy_conj, - true, - false, - out_d_x, - false); - } - } else if (out_d_x) { - FullLikeKernel(dev_ctx, x, Scalar(0.0), x.dtype(), out_d_x); - } - - // equations: - // d_ddx = DOut * D_DY + Y * D_DDOut - // Let: d_ddx1 = Y * D_DDOut - // Let: d_ddx2 = DOut * D_DY - - // d_ddy = DOut * D_DX + X * D_DDOut - // Let: d_ddy1 = X * D_DDOut - // Let: d_ddy2 = DOut * D_DX - - // d_dout = DDY * D_DX + DDX * D_DY - // Let: d_dout1 = DDX * D_DY - // Let: d_dout2 = DDY * D_DX - - // compute d_ddx1 - if (out_d_ddx) { - if (transpose_x && transpose_y) { - // out_d_ddx1 = y' * d_ddout' - CalcInputGrad(dev_ctx, - y_conj, - true, - true, - d_ddout_mat, - true, - false, - out_d_ddx, - d_ddx_flag); - } else if (transpose_x) { - // out_d_ddx1 = y * d_ddout' - CalcInputGrad(dev_ctx, - y_conj, - false, - false, - d_ddout_mat, - true, - false, - out_d_ddx, - d_ddx_flag); - } else if (transpose_y) { - // out_d_ddx1 = d_ddout * y - CalcInputGrad(dev_ctx, - d_ddout_mat, - false, - false, - y_conj, - false, - true, - out_d_ddx, - d_ddx_flag); - } else { - // out_d_ddx1 = d_ddout * y' - CalcInputGrad(dev_ctx, - d_ddout_mat, - false, - false, - y_conj, - true, - false, - out_d_ddx, - d_ddx_flag); - } - d_ddx_flag = true; - } - - // compute d_ddy1 - if (out_d_ddy) { - if (transpose_x && transpose_y) { - // out_d_ddy1 = d_ddout' * x' - CalcInputGrad(dev_ctx, - d_ddout_mat, - true, - true, - x_conj, - true, - false, - out_d_ddy, - false); - } else if (transpose_x) { - // out_d_ddy1 = x * d_ddout - CalcInputGrad(dev_ctx, - x_conj, - false, - false, - d_ddout_mat, - false, - true, - out_d_ddy, - false); - } else if (transpose_y) { - // out_d_ddy1 = d_ddout' * x - CalcInputGrad(dev_ctx, - d_ddout_mat, - true, - true, - x_conj, - false, - true, - out_d_ddy, - false); - } else { - // out_d_ddy1 = x' * d_ddout - CalcInputGrad(dev_ctx, - x_conj, - true, - true, - d_ddout_mat, - false, - true, - out_d_ddy, - false); - } - d_ddy_flag = true; - } - } else { - // d_ddout is none - if (out_d_x) { - FullLikeKernel(dev_ctx, x, Scalar(0.0), x.dtype(), out_d_x); - } - - if (out_d_y) { - FullLikeKernel(dev_ctx, y, Scalar(0.0), y.dtype(), out_d_y); - } - } - - if (d_dy) { - auto d_dy_mat = d_dy.get(); - if (d_dy_mat.dims() != y_help.dims()) { - d_dy_mat.Resize(y_help.dims()); - } - - // compute d_dout1 - if (out_d_dout && ddx) { - CalcInputGrad(dev_ctx, - ddx_conj, - transpose_x, - true, - d_dy_mat, - transpose_y, - false, - out_d_dout, - d_dout_flag); - d_dout_flag = true; - } - - // compute d_ddx2 - if (out_d_ddx) { - if (transpose_x && transpose_y) { - // out_d_ddx2 = D_DY' * DOut' - CalcInputGrad(dev_ctx, - d_dy_mat, - true, - true, - dout_conj, - true, - false, - out_d_ddx, - d_ddx_flag); - } else if (transpose_x) { - // out_d_ddx2 = D_DY * Dout' - CalcInputGrad(dev_ctx, - d_dy_mat, - false, - false, - dout_conj, - true, - false, - out_d_ddx, - d_ddx_flag); - } else if (transpose_y) { - // out_d_ddx2 = Dout * D_DY - CalcInputGrad(dev_ctx, - dout_conj, - false, - false, - d_dy_mat, - false, - true, - out_d_ddx, - d_ddx_flag); - } else { - // out_d_ddx2 = Dout * D_DY' - CalcInputGrad(dev_ctx, - dout_conj, - false, - false, - d_dy_mat, - true, - false, - out_d_ddx, - d_ddx_flag); - } - } - } - - if (d_dx) { - auto d_dx_mat = d_dx.get(); - if (d_dx_mat.dims() != x_help.dims()) { - d_dx_mat.Resize(x_help.dims()); - } - - // compute d_dout2 - if (out_d_dout && ddy) { - CalcInputGrad(dev_ctx, - d_dx_mat, - transpose_x, - true, - ddy_conj, - transpose_y, - false, - out_d_dout, - d_dout_flag); - } - - // compute d_ddy2 - if (out_d_ddy) { - if (transpose_x && transpose_y) { - // out_d_ddy2 = dout' * d_dx' - CalcInputGrad(dev_ctx, - dout_conj, - true, - true, - d_dx_mat, - true, - false, - out_d_ddy, - d_ddy_flag); - } else if (transpose_x) { - // out_d_ddy2 = d_dx * dout - CalcInputGrad(dev_ctx, - d_dx_mat, - false, - false, - dout_conj, - false, - true, - out_d_ddy, - d_ddy_flag); - } else if (transpose_y) { - // out_d_ddy2 = dout' * d_dx - CalcInputGrad(dev_ctx, - dout_conj, - true, - true, - d_dx_mat, - false, - true, - out_d_ddy, - d_ddy_flag); - } else { - // out_d_ddy2 = d_dx' * dout - CalcInputGrad(dev_ctx, - d_dx_mat, - true, - true, - dout_conj, - false, - true, - out_d_ddy, - d_ddy_flag); - } - } - } - - if (out_d_x) { - if (out_dx_dims != x_help.dims()) { - out_d_x->Resize(out_dx_dims); - } - } - - if (out_d_y) { - if (out_dy_dims != y_help.dims()) { - out_d_y->Resize(out_dy_dims); - } - } - - if (out_d_dout) { - if (out_d_dout_dims != dout_help.dims()) { - out_d_dout->Resize(out_d_dout_dims); - } - } - - if (out_d_ddx) { - if (out_d_ddx_dims != x_help.dims()) { - out_d_ddx->Resize(out_d_ddx_dims); - } - } - - if (out_d_ddy) { - if (out_d_ddy_dims != y_help.dims()) { - out_d_ddy->Resize(out_d_ddy_dims); - } - } - - if (out_d_dout && !out_d_dout->IsInitialized()) { - FullLikeKernel( - dev_ctx, dout, Scalar(0.0), dout.dtype(), out_d_dout); - } - - if (out_d_ddx && !out_d_ddx->IsInitialized()) { - FullLikeKernel(dev_ctx, x, Scalar(0.0), x.dtype(), out_d_ddx); - } - - if (out_d_ddy && !out_d_ddy->IsInitialized()) { - FullLikeKernel(dev_ctx, y, Scalar(0.0), y.dtype(), out_d_ddy); - } - } else { - // Case3: broadcast. It need cost much time to reduce sum for the - // broadcast and wastes the memory. - // So we should avoid the case in reality. - VLOG(3) << "======== MatMulV2TripleGradKernel, Compute ====== Case 3"; - VLOG(3) << "It need cost much time to reduce sum for the broadcast and " - "wastes the memory. So we should avoid the case in reality"; - - DenseTensor out_dx_help; - DenseTensor out_dy_help; - DenseTensor out_d_ddx_help; - DenseTensor out_d_ddy_help; - - if (out_d_dout) { - if (ddx) { - ddx_conj = Conj(dev_ctx, ddx.get()); - } - if (ddy) { - ddy_conj = Conj(dev_ctx, ddy.get()); - } - } - if (out_d_ddx || out_d_ddy) { - x_conj = Conj(dev_ctx, x); - y_conj = Conj(dev_ctx, y); - dout_conj = Conj(dev_ctx, dout); - } - - if (transpose_x) { - if (transpose_y) { - // dX = ddY' d_ddout’, dY = d_ddout’ ddX' - if (out_d_x && ddy && d_ddout) - MatMulFunction(dev_ctx, - ddy_conj, - d_ddout.get(), - y_dims, - dout_dims, - &out_dx_help, - true, - true); - if (out_d_y && ddx && d_ddout) - MatMulFunction(dev_ctx, - d_ddout.get(), - ddx_conj, - dout_dims, - x_dims, - &out_dy_help, - true, - true); - } else { - // dX = ddY d_ddout', dY = ddX d_ddout - if (out_d_x && ddy && d_ddout) - MatMulFunction(dev_ctx, - ddy_conj, - d_ddout.get(), - y_dims, - dout_dims, - &out_dx_help, - false, - true); - if (out_d_y && ddx && d_ddout) - MatMulFunction(dev_ctx, - ddx_conj, - d_ddout.get(), - x_dims, - dout_dims, - &out_dy_help, - false, - false); - } - - } else { - if (transpose_y) { - // dX = d_ddout ddY, dY = d_ddout’ ddX - if (out_d_x && ddy && d_ddout) - MatMulFunction(dev_ctx, - d_ddout.get(), - ddy_conj, - dout_dims, - y_dims, - &out_dx_help, - false, - false); - if (out_d_y && ddx && d_ddout) - MatMulFunction(dev_ctx, - d_ddout.get(), - ddx_conj, - dout_dims, - x_dims, - &out_dy_help, - true, - false); - } else { - // dX = d_ddout ddY', dY = ddX' d_ddout - if (out_d_x && ddy && d_ddout) - MatMulFunction(dev_ctx, - d_ddout.get(), - ddy_conj, - dout_dims, - y_dims, - &out_dx_help, - false, - true); - if (out_d_y && ddx && d_ddout) - MatMulFunction(dev_ctx, - ddx_conj, - d_ddout.get(), - x_dims, - dout_dims, - &out_dy_help, - true, - false); - } - } - - // get help dims - const std::vector dx_help_dims = - common::vectorize(out_dx_help.dims()); - const std::vector dy_help_dims = - common::vectorize(out_dx_help.dims()); - - std::vector dx_broadcast_dims(ndim); - std::vector dy_broadcast_dims(ndim); - - std::fill( - dx_broadcast_dims.data(), dx_broadcast_dims.data() + ndim - x_ndim, 1); - std::fill( - dy_broadcast_dims.data(), dy_broadcast_dims.data() + ndim - y_ndim, 1); - std::copy(x_dims.data(), - x_dims.data() + x_ndim, - dx_broadcast_dims.data() + ndim - x_ndim); - std::copy(y_dims.data(), - y_dims.data() + y_ndim, - dy_broadcast_dims.data() + ndim - y_ndim); - - std::vector dx_reduce_dims; - std::vector dy_reduce_dims; - for (int idx = 0; idx <= ndim - 3; idx++) { - if (dx_help_dims[idx] != 1 && dx_broadcast_dims[idx] == 1) { - dx_reduce_dims.push_back(idx); - } - if (dy_help_dims[idx] != 1 && dy_broadcast_dims[idx] == 1) { - dy_reduce_dims.push_back(idx); - } - } - - // Reduce sum to get grad by ReduceSum - if (out_d_x && out_dx_help.initialized()) { - if (dx_reduce_dims.empty()) { - *out_d_x = std::move(out_dx_help); - } else { - ReduceSumForMatmulGrad()( - dev_ctx, out_dx_help, out_d_x, dx_reduce_dims); - } - out_d_x->Resize(x.dims()); - } else if (out_d_x) { - FullLikeKernel(dev_ctx, x, Scalar(0.0), x.dtype(), out_d_x); - } - - if (out_d_y && out_dy_help.initialized()) { - if (dy_reduce_dims.empty()) { - *out_d_y = std::move(out_dy_help); - } else { - ReduceSumForMatmulGrad()( - dev_ctx, out_dy_help, out_d_y, dy_reduce_dims); - } - out_d_y->Resize(y.dims()); - } else if (out_d_y) { - FullLikeKernel(dev_ctx, y, Scalar(0.0), y.dtype(), out_d_y); - } - - // compute d_dout - if (out_d_dout) { - if (d_dx && ddy) { - MatMulFunction(dev_ctx, - d_dx.get(), - ddy_conj, - x_dims, - y_dims, - out_d_dout, - transpose_x, - transpose_y); - } - if (d_dy && ddx) { - MatMulFunction(dev_ctx, - ddx_conj, - d_dy.get(), - x_dims, - y_dims, - out_d_dout, - transpose_x, - transpose_y, - true); - } - - if (!out_d_dout->initialized()) { - FullLikeKernel( - dev_ctx, dout, Scalar(0.0), dout.dtype(), out_d_dout); - } - } - - // compute d_ddx - if (out_d_ddx) { - if (transpose_x && transpose_y) { - // out_d_ddx1 = y' * d_ddout' - if (d_ddout) { - MatMulFunction(dev_ctx, - y_conj, - d_ddout.get(), - y_dims, - dout_dims, - &out_d_ddx_help, - true, - true); - } - - // out_d_ddx2 = D_DY' * DOut' - if (d_dy) { - MatMulFunction(dev_ctx, - d_dy.get(), - dout_conj, - y_dims, - dout_dims, - &out_d_ddx_help, - true, - true, - true); - } - - } else if (transpose_x) { - // out_d_ddx1 = y * d_ddout' - if (d_ddout) { - MatMulFunction(dev_ctx, - y_conj, - d_ddout.get(), - y_dims, - dout_dims, - &out_d_ddx_help, - false, - true); - } - - // out_d_ddx2 = D_DY * Dout' - if (d_dy) { - MatMulFunction(dev_ctx, - d_dy.get(), - dout_conj, - y_dims, - dout_dims, - &out_d_ddx_help, - false, - true, - true); - } - - } else if (transpose_y) { - // out_d_ddx1 = d_ddout * y - if (d_ddout) { - MatMulFunction(dev_ctx, - d_ddout.get(), - y_conj, - dout_dims, - y_dims, - &out_d_ddx_help, - false, - false); - } - - // out_d_ddx2 = Dout * D_DY - if (d_dy) { - MatMulFunction(dev_ctx, - dout_conj, - d_dy.get(), - dout_dims, - y_dims, - &out_d_ddx_help, - false, - false, - true); - } - } else { - // out_d_ddx1 = d_ddout * y' - if (d_ddout) { - MatMulFunction(dev_ctx, - d_ddout.get(), - y_conj, - dout_dims, - y_dims, - &out_d_ddx_help, - false, - true); - } - - // out_d_ddx2 = Dout * D_DY' - if (d_dy) { - MatMulFunction(dev_ctx, - dout_conj, - d_dy.get(), - dout_dims, - y_dims, - &out_d_ddx_help, - false, - true, - true); - } - } - if (out_d_ddx_help.initialized()) { - if (dx_reduce_dims.empty()) { - *out_d_ddx = std::move(out_d_ddx_help); - } else { - ReduceSumForMatmulGrad()( - dev_ctx, out_d_ddx_help, out_d_ddx, dx_reduce_dims); - } - } else { - FullLikeKernel( - dev_ctx, x, Scalar(0.0), x.dtype(), out_d_ddx); - } - - out_d_ddx->Resize(x.dims()); - } - - // compute d_ddy - if (out_d_ddy) { - if (transpose_x && transpose_y) { - // out_d_ddy1 = d_ddout' * x' - if (d_ddout) { - MatMulFunction(dev_ctx, - d_ddout.get(), - x_conj, - dout_dims, - x_dims, - &out_d_ddy_help, - true, - true); - } - - // out_d_ddy2 = dout' * d_dx' - if (d_dx) { - MatMulFunction(dev_ctx, - dout_conj, - d_dx.get(), - dout_dims, - x_dims, - &out_d_ddy_help, - true, - true, - true); - } - - } else if (transpose_x) { - // out_d_ddy1 = x * d_ddout - if (d_ddout) { - MatMulFunction(dev_ctx, - x_conj, - d_ddout.get(), - x_dims, - dout_dims, - &out_d_ddy_help, - false, - false); - } - - // out_d_ddy2 = d_dx * dout - if (d_dx) { - MatMulFunction(dev_ctx, - d_dx.get(), - dout_conj, - x_dims, - dout_dims, - &out_d_ddy_help, - false, - false, - true); - } - - } else if (transpose_y) { - // out_d_ddy1 = d_ddout' * x - if (d_ddout) { - MatMulFunction(dev_ctx, - d_ddout.get(), - x_conj, - dout_dims, - x_dims, - &out_d_ddy_help, - true, - false); - } - - // out_d_ddy2 = dout' * d_dx - if (d_dx) { - MatMulFunction(dev_ctx, - dout_conj, - d_dx.get(), - dout_dims, - x_dims, - &out_d_ddy_help, - true, - false, - true); - } - - } else { - // out_d_ddy1 = x' * d_ddout - if (d_ddout) { - MatMulFunction(dev_ctx, - x_conj, - d_ddout.get(), - x_dims, - dout_dims, - &out_d_ddy_help, - true, - false); - } - - // out_d_ddy2 = d_dx' * dout - if (d_dx) { - MatMulFunction(dev_ctx, - d_dx.get(), - dout_conj, - x_dims, - dout_dims, - &out_d_ddy_help, - true, - false, - true); - } - } - - if (out_d_ddy_help.initialized()) { - if (dy_reduce_dims.empty()) { - *out_d_ddy = std::move(out_d_ddy_help); - } else { - ReduceSumForMatmulGrad()( - dev_ctx, out_d_ddy_help, out_d_ddy, dy_reduce_dims); - } - } else { - FullLikeKernel( - dev_ctx, y, Scalar(0.0), y.dtype(), out_d_ddy); - } - - out_d_ddy->Resize(y.dims()); - } - } -} - -template -void MatmulWithFlattenGradKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - const DenseTensor& out_grad, - int x_num_col_dims, - int y_num_col_dims, - DenseTensor* x_grad, - DenseTensor* y_grad) { - auto x_matrix = - x.dims().size() > 2 ? phi::ReshapeToMatrix(x, x_num_col_dims) : x; - auto y_matrix = - y.dims().size() > 2 ? phi::ReshapeToMatrix(y, y_num_col_dims) : y; - auto* dout = &out_grad; - - DenseTensor dout_mat(*dout); - dout_mat.Resize({common::flatten_to_2d(x.dims(), x_num_col_dims)[0], - common::flatten_to_2d(y.dims(), y_num_col_dims)[1]}); - - auto* dx = x_grad; - auto* dy = y_grad; - - if (dx != nullptr) { - dx->set_lod(x.lod()); - } - if (dy != nullptr) { - dy->set_lod(y.lod()); - } - - auto blas = phi::funcs::GetBlas(dev_ctx); - if (dx) { - dev_ctx.template Alloc(dx); - DenseTensor dx_matrix = - dx->dims().size() > 2 ? phi::ReshapeToMatrix(*dx, x_num_col_dims) : *dx; - - // dx = dout * y'. dx: M x K, dout : M x N, y : K x N - blas.MatMul(dout_mat, false, y_matrix, true, &dx_matrix); - } - if (dy) { - dev_ctx.template Alloc(dy); - DenseTensor dy_matrix = - dy->dims().size() > 2 ? phi::ReshapeToMatrix(*dy, y_num_col_dims) : *dy; - // dy = x' * dout. dy K x N, dout : M x N, x : M x K - blas.MatMul(x_matrix, true, dout_mat, false, &dy_matrix); - } -} - -template -void MatmulWithFlattenDoubleGradKernel( - const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - const DenseTensor& out_grad, - const paddle::optional& x_grad_grad, - const paddle::optional& y_grad_grad, - int x_num_col_dims, - int y_num_col_dims, - DenseTensor* x_grad, - DenseTensor* y_grad, - DenseTensor* out_grad_grad) { - auto x_mat = - x.dims().size() > 2 ? phi::ReshapeToMatrix(x, x_num_col_dims) : x; - auto y_mat = - y.dims().size() > 2 ? phi::ReshapeToMatrix(y, y_num_col_dims) : y; - - const int m = common::flatten_to_2d(x.dims(), x_num_col_dims)[0]; - const int n = common::flatten_to_2d(y.dims(), y_num_col_dims)[1]; - - auto* dout = &out_grad; - DenseTensor dout_mat(*dout); - dout_mat.Resize({m, n}); - - auto* ddx = x_grad_grad.get_ptr(); - auto* ddy = y_grad_grad.get_ptr(); - - auto* dx = x_grad; - auto* dy = y_grad; - auto* ddout = out_grad_grad; - - DenseTensor ddout_mat; - if (ddout) { - ddout->set_lod(dout->lod()); - // allocate and reshape ddout - dev_ctx.template Alloc(ddout); - ddout_mat.ShareDataWith(*ddout); - ddout_mat.Resize({m, n}); - } - - auto blas = phi::funcs::GetBlas(dev_ctx); - // a flag to specify whether ddout value has been set, if flag - // is false, MatMul beta should be 0 to set ddout, if flag is - // true, MatMul beta should be 1 to add result to ddout. - bool ddout_flag = false; - if (ddx) { - auto ddx_mat = ddx->dims().size() > 2 - ? phi::ReshapeToMatrix(*ddx, x_num_col_dims) - : static_cast(*ddx); - - // dy = ddx' * dout. dy : K x M, ddx' : K x M, dout : M x N - if (dy) { - dy->set_lod(y.lod()); - // allocate and reshape dy - dev_ctx.template Alloc(dy); - DenseTensor dy_mat = dy->dims().size() > 2 - ? phi::ReshapeToMatrix(*dy, y_num_col_dims) - : *dy; - blas.MatMul(ddx_mat, true, dout_mat, false, &dy_mat); - } - // ddout1 = ddx * y. ddx : M x K, y : K x N, ddout1 : M x N - if (ddout) { - blas.MatMul(ddx_mat, - false, - y_mat, - false, - static_cast(1.0), - &ddout_mat, - static_cast(ddout_flag)); - ddout_flag = true; - } - } - if (ddy) { - auto ddy_mat = ddy->dims().size() > 2 - ? phi::ReshapeToMatrix(*ddy, y_num_col_dims) - : static_cast(*ddy); - // dx = dout * ddy'. dout : M x N, ddy' : N x K, dx : M x K - if (dx) { - dx->set_lod(x.lod()); - // allocate and reshape dx - dev_ctx.template Alloc(dx); - DenseTensor dx_mat = dx->dims().size() > 2 - ? phi::ReshapeToMatrix(*dx, x_num_col_dims) - : *dx; - blas.MatMul(dout_mat, false, ddy_mat, true, &dx_mat); - } - // ddout2 = x * ddy. x : M x K, ddy : K x N, ddout2 : M x N - if (ddout) { - blas.MatMul(x_mat, - false, - ddy_mat, - false, - static_cast(1.0), - &ddout_mat, - static_cast(ddout_flag)); - } - } -} -template -void LegacyMatmulGradKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - const DenseTensor& out_grad, - bool transpose_x, - bool transpose_y, - float alpha, - DenseTensor* dx, - DenseTensor* dy) { - MatmulGradKernel( - dev_ctx, x, y, out_grad, transpose_x, transpose_y, dx, dy); - if (std::fabs(alpha - 1.f) > 1e-6f) { - ScaleKernel(dev_ctx, *dx, Scalar(alpha), Scalar(0), false, dx); - ScaleKernel(dev_ctx, *dy, Scalar(alpha), Scalar(0), false, dy); - } -} -} // namespace phi diff --git a/backends/metax_gpu/kernels/impl/matmul_kernel_impl.h b/backends/metax_gpu/kernels/impl/matmul_kernel_impl.h deleted file mode 100755 index 5221bd93ba9..00000000000 --- a/backends/metax_gpu/kernels/impl/matmul_kernel_impl.h +++ /dev/null @@ -1,1717 +0,0 @@ -// 2024 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights -// Reserved. -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -// clang-format off -#include "glog/logging.h" - -#include "paddle/phi/common/memory_utils.h" -#include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/kernels/autotune/cache_base.h" -#include "paddle/phi/kernels/cast_kernel.h" -#include "../funcs/blas/blas.h" -#ifdef PADDLE_WITH_HIP -#include "paddle/phi/kernels/funcs/blas/blaslt_impl.hip.h" -#else -#include "../funcs/blas/blaslt_impl.cu.h" -#endif -#include "paddle/phi/kernels/funcs/complex_functors.h" -#include "paddle/phi/kernels/scale_kernel.h" -#if defined(PADDLE_WITH_CUDA) -// #include "paddle/phi/kernels/funcs/cublaslt.h" -#include "paddle/phi/kernels/gpu/cuda_gemm_kernel.h" -#include "paddle/phi/kernels/transpose_kernel.h" -#elif defined(PADDLE_WITH_HIP) -#include "paddle/phi/kernels/funcs/hipblaslt.h" -#endif -#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 && 0 -#include "paddle/phi/kernels/autotune/auto_tune_base.h" -#endif -#include "paddle/phi/kernels/full_kernel.h" -// clang-format on -namespace phi { - -static void GetBroadcastFromDims(const int x_ndim, - const std::int64_t* x_dims, - const int y_ndim, - const std::int64_t* y_dims, - std::int64_t* x_bd_dims, - std::int64_t* y_bd_dims, - std::int64_t* out_bd_dims) { - const int ndim = (std::max)(x_ndim, y_ndim); - std::fill(x_bd_dims, x_bd_dims + ndim - x_ndim, 1); - std::fill(y_bd_dims, y_bd_dims + ndim - y_ndim, 1); - std::copy(x_dims, x_dims + x_ndim, x_bd_dims + ndim - x_ndim); - std::copy(y_dims, y_dims + y_ndim, y_bd_dims + ndim - y_ndim); - - for (int i = 0; i < ndim; ++i) { - PADDLE_ENFORCE_EQ( - x_bd_dims[i] == y_bd_dims[i] || x_bd_dims[i] <= 1 || y_bd_dims[i] <= 1, - true, - phi::errors::InvalidArgument( - "Input(X) and Input(Y) has error dim. " - "X_broadcast's shape[%s] must be equal to Y_broadcast's shape[%s], " - "or X_broadcast's shape[%s] <= 1, or Y_broadcast's shape[%s] <= 1, " - "but received X_broadcast's shape[%s] = [%s]" - "received Y_broadcast's shape[%s] = [%s].", - i, - i, - i, - i, - i, - x_bd_dims[i], - i, - y_bd_dims[i])); - if (x_bd_dims[i] == 0 || y_bd_dims[i] == 0) { - out_bd_dims[i] = 0; - } else { - out_bd_dims[i] = (std::max)(x_bd_dims[i], y_bd_dims[i]); - } - } -} - -static int64_t GetIndexMessage(const int n, - const int64_t* dims, - const int64_t* index) { - int64_t sum = 0; - for (int i = 0; i < n; ++i) { - if (dims[i] > 1) { - sum = sum * dims[i] + index[i]; - } - } - return sum; -} - -static void IndexIncreaseFromDims(const int ndim, - const int64_t* dims, - int64_t* index) { - for (int i = ndim - 1; i >= 0; --i) { - ++index[i]; - if (index[i] >= dims[i]) { - index[i] -= dims[i]; - } else { - break; - } - } -} - -// The general implementation with blas. -template -void MatMulFunctionImplWithBlas( - const Context& dev_ctx, - const DenseTensor& X, - const DenseTensor& Y, - const std::vector& x_dims, - const std::vector& y_dims, - DenseTensor* Out, - bool trans_x, - bool trans_y, - bool flag = false, - phi::funcs::MatmulPlanner* matmul_planner UNUSED = nullptr) { - const int x_ndim = x_dims.size(); - const int y_ndim = y_dims.size(); - - // Get data ptr - const T* x_data = X.data(); - const T* y_data = Y.data(); - - auto blas = phi::funcs::GetBlas(dev_ctx); - - if (x_ndim == 1 && y_ndim == 1) { - const int M = X.numel(); - const int N = Y.numel(); - PADDLE_ENFORCE_EQ( - M, - N, - phi::errors::InvalidArgument( - "X's numbers must be equal to Y's numbers, " - "when X/Y's dims =1. But received X has [%d] elements, " - "received Y has [%d] elements.", - M, - N)); - VLOG(3) << "MatMul's case 1"; - Out->Resize(common::make_ddim({})); - dev_ctx.template Alloc(Out); - blas.GEMM(CblasNoTrans, - CblasTrans, - 1, - 1, - M, - static_cast(1), - y_data, - x_data, - static_cast(flag), - dev_ctx.template Alloc(Out)); - return; - } - - if (x_ndim == 1) { - const int N = X.numel(); - if (trans_y) { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 1], - N, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 1, - N, - y_ndim - 1, - y_dims[y_ndim - 1])); - } else { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 2], - N, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 2, - N, - y_ndim - 2, - y_dims[y_ndim - 2])); - } - std::vector out_dims(y_ndim - 1); - if (trans_y) { - std::copy_n(y_dims.cbegin(), y_ndim - 1, out_dims.begin()); - } else { - std::copy_n(y_dims.cbegin(), y_ndim - 2, out_dims.begin()); - out_dims.back() = y_dims.back(); - } - Out->ResizeAndAllocate(common::make_ddim(out_dims)); - dev_ctx.template Alloc(Out); - if (trans_y) { - const int M = Y.numel() / N; - VLOG(3) << "MatMul's case 2"; - blas.GEMV(false, - M, - N, - static_cast(1), - y_data, - x_data, - static_cast(flag), - dev_ctx.template Alloc(Out)); - } else { - const int M = y_dims[y_ndim - 1]; - const int batch_size = Y.numel() / (M * N); - if (batch_size == 1) { - VLOG(3) << "MatMul's case 3"; - blas.GEMV(true, - N, - M, - static_cast(1), - y_data, - x_data, - static_cast(flag), - dev_ctx.template Alloc(Out)); - } else { - VLOG(3) << "MatMul's case 4"; - blas.BatchedGEMM(CblasTrans, - CblasNoTrans, - M, - 1, - N, - static_cast(1), - y_data, - x_data, - static_cast(flag), - dev_ctx.template Alloc(Out), - batch_size, - M * N, - 0); - } - } - return; - } - - if (y_ndim == 1) { - const int N = Y.numel(); - if (trans_x) { - PADDLE_ENFORCE_EQ( - x_dims[x_ndim - 2], - N, - phi::errors::InvalidArgument("Input(X) has error dim." - "X'dims[%d] must be equal to %d" - "But received X'dims[%d] is %d", - x_ndim - 2, - N, - x_ndim - 2, - x_dims[x_ndim - 2])); - } else { - PADDLE_ENFORCE_EQ( - x_dims[x_ndim - 1], - N, - phi::errors::InvalidArgument("Input(X) has error dim." - "X'dims[%d] must be equal to %d" - "But received X'dims[%d] is %d", - x_ndim - 1, - N, - x_ndim - 1, - x_dims[x_ndim - 1])); - } - std::vector out_dims(x_ndim - 1); - if (trans_x) { - std::copy_n(x_dims.cbegin(), x_ndim - 2, out_dims.begin()); - out_dims.back() = x_dims.back(); - } else { - std::copy_n(x_dims.cbegin(), x_ndim - 1, out_dims.begin()); - } - Out->ResizeAndAllocate(common::make_ddim(out_dims)); - dev_ctx.template Alloc(Out); - - if (trans_x) { - const int M = x_dims[x_ndim - 1]; - const int batch_size = X.numel() / (M * N); - if (batch_size == 1) { - VLOG(3) << "MatMul's case 5"; - blas.GEMV(true, - N, - M, - static_cast(1), - x_data, - y_data, - static_cast(flag), - dev_ctx.template Alloc(Out)); - } else { - VLOG(3) << "MatMul's case 6"; - blas.BatchedGEMM(CblasTrans, - CblasNoTrans, - M, - 1, - N, - static_cast(1), - x_data, - y_data, - static_cast(flag), - dev_ctx.template Alloc(Out), - batch_size, - M * N, - 0); - } - } else { - const int M = X.numel() / N; - VLOG(3) << "MatMul's case 7"; - blas.GEMV(false, - M, - N, - static_cast(1), - x_data, - y_data, - static_cast(flag), - dev_ctx.template Alloc(Out)); - } - return; - } - - const int M = trans_x ? x_dims[x_ndim - 1] : x_dims[x_ndim - 2]; - const int K = trans_x ? x_dims[x_ndim - 2] : x_dims[x_ndim - 1]; - if (trans_y) { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 1], - K, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 1, - K, - y_ndim - 1, - y_dims[y_ndim - 1])); - } else { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 2], - K, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 2, - K, - y_ndim - 2, - y_dims[y_ndim - 2])); - } - const int N = trans_y ? y_dims[y_ndim - 2] : y_dims[y_ndim - 1]; - const int ndim = (std::max)(x_ndim, y_ndim); - std::vector x_broadcast_dims(ndim); - std::vector y_broadcast_dims(ndim); - std::vector out_broadcast_dims(ndim); - - GetBroadcastFromDims(x_ndim - 2, - x_dims.data(), - y_ndim - 2, - y_dims.data(), - x_broadcast_dims.data(), - y_broadcast_dims.data(), - out_broadcast_dims.data()); - out_broadcast_dims[ndim - 2] = M; - out_broadcast_dims[ndim - 1] = N; - - Out->ResizeAndAllocate(common::make_ddim(out_broadcast_dims)); - dev_ctx.template Alloc(Out); - - const int batch_dim = ndim - 2; - // broadcast message - const bool is_broadcast_dims = - !std::equal(x_broadcast_dims.cbegin(), - x_broadcast_dims.cbegin() + batch_dim, - y_broadcast_dims.cbegin()); - - const std::int64_t x_batch_size = - std::accumulate(x_broadcast_dims.cbegin(), - x_broadcast_dims.cbegin() + batch_dim, - 1LL, - std::multiplies()); - const std::int64_t y_batch_size = - std::accumulate(y_broadcast_dims.cbegin(), - y_broadcast_dims.cbegin() + batch_dim, - 1LL, - std::multiplies()); - const std::int64_t out_batch_size = - std::accumulate(out_broadcast_dims.cbegin(), - out_broadcast_dims.cbegin() + batch_dim, - 1LL, - std::multiplies()); - if (out_batch_size == 0) return; - if (x_batch_size == 1 && y_batch_size == 1) { - VLOG(3) << "MatMul's case 8"; - blas.GEMM(trans_x ? CblasTrans : CblasNoTrans, - trans_y ? CblasTrans : CblasNoTrans, - M, - N, - K, - static_cast(1), - x_data, - y_data, - static_cast(flag), - dev_ctx.template Alloc(Out)); - } else if (x_batch_size == 1) { - if (M == 1 && trans_y) { - VLOG(3) << "MatMul's case 9"; - blas.GEMV(false, - y_batch_size * N, - K, - static_cast(1), - y_data, - x_data, - static_cast(flag), - dev_ctx.template Alloc(Out)); - } else { - VLOG(3) << "MatMul's case 10"; - blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, - trans_y ? CblasTrans : CblasNoTrans, - M, - N, - K, - static_cast(1), - x_data, - y_data, - static_cast(flag), - dev_ctx.template Alloc(Out), - out_batch_size, - 0, - K * N); - } - } else if (y_batch_size == 1) { - if (!trans_x) { - VLOG(3) << "MatMul's case 11"; - blas.GEMM(CblasNoTrans, - trans_y ? CblasTrans : CblasNoTrans, - x_batch_size * M, - N, - K, - static_cast(1), - x_data, - y_data, - static_cast(flag), - dev_ctx.template Alloc(Out)); - } else { - VLOG(3) << "MatMul's case 12"; - blas.BatchedGEMM(CblasTrans, - trans_y ? CblasTrans : CblasNoTrans, - M, - N, - K, - static_cast(1), - x_data, - y_data, - static_cast(flag), - dev_ctx.template Alloc(Out), - out_batch_size, - M * K, - 0); - } - } else if (!is_broadcast_dims) { - VLOG(3) << "MatMul's case 13"; - blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, - trans_y ? CblasTrans : CblasNoTrans, - M, - N, - K, - static_cast(1), - x_data, - y_data, - static_cast(flag), - dev_ctx.template Alloc(Out), - out_batch_size, - M * K, - K * N); - } else { - // in the case, can't use stridedgemm - std::vector x_ptr(out_batch_size); - std::vector y_ptr(out_batch_size); - std::vector out_ptr(out_batch_size); - std::vector index(batch_dim, 0); - for (std::int64_t i = 0; i < out_batch_size; ++i) { - // using the index to get offset - const std::int64_t x_index = - GetIndexMessage(batch_dim, x_broadcast_dims.data(), index.data()); - const std::int64_t y_index = - GetIndexMessage(batch_dim, y_broadcast_dims.data(), index.data()); - - x_ptr[i] = x_data + x_index * M * K; - y_ptr[i] = y_data + y_index * K * N; - out_ptr[i] = dev_ctx.template Alloc(Out) + i * M * N; - IndexIncreaseFromDims(batch_dim, out_broadcast_dims.data(), index.data()); - } - VLOG(3) << "MatMul's case 14"; - blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, - trans_y ? CblasTrans : CblasNoTrans, - M, - N, - K, - static_cast(1), - x_ptr.data(), - y_ptr.data(), - static_cast(flag), - out_ptr.data(), - out_batch_size); - } -} - -#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 && 0 -// This is almost a copy from MatMulFunctionImplWithBlas, -// compare cublas with cublasLt kernels when Matmul autotune is on -template -void MatMulFunctionImplWithCublasLt( - const Context& dev_ctx, - const DenseTensor& X, - const DenseTensor& Y, - const std::vector& x_dims, - const std::vector& y_dims, - DenseTensor* Out, - bool trans_x, - bool trans_y, - bool flag = false, - phi::funcs::MatmulPlanner* matmul_planner = nullptr) { - const int x_ndim = x_dims.size(); - const int y_ndim = y_dims.size(); - const T* x_data = X.data(); - const T* y_data = Y.data(); - using blaslt = phi::funcs::MatmulWithCublasLt; - - if (x_ndim == 1 && y_ndim == 1) { - const int M = X.numel(); - const int N = Y.numel(); - PADDLE_ENFORCE_EQ( - M, - N, - phi::errors::InvalidArgument( - "X's numbers must be equal to Y's numbers," - "when X/Y's dims =1. But received X has [%d] elements," - "received Y has [%d] elements", - M, - N)); - - // MatMul's case 0 => vector * vector - Out->Resize(common::make_ddim({})); - dev_ctx.template Alloc(Out); - VLOG(3) << "MatMul with blaslt case 1"; - blaslt::Run(dev_ctx, - y_data, - x_data, - dev_ctx.template Alloc(Out), - 1, - 1, - M, - false, - true, - matmul_planner); - return; - } - - if (x_ndim == 1) { - const int N = X.numel(); - if (trans_y) { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 1], - N, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 1, - N, - y_ndim - 1, - y_dims[y_ndim - 1])); - } else { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 2], - N, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 2, - N, - y_ndim - 2, - y_dims[y_ndim - 2])); - } - std::vector out_dims(y_ndim - 1); - if (trans_y) { - std::copy_n(y_dims.cbegin(), y_ndim - 1, out_dims.begin()); - } else { - std::copy_n(y_dims.cbegin(), y_ndim - 2, out_dims.begin()); - out_dims.back() = y_dims.back(); - } - Out->ResizeAndAllocate(common::make_ddim(out_dims)); - dev_ctx.template Alloc(Out); - if (trans_y) { - const int M = Y.numel() / N; - VLOG(3) << "MatMul with blaslt 2"; - blaslt::Run(dev_ctx, - y_data, - x_data, - dev_ctx.template Alloc(Out), - M, - 1, - N, - false, - false, - matmul_planner); - } else { - const int M = y_dims[y_ndim - 1]; - const int batch_size = Y.numel() / (M * N); - if (batch_size == 1) { - VLOG(3) << "MatMul with blaslt 3"; - blaslt::Run(dev_ctx, - y_data, - x_data, - dev_ctx.template Alloc(Out), - M, - 1, - N, - true, - false, - matmul_planner); - } else { - VLOG(3) << "MatMul with blaslt 4"; - blaslt::RunWithBatch(dev_ctx, - y_data, - x_data, - dev_ctx.template Alloc(Out), - M, - 1, - N, - true, - false, - batch_size, - M * N, - 0, - M, - matmul_planner); - } - } - return; - } - - if (y_ndim == 1) { - const int N = Y.numel(); - if (trans_x) { - PADDLE_ENFORCE_EQ( - x_dims[x_ndim - 2], - N, - phi::errors::InvalidArgument("Input(X) has error dim." - "X'dims[%d] must be equal to %d" - "But received X'dims[%d] is %d", - x_ndim - 2, - N, - x_ndim - 2, - x_dims[x_ndim - 2])); - } else { - PADDLE_ENFORCE_EQ( - x_dims[x_ndim - 1], - N, - phi::errors::InvalidArgument("Input(X) has error dim." - "X'dims[%d] must be equal to %d" - "But received X'dims[%d] is %d", - x_ndim - 1, - N, - x_ndim - 1, - x_dims[x_ndim - 1])); - } - std::vector out_dims(x_ndim - 1); - if (trans_x) { - std::copy_n(x_dims.cbegin(), x_ndim - 2, out_dims.begin()); - out_dims.back() = x_dims.back(); - } else { - std::copy_n(x_dims.cbegin(), x_ndim - 1, out_dims.begin()); - } - Out->ResizeAndAllocate(common::make_ddim(out_dims)); - dev_ctx.template Alloc(Out); - - if (trans_x) { - const int M = x_dims[x_ndim - 1]; - const int batch_size = X.numel() / (M * N); - if (batch_size == 1) { - VLOG(3) << "MatMul with blaslt 5"; - blaslt::Run(dev_ctx, - x_data, - y_data, - dev_ctx.template Alloc(Out), - M, - 1, - N, - true, - false, - matmul_planner); - } else { - VLOG(3) << "MatMul with blaslt 6"; - blaslt::RunWithBatch(dev_ctx, - x_data, - y_data, - dev_ctx.template Alloc(Out), - M, - 1, - N, - true, - false, - batch_size, - M * N, - 0, - M, - matmul_planner); - } - } else { - const int M = X.numel() / N; - VLOG(3) << "MatMul with blaslt 7"; - blaslt::Run(dev_ctx, - x_data, - y_data, - dev_ctx.template Alloc(Out), - M, - 1, - N, - false, - false, - matmul_planner); - } - return; - } - - const int M = trans_x ? x_dims[x_ndim - 1] : x_dims[x_ndim - 2]; - const int K = trans_x ? x_dims[x_ndim - 2] : x_dims[x_ndim - 1]; - if (trans_y) { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 1], - K, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 1, - K, - y_ndim - 1, - y_dims[y_ndim - 1])); - } else { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 2], - K, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 2, - K, - y_ndim - 2, - y_dims[y_ndim - 2])); - } - const int N = trans_y ? y_dims[y_ndim - 2] : y_dims[y_ndim - 1]; - const int ndim = (std::max)(x_ndim, y_ndim); - std::vector x_broadcast_dims(ndim); - std::vector y_broadcast_dims(ndim); - std::vector out_broadcast_dims(ndim); - - GetBroadcastFromDims(x_ndim - 2, - x_dims.data(), - y_ndim - 2, - y_dims.data(), - x_broadcast_dims.data(), - y_broadcast_dims.data(), - out_broadcast_dims.data()); - out_broadcast_dims[ndim - 2] = M; - out_broadcast_dims[ndim - 1] = N; - - Out->ResizeAndAllocate(common::make_ddim(out_broadcast_dims)); - dev_ctx.template Alloc(Out); - - const int batch_dim = ndim - 2; - // broadcast message - const bool is_broadcast_dims = - !std::equal(x_broadcast_dims.cbegin(), - x_broadcast_dims.cbegin() + batch_dim, - y_broadcast_dims.cbegin()); - - const std::int64_t x_batch_size = - std::accumulate(x_broadcast_dims.cbegin(), - x_broadcast_dims.cbegin() + batch_dim, - 1LL, - std::multiplies()); - const std::int64_t y_batch_size = - std::accumulate(y_broadcast_dims.cbegin(), - y_broadcast_dims.cbegin() + batch_dim, - 1LL, - std::multiplies()); - const std::int64_t out_batch_size = - std::accumulate(out_broadcast_dims.cbegin(), - out_broadcast_dims.cbegin() + batch_dim, - 1LL, - std::multiplies()); - if (out_batch_size == 0) return; - if (x_batch_size == 1 && y_batch_size == 1) { - VLOG(3) << "MatMul with blaslt 8"; - blaslt::Run(dev_ctx, - x_data, - y_data, - dev_ctx.template Alloc(Out), - M, - N, - K, - trans_x, - trans_y, - matmul_planner); - } else if (x_batch_size == 1) { - if (M == 1 && trans_y) { - VLOG(3) << "MatMul with blaslt 9"; - blaslt::Run(dev_ctx, - y_data, - x_data, - dev_ctx.template Alloc(Out), - y_batch_size * N, - 1, - K, - false, - false, - matmul_planner); - } else { - VLOG(3) << "MatMul with blaslt 10"; - blaslt::RunWithBatch(dev_ctx, - x_data, - y_data, - dev_ctx.template Alloc(Out), - M, - N, - K, - trans_x, - trans_y, - out_batch_size, - 0, - K * N, - M * N, - matmul_planner); - } - } else if (y_batch_size == 1) { - if (!trans_x) { - VLOG(3) << "MatMul with blaslt 11"; - blaslt::Run(dev_ctx, - x_data, - y_data, - dev_ctx.template Alloc(Out), - x_batch_size * M, - N, - K, - false, - trans_y, - matmul_planner); - } else { - VLOG(3) << "MatMul with blaslt 12"; - blaslt::RunWithBatch(dev_ctx, - x_data, - y_data, - dev_ctx.template Alloc(Out), - M, - N, - K, - true, - trans_y, - out_batch_size, - M * K, - 0, - M * N, - matmul_planner); - } - } else if (!is_broadcast_dims) { - VLOG(3) << "MatMul with blaslt 13"; - blaslt::RunWithBatch(dev_ctx, - x_data, - y_data, - dev_ctx.template Alloc(Out), - M, - N, - K, - trans_x, - trans_y, - out_batch_size, - M * K, - K * N, - M * N, - matmul_planner); - } else { - // in the case, can't use stridedgemm - std::vector x_ptr(out_batch_size); - std::vector y_ptr(out_batch_size); - std::vector out_ptr(out_batch_size); - std::vector index(batch_dim, 0); - for (std::int64_t i = 0; i < out_batch_size; ++i) { - // using the index to get offset - const std::int64_t x_index = - GetIndexMessage(batch_dim, x_broadcast_dims.data(), index.data()); - const std::int64_t y_index = - GetIndexMessage(batch_dim, y_broadcast_dims.data(), index.data()); - - x_ptr[i] = x_data + x_index * M * K; - y_ptr[i] = y_data + y_index * K * N; - out_ptr[i] = dev_ctx.template Alloc(Out) + i * M * N; - IndexIncreaseFromDims(batch_dim, out_broadcast_dims.data(), index.data()); - } - VLOG(3) << "MatMul with blaslt 14"; - blaslt::RunWithBatch(dev_ctx, - x_ptr.data(), - y_ptr.data(), - out_ptr.data(), - M, - N, - K, - trans_x, - trans_y, - out_batch_size, - matmul_planner); - } -} -#endif - -template -struct MatMulDispatcher { - void operator()(const Context& ctx, - const DenseTensor& x, - const DenseTensor& y, - const std::vector& x_dims, - const std::vector& y_dims, - DenseTensor* out, - bool trans_x, - bool trans_y, - bool flag = false) { - MatMulFunctionImplWithBlas( - ctx, x, y, x_dims, y_dims, out, trans_x, trans_y, flag); - } -}; - -#ifdef PADDLE_WITH_CUDA -template -struct MatMulDispatcher { - void operator()(const phi::GPUContext& ctx, - const DenseTensor& x, - const DenseTensor& y, - const std::vector& x_dims, - const std::vector& y_dims, - DenseTensor* out, - bool trans_x, - bool trans_y, - bool flag = false) { -#if CUDA_VERSION >= 11060 && 0 - auto* tuner = phi::autotune::MakeMatmulTuner( - MatMulFunctionImplWithBlas); - tuner->AddCallBack(MatMulFunctionImplWithCublasLt); - phi::funcs::MatmulPlanner matmul_planner(x_dims, - y_dims, - trans_x, - trans_y, - phi::CppTypeToDataType::Type(), - funcs::MatmulFusedType::kMatmul, - /* bias_data */ nullptr, - /* reserve_data */ nullptr, - /* use_addto */ flag, - /* no_exchange */ true); - tuner->Run(ctx, - matmul_planner.GetKey(), - ctx, - x, - y, - x_dims, - y_dims, - out, - trans_x, - trans_y, - flag, - &matmul_planner); -#else - MatMulFunctionImplWithBlas( - ctx, x, y, x_dims, y_dims, out, trans_x, trans_y, flag); -#endif - } -}; - -#endif // PADDLE_WITH_CUDA - -template -void MatMulFunction(const Context& ctx, - const DenseTensor& x, - const DenseTensor& y, - const std::vector& x_dims, - const std::vector& y_dims, - DenseTensor* out, - bool trans_x, - bool trans_y, - bool flag = false) { - MatMulDispatcher()( - ctx, x, y, x_dims, y_dims, out, trans_x, trans_y, flag); -} - -template -bool MatMulInt8Function(const Context& ctx, - const DenseTensor& x, - const DenseTensor& y, - const std::vector& x_dims, - const std::vector& y_dims, - DenseTensor* out, - bool trans_x, - bool trans_y) { - return false; -} - -#ifdef PADDLE_WITH_CUDA -template <> -bool inline MatMulInt8Function(const phi::GPUContext& ctx, - const DenseTensor& x, - const DenseTensor& y, - const std::vector& x_dims, - const std::vector& y_dims, - DenseTensor* out, - bool trans_x, - bool trans_y) { - if (x.dtype() != DataType::INT8 || y.dtype() != DataType::INT8) { - return false; - } -#if CUDA_VERSION >= 11060 && 0 - const int x_ndim = x_dims.size(); - const int y_ndim = y_dims.size(); - const int8_t* x_data = x.data(); - const int8_t* y_data = y.data(); - using blaslt = phi::funcs::MatmulWithCublasLt; - - phi::funcs::MatmulPlanner matmul_planner( - x_dims, - y_dims, - trans_x, - trans_y, - phi::CppTypeToDataType::Type(), - funcs::MatmulFusedType::kMatmul, - /* bias_data */ nullptr, - /* reserve_data */ nullptr, - /* use_addto */ false, - /* no_exchange */ true); - - if (x_ndim == 1 && y_ndim == 1) { - const int M = x.numel(); - const int N = y.numel(); - PADDLE_ENFORCE_EQ( - M, - N, - phi::errors::InvalidArgument( - "X's numbers must be equal to Y's numbers," - "when X/Y's dims =1. But received X has [%d] elements," - "received Y has [%d] elements", - M, - N)); - if (!(M % 4 == 0)) { - return false; - } - - out->Resize(common::make_ddim({})); - ctx.template Alloc(out); - blaslt::Run(ctx, - y_data, - x_data, - ctx.template Alloc(out), - 1, - 1, - M, - false, - true, - &matmul_planner); - return true; - } - if (x_ndim == 1) { - const int N = x.numel(); - if (trans_y) { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 1], - N, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 1, - N, - y_ndim - 1, - y_dims[y_ndim - 1])); - if (!(N % 4 == 0)) { - return false; - } - } else { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 2], - N, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 2, - N, - y_ndim - 2, - y_dims[y_ndim - 2])); - const int M = y.numel() / N; - if (!(M == 1 || M % 4 == 0)) { - return false; - } - } - std::vector out_dims(y_ndim - 1); - if (trans_y) { - std::copy_n(y_dims.cbegin(), y_ndim - 1, out_dims.begin()); - } else { - std::copy_n(y_dims.cbegin(), y_ndim - 2, out_dims.begin()); - out_dims.back() = y_dims.back(); - } - out->ResizeAndAllocate(common::make_ddim(out_dims)); - ctx.template Alloc(out); - if (trans_y) { - const int M = y.numel() / N; - blaslt::Run(ctx, - y_data, - x_data, - ctx.template Alloc(out), - M, - 1, - N, - false, - false, - &matmul_planner); - } else { - const int M = y_dims[y_ndim - 1]; - const int batch_size = y.numel() / (M * N); - if (batch_size == 1) { - blaslt::Run(ctx, - y_data, - x_data, - ctx.template Alloc(out), - M, - 1, - N, - true, - false, - &matmul_planner); - } else { - blaslt::RunWithBatch(ctx, - y_data, - x_data, - ctx.template Alloc(out), - M, - 1, - N, - true, - false, - batch_size, - M * N, - 0, - M, - &matmul_planner); - } - } - return true; - } - - if (y_ndim == 1) { - const int N = y.numel(); - if (trans_x) { - PADDLE_ENFORCE_EQ( - x_dims[x_ndim - 2], - N, - phi::errors::InvalidArgument("Input(X) has error dim." - "X'dims[%d] must be equal to %d" - "But received X'dims[%d] is %d", - x_ndim - 2, - N, - x_ndim - 2, - x_dims[x_ndim - 2])); - const int M = x.numel() / N; - if (!((M == 1 || M % 4 == 0))) { - return false; - } - } else { - PADDLE_ENFORCE_EQ( - x_dims[x_ndim - 1], - N, - phi::errors::InvalidArgument("Input(X) has error dim." - "X'dims[%d] must be equal to %d" - "But received X'dims[%d] is %d", - x_ndim - 1, - N, - x_ndim - 1, - x_dims[x_ndim - 1])); - if (N % 4 != 0) { - return false; - } - } - std::vector out_dims(x_ndim - 1); - if (trans_x) { - std::copy_n(x_dims.cbegin(), x_ndim - 2, out_dims.begin()); - out_dims.back() = x_dims.back(); - } else { - std::copy_n(x_dims.cbegin(), x_ndim - 1, out_dims.begin()); - } - out->ResizeAndAllocate(common::make_ddim(out_dims)); - ctx.template Alloc(out); - - if (trans_x) { - const int M = x_dims[x_ndim - 1]; - const int batch_size = x.numel() / (M * N); - if (batch_size == 1) { - blaslt::Run(ctx, - x_data, - y_data, - ctx.template Alloc(out), - M, - 1, - N, - true, - false, - &matmul_planner); - } else { - blaslt::RunWithBatch(ctx, - x_data, - y_data, - ctx.template Alloc(out), - M, - 1, - N, - true, - false, - batch_size, - M * N, - 0, - M, - &matmul_planner); - } - } else { - const int M = x.numel() / N; - blaslt::Run(ctx, - x_data, - y_data, - ctx.template Alloc(out), - M, - 1, - N, - false, - false, - &matmul_planner); - } - return true; - } - - const int M = trans_x ? x_dims[x_ndim - 1] : x_dims[x_ndim - 2]; - const int K = trans_x ? x_dims[x_ndim - 2] : x_dims[x_ndim - 1]; - if (trans_y) { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 1], - K, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 1, - K, - y_ndim - 1, - y_dims[y_ndim - 1])); - } else { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 2], - K, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 2, - K, - y_ndim - 2, - y_dims[y_ndim - 2])); - } - const int N = trans_y ? y_dims[y_ndim - 2] : y_dims[y_ndim - 1]; - const int ndim = (std::max)(x_ndim, y_ndim); - std::vector x_broadcast_dims(ndim); - std::vector y_broadcast_dims(ndim); - std::vector out_broadcast_dims(ndim); - GetBroadcastFromDims(x_ndim - 2, - x_dims.data(), - y_ndim - 2, - y_dims.data(), - x_broadcast_dims.data(), - y_broadcast_dims.data(), - out_broadcast_dims.data()); - out_broadcast_dims[ndim - 2] = M; - out_broadcast_dims[ndim - 1] = N; - - out->ResizeAndAllocate(common::make_ddim(out_broadcast_dims)); - ctx.template Alloc(out); - - const int batch_dim = ndim - 2; - // broadcast message - const bool is_broadcast_dims = - !std::equal(x_broadcast_dims.cbegin(), - x_broadcast_dims.cbegin() + batch_dim, - y_broadcast_dims.cbegin()); - - const std::int64_t x_batch_size = - std::accumulate(x_broadcast_dims.cbegin(), - x_broadcast_dims.cbegin() + batch_dim, - 1LL, - std::multiplies()); - const std::int64_t y_batch_size = - std::accumulate(y_broadcast_dims.cbegin(), - y_broadcast_dims.cbegin() + batch_dim, - 1LL, - std::multiplies()); - const std::int64_t out_batch_size = - std::accumulate(out_broadcast_dims.cbegin(), - out_broadcast_dims.cbegin() + batch_dim, - 1LL, - std::multiplies()); - if (out_batch_size == 0) return true; - - if (x_batch_size == 1 && M == 1 && trans_y) { - if (!(K % 4 == 0)) { - return false; - } - } else if (!trans_x && !trans_y) { - if (!(N % 4 == 0 || N == 1) || !(K % 4 == 0) || (M == 1 && N == 1)) { - return false; - } - } else if (!trans_x && trans_y) { - if (!(K % 4 == 0)) { - return false; - } - } else if (trans_x && !trans_y) { - if (!(M % 4 == 0 || M == 1) || !(N % 4 == 0 || N == 1)) { - return false; - } - } else { - if (!(M % 4 == 0 || M == 1) || !(K % 4 == 0)) { - return false; - } - } - if (x_batch_size == 1 && y_batch_size == 1) { - blaslt::Run(ctx, - x_data, - y_data, - ctx.template Alloc(out), - M, - N, - K, - trans_x, - trans_y, - &matmul_planner); - } else if (x_batch_size == 1) { - if (M == 1 && trans_y) { - blaslt::Run(ctx, - y_data, - x_data, - ctx.template Alloc(out), - y_batch_size * N, - 1, - K, - false, - false, - &matmul_planner); - } else { - blaslt::RunWithBatch(ctx, - x_data, - y_data, - ctx.template Alloc(out), - M, - N, - K, - trans_x, - trans_y, - out_batch_size, - 0, - K * N, - M * N, - &matmul_planner); - } - } else if (y_batch_size == 1) { - if (!trans_x) { - blaslt::Run(ctx, - x_data, - y_data, - ctx.template Alloc(out), - x_batch_size * M, - N, - K, - false, - trans_y, - &matmul_planner); - } else { - blaslt::RunWithBatch(ctx, - x_data, - y_data, - ctx.template Alloc(out), - M, - N, - K, - true, - trans_y, - out_batch_size, - M * K, - 0, - M * N, - &matmul_planner); - } - } else if (!is_broadcast_dims) { - blaslt::RunWithBatch(ctx, - x_data, - y_data, - ctx.template Alloc(out), - M, - N, - K, - trans_x, - trans_y, - out_batch_size, - M * K, - K * N, - M * N, - &matmul_planner); - } else { - // in the case, can't use stridedgemm - std::vector x_ptr(out_batch_size); - std::vector y_ptr(out_batch_size); - std::vector out_ptr(out_batch_size); - std::vector index(batch_dim, 0); - for (std::int64_t i = 0; i < out_batch_size; ++i) { - // using the index to get offset - const std::int64_t x_index = - GetIndexMessage(batch_dim, x_broadcast_dims.data(), index.data()); - const std::int64_t y_index = - GetIndexMessage(batch_dim, y_broadcast_dims.data(), index.data()); - - x_ptr[i] = x_data + x_index * M * K; - y_ptr[i] = y_data + y_index * K * N; - out_ptr[i] = ctx.template Alloc(out) + i * M * N; - IndexIncreaseFromDims(batch_dim, out_broadcast_dims.data(), index.data()); - } - blaslt::RunWithBatch(ctx, - x_ptr.data(), - y_ptr.data(), - out_ptr.data(), - M, - N, - K, - trans_x, - trans_y, - out_batch_size, - &matmul_planner); - } - return true; -#else - return false; -#endif -} -#endif - -template -typename std::enable_if::value>::type -MatmulJudgeDtypeKernel(const Context& ctx, - const DenseTensor& x, - const DenseTensor& y, - const std::vector& x_dims, - const std::vector& y_dims, - DenseTensor* out, - bool transpose_x, - bool transpose_y) { - bool try_matmul_int8 = MatMulInt8Function( - ctx, x, y, x_dims, y_dims, out, transpose_x, transpose_y); - if (try_matmul_int8) { - return; - } - auto x_tmp = phi::Cast(ctx, x, phi::DataType::FLOAT32); - auto y_tmp = phi::Cast(ctx, y, phi::DataType::FLOAT32); - DenseTensor out_tmp; - MatMulFunction( - ctx, x_tmp, y_tmp, x_dims, y_dims, &out_tmp, transpose_x, transpose_y); - if (x.dtype() == phi::DataType::INT8) { - phi::CastKernel(ctx, out_tmp, phi::DataType::INT32, out); - return; - } - phi::CastKernel(ctx, out_tmp, x.dtype(), out); -} - -template -typename std::enable_if::value>::type -MatmulJudgeDtypeKernel(const Context& ctx, - const DenseTensor& x, - const DenseTensor& y, - const std::vector& x_dims, - const std::vector& y_dims, - DenseTensor* out, - bool transpose_x, - bool transpose_y) { - MatMulFunction( - ctx, x, y, x_dims, y_dims, out, transpose_x, transpose_y); -} - -template -void MatmulKernel(const Context& ctx, - const DenseTensor& x, - const DenseTensor& y, - bool transpose_x, - bool transpose_y, - DenseTensor* out) { - if (x.numel() == 0 || y.numel() == 0) { - // input shape [1, 1, 5, 0], [1, 1, 0, 5], result shape is [1, 1, 5, 5] - phi::Full( - ctx, phi::IntArray(common::vectorize(out->dims())), 0, out); - return; - } - PADDLE_ENFORCE_GE( - common::product(x.dims()), - 0, - common::errors::InvalidArgument( - "The dims of Input(X) should be greater than or equal to 0.")); - PADDLE_ENFORCE_GE( - common::product(y.dims()), - 0, - common::errors::InvalidArgument( - "The dims of Input(Y) should be greater than or equal to 0.")); - const std::vector x_dims = common::vectorize(x.dims()); - const std::vector y_dims = common::vectorize(y.dims()); - MatmulJudgeDtypeKernel( - ctx, x, y, x_dims, y_dims, out, transpose_x, transpose_y); -} - -template -void MatmulWithFlattenKernelImpl(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - int x_num_col_dims, - int y_num_col_dims, - DenseTensor* out) { - const DenseTensor x_matrix = - x.dims().size() > 2 ? phi::ReshapeToMatrix(x, x_num_col_dims) : x; - const DenseTensor y_matrix = - y.dims().size() > 2 ? phi::ReshapeToMatrix(y, y_num_col_dims) : y; - - dev_ctx.template Alloc(out); - auto z_dim = out->dims(); - if (z_dim.size() != 2) { - out->Resize({x_matrix.dims()[0], y_matrix.dims()[1]}); - } - - auto blas = phi::funcs::GetBlas(dev_ctx); - - blas.MatMul(x_matrix, y_matrix, out); - if (z_dim.size() != 2) { - out->Resize(z_dim); - } -} - -#ifdef PADDLE_WITH_CUDA - -template -void MatmulWithFlattenKernelInt8Impl(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - int x_num_col_dims, - int y_num_col_dims, - DenseTensor* out) { - PADDLE_ENFORCE_EQ( - x.dtype(), - DataType::INT8, - phi::errors::InvalidArgument( - "The type of input(x) used in int8 mul must be (%s) " - "does not match the " - "type of data (%s) currently contained in the container.", - phi::CppTypeToDataType::Type(), - x.dtype())); - PADDLE_ENFORCE_EQ( - y.dtype(), - DataType::INT8, - phi::errors::InvalidArgument( - "The type of input(y) used in int8 mul must be (%s) " - "does not match the " - "type of data (%s) currently contained in the container.", - phi::CppTypeToDataType::Type(), - y.dtype())); - - const DenseTensor x_matrix = - x.dims().size() > 2 ? phi::ReshapeToMatrix(x, x_num_col_dims) : x; - const DenseTensor y_matrix = - y.dims().size() > 2 ? phi::ReshapeToMatrix(y, y_num_col_dims) : y; - - PADDLE_ENFORCE_EQ( - x_matrix.dims()[1], - y_matrix.dims()[0], - phi::errors::InvalidArgument( - "X's numbers of columns must be equal to Y's numbers of rows." - "But received X has [%d] columns," - "received Y has [%d] rows", - x_matrix.dims()[1], - y_matrix.dims()[0])); - - PADDLE_ENFORCE_EQ((y_matrix.dims()[1] % 4 == 0 || y_matrix.dims()[1] == 1), - true, - phi::errors::InvalidArgument( - "The dimension size N used in int8 mul must be 1" - "or a multiple of 4 does not match the size (%d)" - "currently contained in the container.", - y_matrix.dims()[1])); - PADDLE_ENFORCE_EQ((x_matrix.dims()[1] % 4 == 0), - true, - phi::errors::InvalidArgument( - "The dimension size K used in int8 mul must be a" - "multiple of 4 does not match the size (%d) currently" - "contained in the container.", - x_matrix.dims()[1])); - - dev_ctx.template Alloc(out); - auto z_dim = out->dims(); - if (z_dim.size() != 2) { - out->Resize({x_matrix.dims()[0], y_matrix.dims()[1]}); - } - -#if CUDA_VERSION >= 11060 && 0 - using blaslt = phi::funcs::MatmulWithCublasLt; - - const int8_t* x_data = x_matrix.data(); - const int8_t* y_data = y_matrix.data(); - - std::vector x_dims = {x_matrix.dims()[0], x_matrix.dims()[1]}; - std::vector y_dims = {y_matrix.dims()[0], y_matrix.dims()[1]}; - phi::funcs::MatmulPlanner matmul_planner( - x_dims, - y_dims, - false, - false, - phi::CppTypeToDataType::Type(), - funcs::MatmulFusedType::kMatmul, - /* bias_data */ nullptr, - /* reserve_data */ nullptr, - /* use_addto */ false, - /* no_exchange */ true); - - blaslt::Run(dev_ctx, - x_data, - y_data, - dev_ctx.template Alloc(out), - x_matrix.dims()[0], - y_matrix.dims()[1], - x_matrix.dims()[1], - false, - false, - &matmul_planner); - - if (z_dim.size() != 2) { - out->Resize(z_dim); - } -#endif -} -#endif - -#ifdef PADDLE_WITH_CUDA -template -typename std::enable_if::value, - void>::type -DispatchMatmulWithFlattenInt8Kernel(const phi::GPUContext& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - int x_num_col_dims, - int y_num_col_dims, - DenseTensor* out) { - MatmulWithFlattenKernelInt8Impl( - dev_ctx, x, y, x_num_col_dims, y_num_col_dims, out); -} -#endif - -template -typename std::enable_if::value, - void>::type -DispatchMatmulWithFlattenInt8Kernel(const phi::CPUContext& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - int x_num_col_dims, - int y_num_col_dims, - DenseTensor* out) { - PADDLE_THROW(phi::errors::Unimplemented( - "MatmulWithFlatten with CPU is NOT implemented " - "yet.")); -} - -template -typename std::enable_if::value, void>::type -DispatchMatmulFlattenKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - int x_num_col_dims, - int y_num_col_dims, - DenseTensor* out) { - DispatchMatmulWithFlattenInt8Kernel( - dev_ctx, x, y, x_num_col_dims, y_num_col_dims, out); -} - -template -typename std::enable_if::value, void>::type -DispatchMatmulFlattenKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - int x_num_col_dims, - int y_num_col_dims, - DenseTensor* out) { - MatmulWithFlattenKernelImpl( - dev_ctx, x, y, x_num_col_dims, y_num_col_dims, out); -} - -template -void MatmulWithFlattenKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - int x_num_col_dims, - int y_num_col_dims, - DenseTensor* out) { - DispatchMatmulFlattenKernel( - dev_ctx, x, y, x_num_col_dims, y_num_col_dims, out); -} - -template -void LegacyMatmulKernel(const Context& ctx, - const DenseTensor& x, - const DenseTensor& y, - bool transpose_x, - bool transpose_y, - float alpha, - DenseTensor* out) { - MatmulKernel(ctx, x, y, transpose_x, transpose_y, out); - if (std::fabs(alpha - 1.f) > 1e-6f) { - ScaleKernel(ctx, *out, Scalar(alpha), Scalar(0), false, out); - } -} - -} // namespace phi diff --git a/backends/metax_gpu/kernels/impl/matmul_kernel_impl_maca.h b/backends/metax_gpu/kernels/impl/matmul_kernel_impl_maca.h deleted file mode 100644 index 9750abae5ca..00000000000 --- a/backends/metax_gpu/kernels/impl/matmul_kernel_impl_maca.h +++ /dev/null @@ -1,1696 +0,0 @@ -// 2024 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights -// Reserved. -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -// clang-format off -#include "glog/logging.h" - -#include "paddle/phi/common/memory_utils.h" -#include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/kernels/autotune/cache_base.h" -#include "paddle/phi/kernels/cast_kernel.h" -#include "../funcs/blas/blas.h" -#ifdef PADDLE_WITH_HIP -#include "paddle/phi/kernels/funcs/blas/blaslt_impl.hip.h" -#else -#include "../funcs/blas/blaslt_impl.cu.h" -#endif -#include "paddle/phi/kernels/funcs/complex_functors.h" -#include "paddle/phi/kernels/scale_kernel.h" -#if defined(PADDLE_WITH_CUDA) -#include "paddle/phi/kernels/funcs/cublaslt.h" -#include "paddle/phi/kernels/gpu/cuda_gemm_kernel.h" -#include "paddle/phi/kernels/transpose_kernel.h" -#elif defined(PADDLE_WITH_HIP) -#include "paddle/phi/kernels/funcs/hipblaslt.h" -#endif -#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 && 0 -#include "paddle/phi/kernels/autotune/auto_tune_base.h" -#endif -// clang-format on -namespace phi { - -static void GetBroadcastFromDims(const int x_ndim, - const std::int64_t* x_dims, - const int y_ndim, - const std::int64_t* y_dims, - std::int64_t* x_bd_dims, - std::int64_t* y_bd_dims, - std::int64_t* out_bd_dims) { - const int ndim = (std::max)(x_ndim, y_ndim); - std::fill(x_bd_dims, x_bd_dims + ndim - x_ndim, 1); - std::fill(y_bd_dims, y_bd_dims + ndim - y_ndim, 1); - std::copy(x_dims, x_dims + x_ndim, x_bd_dims + ndim - x_ndim); - std::copy(y_dims, y_dims + y_ndim, y_bd_dims + ndim - y_ndim); - - for (int i = 0; i < ndim; ++i) { - PADDLE_ENFORCE_EQ( - x_bd_dims[i] == y_bd_dims[i] || x_bd_dims[i] <= 1 || y_bd_dims[i] <= 1, - true, - phi::errors::InvalidArgument( - "Input(X) and Input(Y) has error dim. " - "X_broadcast's shape[%s] must be equal to Y_broadcast's shape[%s], " - "or X_broadcast's shape[%s] <= 1, or Y_broadcast's shape[%s] <= 1, " - "but received X_broadcast's shape[%s] = [%s]" - "received Y_broadcast's shape[%s] = [%s].", - i, - i, - i, - i, - i, - x_bd_dims[i], - i, - y_bd_dims[i])); - if (x_bd_dims[i] == 0 || y_bd_dims[i] == 0) { - out_bd_dims[i] = 0; - } else { - out_bd_dims[i] = (std::max)(x_bd_dims[i], y_bd_dims[i]); - } - } -} - -static int64_t GetIndexMessage(const int n, - const int64_t* dims, - const int64_t* index) { - int64_t sum = 0; - for (int i = 0; i < n; ++i) { - if (dims[i] > 1) { - sum = sum * dims[i] + index[i]; - } - } - return sum; -} - -static void IndexIncreaseFromDims(const int ndim, - const int64_t* dims, - int64_t* index) { - for (int i = ndim - 1; i >= 0; --i) { - ++index[i]; - if (index[i] >= dims[i]) { - index[i] -= dims[i]; - } else { - break; - } - } -} - -// The general implementation with blas. -template -void MatMulFunctionImplWithBlas( - const Context& dev_ctx, - const DenseTensor& X, - const DenseTensor& Y, - const std::vector& x_dims, - const std::vector& y_dims, - DenseTensor* Out, - bool trans_x, - bool trans_y, - bool flag = false, - phi::funcs::MatmulPlanner* matmul_planner UNUSED = nullptr) { - const int x_ndim = x_dims.size(); - const int y_ndim = y_dims.size(); - - // Get data ptr - const T* x_data = X.data(); - const T* y_data = Y.data(); - - auto blas = phi::funcs::GetBlas(dev_ctx); - - if (x_ndim == 1 && y_ndim == 1) { - const int M = X.numel(); - const int N = Y.numel(); - PADDLE_ENFORCE_EQ( - M, - N, - phi::errors::InvalidArgument( - "X's numbers must be equal to Y's numbers, " - "when X/Y's dims =1. But received X has [%d] elements, " - "received Y has [%d] elements.", - M, - N)); - VLOG(3) << "MatMul's case 1"; - Out->Resize(common::make_ddim({})); - dev_ctx.template Alloc(Out); - blas.GEMM(CblasNoTrans, - CblasTrans, - 1, - 1, - M, - static_cast(1), - y_data, - x_data, - static_cast(flag), - dev_ctx.template Alloc(Out)); - return; - } - - if (x_ndim == 1) { - const int N = X.numel(); - if (trans_y) { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 1], - N, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 1, - N, - y_ndim - 1, - y_dims[y_ndim - 1])); - } else { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 2], - N, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 2, - N, - y_ndim - 2, - y_dims[y_ndim - 2])); - } - std::vector out_dims(y_ndim - 1); - if (trans_y) { - std::copy_n(y_dims.cbegin(), y_ndim - 1, out_dims.begin()); - } else { - std::copy_n(y_dims.cbegin(), y_ndim - 2, out_dims.begin()); - out_dims.back() = y_dims.back(); - } - Out->ResizeAndAllocate(common::make_ddim(out_dims)); - dev_ctx.template Alloc(Out); - if (trans_y) { - const int M = Y.numel() / N; - VLOG(3) << "MatMul's case 2"; - blas.GEMV(false, - M, - N, - static_cast(1), - y_data, - x_data, - static_cast(flag), - dev_ctx.template Alloc(Out)); - } else { - const int M = y_dims[y_ndim - 1]; - const int batch_size = Y.numel() / (M * N); - if (batch_size == 1) { - VLOG(3) << "MatMul's case 3"; - blas.GEMV(true, - N, - M, - static_cast(1), - y_data, - x_data, - static_cast(flag), - dev_ctx.template Alloc(Out)); - } else { - VLOG(3) << "MatMul's case 4"; - blas.BatchedGEMM(CblasTrans, - CblasNoTrans, - M, - 1, - N, - static_cast(1), - y_data, - x_data, - static_cast(flag), - dev_ctx.template Alloc(Out), - batch_size, - M * N, - 0); - } - } - return; - } - - if (y_ndim == 1) { - const int N = Y.numel(); - if (trans_x) { - PADDLE_ENFORCE_EQ( - x_dims[x_ndim - 2], - N, - phi::errors::InvalidArgument("Input(X) has error dim." - "X'dims[%d] must be equal to %d" - "But received X'dims[%d] is %d", - x_ndim - 2, - N, - x_ndim - 2, - x_dims[x_ndim - 2])); - } else { - PADDLE_ENFORCE_EQ( - x_dims[x_ndim - 1], - N, - phi::errors::InvalidArgument("Input(X) has error dim." - "X'dims[%d] must be equal to %d" - "But received X'dims[%d] is %d", - x_ndim - 1, - N, - x_ndim - 1, - x_dims[x_ndim - 1])); - } - std::vector out_dims(x_ndim - 1); - if (trans_x) { - std::copy_n(x_dims.cbegin(), x_ndim - 2, out_dims.begin()); - out_dims.back() = x_dims.back(); - } else { - std::copy_n(x_dims.cbegin(), x_ndim - 1, out_dims.begin()); - } - Out->ResizeAndAllocate(common::make_ddim(out_dims)); - dev_ctx.template Alloc(Out); - - if (trans_x) { - const int M = x_dims[x_ndim - 1]; - const int batch_size = X.numel() / (M * N); - if (batch_size == 1) { - VLOG(3) << "MatMul's case 5"; - blas.GEMV(true, - N, - M, - static_cast(1), - x_data, - y_data, - static_cast(flag), - dev_ctx.template Alloc(Out)); - } else { - VLOG(3) << "MatMul's case 6"; - blas.BatchedGEMM(CblasTrans, - CblasNoTrans, - M, - 1, - N, - static_cast(1), - x_data, - y_data, - static_cast(flag), - dev_ctx.template Alloc(Out), - batch_size, - M * N, - 0); - } - } else { - const int M = X.numel() / N; - VLOG(3) << "MatMul's case 7"; - blas.GEMV(false, - M, - N, - static_cast(1), - x_data, - y_data, - static_cast(flag), - dev_ctx.template Alloc(Out)); - } - return; - } - - const int M = trans_x ? x_dims[x_ndim - 1] : x_dims[x_ndim - 2]; - const int K = trans_x ? x_dims[x_ndim - 2] : x_dims[x_ndim - 1]; - if (trans_y) { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 1], - K, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 1, - K, - y_ndim - 1, - y_dims[y_ndim - 1])); - } else { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 2], - K, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 2, - K, - y_ndim - 2, - y_dims[y_ndim - 2])); - } - const int N = trans_y ? y_dims[y_ndim - 2] : y_dims[y_ndim - 1]; - const int ndim = (std::max)(x_ndim, y_ndim); - std::vector x_broadcast_dims(ndim); - std::vector y_broadcast_dims(ndim); - std::vector out_broadcast_dims(ndim); - - GetBroadcastFromDims(x_ndim - 2, - x_dims.data(), - y_ndim - 2, - y_dims.data(), - x_broadcast_dims.data(), - y_broadcast_dims.data(), - out_broadcast_dims.data()); - out_broadcast_dims[ndim - 2] = M; - out_broadcast_dims[ndim - 1] = N; - - Out->ResizeAndAllocate(common::make_ddim(out_broadcast_dims)); - dev_ctx.template Alloc(Out); - - const int batch_dim = ndim - 2; - // broadcast message - const bool is_broadcast_dims = - !std::equal(x_broadcast_dims.cbegin(), - x_broadcast_dims.cbegin() + batch_dim, - y_broadcast_dims.cbegin()); - - const std::int64_t x_batch_size = - std::accumulate(x_broadcast_dims.cbegin(), - x_broadcast_dims.cbegin() + batch_dim, - 1LL, - std::multiplies()); - const std::int64_t y_batch_size = - std::accumulate(y_broadcast_dims.cbegin(), - y_broadcast_dims.cbegin() + batch_dim, - 1LL, - std::multiplies()); - const std::int64_t out_batch_size = - std::accumulate(out_broadcast_dims.cbegin(), - out_broadcast_dims.cbegin() + batch_dim, - 1LL, - std::multiplies()); - if (out_batch_size == 0) return; - if (x_batch_size == 1 && y_batch_size == 1) { - VLOG(3) << "MatMul's case 8"; - blas.GEMM(trans_x ? CblasTrans : CblasNoTrans, - trans_y ? CblasTrans : CblasNoTrans, - M, - N, - K, - static_cast(1), - x_data, - y_data, - static_cast(flag), - dev_ctx.template Alloc(Out)); - } else if (x_batch_size == 1) { - if (M == 1 && trans_y) { - VLOG(3) << "MatMul's case 9"; - blas.GEMV(false, - y_batch_size * N, - K, - static_cast(1), - y_data, - x_data, - static_cast(flag), - dev_ctx.template Alloc(Out)); - } else { - VLOG(3) << "MatMul's case 10"; - blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, - trans_y ? CblasTrans : CblasNoTrans, - M, - N, - K, - static_cast(1), - x_data, - y_data, - static_cast(flag), - dev_ctx.template Alloc(Out), - out_batch_size, - 0, - K * N); - } - } else if (y_batch_size == 1) { - if (!trans_x) { - VLOG(3) << "MatMul's case 11"; - blas.GEMM(CblasNoTrans, - trans_y ? CblasTrans : CblasNoTrans, - x_batch_size * M, - N, - K, - static_cast(1), - x_data, - y_data, - static_cast(flag), - dev_ctx.template Alloc(Out)); - } else { - VLOG(3) << "MatMul's case 12"; - blas.BatchedGEMM(CblasTrans, - trans_y ? CblasTrans : CblasNoTrans, - M, - N, - K, - static_cast(1), - x_data, - y_data, - static_cast(flag), - dev_ctx.template Alloc(Out), - out_batch_size, - M * K, - 0); - } - } else if (!is_broadcast_dims) { - VLOG(3) << "MatMul's case 13"; - blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, - trans_y ? CblasTrans : CblasNoTrans, - M, - N, - K, - static_cast(1), - x_data, - y_data, - static_cast(flag), - dev_ctx.template Alloc(Out), - out_batch_size, - M * K, - K * N); - } else { - // in the case, can't use stridedgemm - std::vector x_ptr(out_batch_size); - std::vector y_ptr(out_batch_size); - std::vector out_ptr(out_batch_size); - std::vector index(batch_dim, 0); - for (std::int64_t i = 0; i < out_batch_size; ++i) { - // using the index to get offset - const std::int64_t x_index = - GetIndexMessage(batch_dim, x_broadcast_dims.data(), index.data()); - const std::int64_t y_index = - GetIndexMessage(batch_dim, y_broadcast_dims.data(), index.data()); - - x_ptr[i] = x_data + x_index * M * K; - y_ptr[i] = y_data + y_index * K * N; - out_ptr[i] = dev_ctx.template Alloc(Out) + i * M * N; - IndexIncreaseFromDims(batch_dim, out_broadcast_dims.data(), index.data()); - } - VLOG(3) << "MatMul's case 14"; - blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, - trans_y ? CblasTrans : CblasNoTrans, - M, - N, - K, - static_cast(1), - x_ptr.data(), - y_ptr.data(), - static_cast(flag), - out_ptr.data(), - out_batch_size); - } -} - -#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 && 0 -// This is almost a copy from MatMulFunctionImplWithBlas, -// compare cublas with cublasLt kernels when Matmul autotune is on -template -void MatMulFunctionImplWithCublasLt( - const Context& dev_ctx, - const DenseTensor& X, - const DenseTensor& Y, - const std::vector& x_dims, - const std::vector& y_dims, - DenseTensor* Out, - bool trans_x, - bool trans_y, - bool flag = false, - phi::funcs::MatmulPlanner* matmul_planner = nullptr) { - const int x_ndim = x_dims.size(); - const int y_ndim = y_dims.size(); - const T* x_data = X.data(); - const T* y_data = Y.data(); - using blaslt = phi::funcs::MatmulWithCublasLt; - - if (x_ndim == 1 && y_ndim == 1) { - const int M = X.numel(); - const int N = Y.numel(); - PADDLE_ENFORCE_EQ( - M, - N, - phi::errors::InvalidArgument( - "X's numbers must be equal to Y's numbers," - "when X/Y's dims =1. But received X has [%d] elements," - "received Y has [%d] elements", - M, - N)); - - // MatMul's case 0 => vector * vector - Out->Resize(common::make_ddim({})); - dev_ctx.template Alloc(Out); - VLOG(3) << "MatMul with blaslt case 1"; - blaslt::Run(dev_ctx, - y_data, - x_data, - dev_ctx.template Alloc(Out), - 1, - 1, - M, - false, - true, - matmul_planner); - return; - } - - if (x_ndim == 1) { - const int N = X.numel(); - if (trans_y) { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 1], - N, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 1, - N, - y_ndim - 1, - y_dims[y_ndim - 1])); - } else { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 2], - N, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 2, - N, - y_ndim - 2, - y_dims[y_ndim - 2])); - } - std::vector out_dims(y_ndim - 1); - if (trans_y) { - std::copy_n(y_dims.cbegin(), y_ndim - 1, out_dims.begin()); - } else { - std::copy_n(y_dims.cbegin(), y_ndim - 2, out_dims.begin()); - out_dims.back() = y_dims.back(); - } - Out->ResizeAndAllocate(common::make_ddim(out_dims)); - dev_ctx.template Alloc(Out); - if (trans_y) { - const int M = Y.numel() / N; - VLOG(3) << "MatMul with blaslt 2"; - blaslt::Run(dev_ctx, - y_data, - x_data, - dev_ctx.template Alloc(Out), - M, - 1, - N, - false, - false, - matmul_planner); - } else { - const int M = y_dims[y_ndim - 1]; - const int batch_size = Y.numel() / (M * N); - if (batch_size == 1) { - VLOG(3) << "MatMul with blaslt 3"; - blaslt::Run(dev_ctx, - y_data, - x_data, - dev_ctx.template Alloc(Out), - M, - 1, - N, - true, - false, - matmul_planner); - } else { - VLOG(3) << "MatMul with blaslt 4"; - blaslt::RunWithBatch(dev_ctx, - y_data, - x_data, - dev_ctx.template Alloc(Out), - M, - 1, - N, - true, - false, - batch_size, - M * N, - 0, - M, - matmul_planner); - } - } - return; - } - - if (y_ndim == 1) { - const int N = Y.numel(); - if (trans_x) { - PADDLE_ENFORCE_EQ( - x_dims[x_ndim - 2], - N, - phi::errors::InvalidArgument("Input(X) has error dim." - "X'dims[%d] must be equal to %d" - "But received X'dims[%d] is %d", - x_ndim - 2, - N, - x_ndim - 2, - x_dims[x_ndim - 2])); - } else { - PADDLE_ENFORCE_EQ( - x_dims[x_ndim - 1], - N, - phi::errors::InvalidArgument("Input(X) has error dim." - "X'dims[%d] must be equal to %d" - "But received X'dims[%d] is %d", - x_ndim - 1, - N, - x_ndim - 1, - x_dims[x_ndim - 1])); - } - std::vector out_dims(x_ndim - 1); - if (trans_x) { - std::copy_n(x_dims.cbegin(), x_ndim - 2, out_dims.begin()); - out_dims.back() = x_dims.back(); - } else { - std::copy_n(x_dims.cbegin(), x_ndim - 1, out_dims.begin()); - } - Out->ResizeAndAllocate(common::make_ddim(out_dims)); - dev_ctx.template Alloc(Out); - - if (trans_x) { - const int M = x_dims[x_ndim - 1]; - const int batch_size = X.numel() / (M * N); - if (batch_size == 1) { - VLOG(3) << "MatMul with blaslt 5"; - blaslt::Run(dev_ctx, - x_data, - y_data, - dev_ctx.template Alloc(Out), - M, - 1, - N, - true, - false, - matmul_planner); - } else { - VLOG(3) << "MatMul with blaslt 6"; - blaslt::RunWithBatch(dev_ctx, - x_data, - y_data, - dev_ctx.template Alloc(Out), - M, - 1, - N, - true, - false, - batch_size, - M * N, - 0, - M, - matmul_planner); - } - } else { - const int M = X.numel() / N; - VLOG(3) << "MatMul with blaslt 7"; - blaslt::Run(dev_ctx, - x_data, - y_data, - dev_ctx.template Alloc(Out), - M, - 1, - N, - false, - false, - matmul_planner); - } - return; - } - - const int M = trans_x ? x_dims[x_ndim - 1] : x_dims[x_ndim - 2]; - const int K = trans_x ? x_dims[x_ndim - 2] : x_dims[x_ndim - 1]; - if (trans_y) { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 1], - K, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 1, - K, - y_ndim - 1, - y_dims[y_ndim - 1])); - } else { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 2], - K, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 2, - K, - y_ndim - 2, - y_dims[y_ndim - 2])); - } - const int N = trans_y ? y_dims[y_ndim - 2] : y_dims[y_ndim - 1]; - const int ndim = (std::max)(x_ndim, y_ndim); - std::vector x_broadcast_dims(ndim); - std::vector y_broadcast_dims(ndim); - std::vector out_broadcast_dims(ndim); - - GetBroadcastFromDims(x_ndim - 2, - x_dims.data(), - y_ndim - 2, - y_dims.data(), - x_broadcast_dims.data(), - y_broadcast_dims.data(), - out_broadcast_dims.data()); - out_broadcast_dims[ndim - 2] = M; - out_broadcast_dims[ndim - 1] = N; - - Out->ResizeAndAllocate(common::make_ddim(out_broadcast_dims)); - dev_ctx.template Alloc(Out); - - const int batch_dim = ndim - 2; - // broadcast message - const bool is_broadcast_dims = - !std::equal(x_broadcast_dims.cbegin(), - x_broadcast_dims.cbegin() + batch_dim, - y_broadcast_dims.cbegin()); - - const std::int64_t x_batch_size = - std::accumulate(x_broadcast_dims.cbegin(), - x_broadcast_dims.cbegin() + batch_dim, - 1LL, - std::multiplies()); - const std::int64_t y_batch_size = - std::accumulate(y_broadcast_dims.cbegin(), - y_broadcast_dims.cbegin() + batch_dim, - 1LL, - std::multiplies()); - const std::int64_t out_batch_size = - std::accumulate(out_broadcast_dims.cbegin(), - out_broadcast_dims.cbegin() + batch_dim, - 1LL, - std::multiplies()); - if (out_batch_size == 0) return; - if (x_batch_size == 1 && y_batch_size == 1) { - VLOG(3) << "MatMul with blaslt 8"; - blaslt::Run(dev_ctx, - x_data, - y_data, - dev_ctx.template Alloc(Out), - M, - N, - K, - trans_x, - trans_y, - matmul_planner); - } else if (x_batch_size == 1) { - if (M == 1 && trans_y) { - VLOG(3) << "MatMul with blaslt 9"; - blaslt::Run(dev_ctx, - y_data, - x_data, - dev_ctx.template Alloc(Out), - y_batch_size * N, - 1, - K, - false, - false, - matmul_planner); - } else { - VLOG(3) << "MatMul with blaslt 10"; - blaslt::RunWithBatch(dev_ctx, - x_data, - y_data, - dev_ctx.template Alloc(Out), - M, - N, - K, - trans_x, - trans_y, - out_batch_size, - 0, - K * N, - M * N, - matmul_planner); - } - } else if (y_batch_size == 1) { - if (!trans_x) { - VLOG(3) << "MatMul with blaslt 11"; - blaslt::Run(dev_ctx, - x_data, - y_data, - dev_ctx.template Alloc(Out), - x_batch_size * M, - N, - K, - false, - trans_y, - matmul_planner); - } else { - VLOG(3) << "MatMul with blaslt 12"; - blaslt::RunWithBatch(dev_ctx, - x_data, - y_data, - dev_ctx.template Alloc(Out), - M, - N, - K, - true, - trans_y, - out_batch_size, - M * K, - 0, - M * N, - matmul_planner); - } - } else if (!is_broadcast_dims) { - VLOG(3) << "MatMul with blaslt 13"; - blaslt::RunWithBatch(dev_ctx, - x_data, - y_data, - dev_ctx.template Alloc(Out), - M, - N, - K, - trans_x, - trans_y, - out_batch_size, - M * K, - K * N, - M * N, - matmul_planner); - } else { - // in the case, can't use stridedgemm - std::vector x_ptr(out_batch_size); - std::vector y_ptr(out_batch_size); - std::vector out_ptr(out_batch_size); - std::vector index(batch_dim, 0); - for (std::int64_t i = 0; i < out_batch_size; ++i) { - // using the index to get offset - const std::int64_t x_index = - GetIndexMessage(batch_dim, x_broadcast_dims.data(), index.data()); - const std::int64_t y_index = - GetIndexMessage(batch_dim, y_broadcast_dims.data(), index.data()); - - x_ptr[i] = x_data + x_index * M * K; - y_ptr[i] = y_data + y_index * K * N; - out_ptr[i] = dev_ctx.template Alloc(Out) + i * M * N; - IndexIncreaseFromDims(batch_dim, out_broadcast_dims.data(), index.data()); - } - VLOG(3) << "MatMul with blaslt 14"; - blaslt::RunWithBatch(dev_ctx, - x_ptr.data(), - y_ptr.data(), - out_ptr.data(), - M, - N, - K, - trans_x, - trans_y, - out_batch_size, - matmul_planner); - } -} -#endif - -template -struct MatMulDispatcher { - void operator()(const Context& ctx, - const DenseTensor& x, - const DenseTensor& y, - const std::vector& x_dims, - const std::vector& y_dims, - DenseTensor* out, - bool trans_x, - bool trans_y, - bool flag = false) { - MatMulFunctionImplWithBlas( - ctx, x, y, x_dims, y_dims, out, trans_x, trans_y, flag); - } -}; - -#ifdef PADDLE_WITH_CUDA -template -struct MatMulDispatcher { - void operator()(const phi::GPUContext& ctx, - const DenseTensor& x, - const DenseTensor& y, - const std::vector& x_dims, - const std::vector& y_dims, - DenseTensor* out, - bool trans_x, - bool trans_y, - bool flag = false) { -#if CUDA_VERSION >= 11060 && 0 - auto* tuner = phi::autotune::MakeMatmulTuner( - MatMulFunctionImplWithBlas); - tuner->AddCallBack(MatMulFunctionImplWithCublasLt); - phi::funcs::MatmulPlanner matmul_planner(x_dims, - y_dims, - trans_x, - trans_y, - phi::CppTypeToDataType::Type(), - funcs::MatmulFusedType::kMatmul, - /* bias_data */ nullptr, - /* reserve_data */ nullptr, - /* use_addto */ flag, - /* no_exchange */ true); - tuner->Run(ctx, - matmul_planner.GetKey(), - ctx, - x, - y, - x_dims, - y_dims, - out, - trans_x, - trans_y, - flag, - &matmul_planner); -#else - MatMulFunctionImplWithBlas( - ctx, x, y, x_dims, y_dims, out, trans_x, trans_y, flag); -#endif - } -}; - -#endif // PADDLE_WITH_CUDA - -template -void MatMulFunction(const Context& ctx, - const DenseTensor& x, - const DenseTensor& y, - const std::vector& x_dims, - const std::vector& y_dims, - DenseTensor* out, - bool trans_x, - bool trans_y, - bool flag = false) { - MatMulDispatcher()( - ctx, x, y, x_dims, y_dims, out, trans_x, trans_y, flag); -} - -template -bool MatMulInt8Function(const Context& ctx, - const DenseTensor& x, - const DenseTensor& y, - const std::vector& x_dims, - const std::vector& y_dims, - DenseTensor* out, - bool trans_x, - bool trans_y) { - return false; -} - -#ifdef PADDLE_WITH_CUDA -template <> -bool inline MatMulInt8Function(const phi::GPUContext& ctx, - const DenseTensor& x, - const DenseTensor& y, - const std::vector& x_dims, - const std::vector& y_dims, - DenseTensor* out, - bool trans_x, - bool trans_y) { - if (x.dtype() != DataType::INT8 || y.dtype() != DataType::INT8) { - return false; - } -#if CUDA_VERSION >= 11060 && 0 - const int x_ndim = x_dims.size(); - const int y_ndim = y_dims.size(); - const int8_t* x_data = x.data(); - const int8_t* y_data = y.data(); - using blaslt = phi::funcs::MatmulWithCublasLt; - - phi::funcs::MatmulPlanner matmul_planner( - x_dims, - y_dims, - trans_x, - trans_y, - phi::CppTypeToDataType::Type(), - funcs::MatmulFusedType::kMatmul, - /* bias_data */ nullptr, - /* reserve_data */ nullptr, - /* use_addto */ false, - /* no_exchange */ true); - - if (x_ndim == 1 && y_ndim == 1) { - const int M = x.numel(); - const int N = y.numel(); - PADDLE_ENFORCE_EQ( - M, - N, - phi::errors::InvalidArgument( - "X's numbers must be equal to Y's numbers," - "when X/Y's dims =1. But received X has [%d] elements," - "received Y has [%d] elements", - M, - N)); - if (!(M % 4 == 0)) { - return false; - } - - out->Resize(common::make_ddim({})); - ctx.template Alloc(out); - blaslt::Run(ctx, - y_data, - x_data, - ctx.template Alloc(out), - 1, - 1, - M, - false, - true, - &matmul_planner); - return true; - } - if (x_ndim == 1) { - const int N = x.numel(); - if (trans_y) { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 1], - N, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 1, - N, - y_ndim - 1, - y_dims[y_ndim - 1])); - if (!(N % 4 == 0)) { - return false; - } - } else { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 2], - N, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 2, - N, - y_ndim - 2, - y_dims[y_ndim - 2])); - const int M = y.numel() / N; - if (!(M == 1 || M % 4 == 0)) { - return false; - } - } - std::vector out_dims(y_ndim - 1); - if (trans_y) { - std::copy_n(y_dims.cbegin(), y_ndim - 1, out_dims.begin()); - } else { - std::copy_n(y_dims.cbegin(), y_ndim - 2, out_dims.begin()); - out_dims.back() = y_dims.back(); - } - out->ResizeAndAllocate(common::make_ddim(out_dims)); - ctx.template Alloc(out); - if (trans_y) { - const int M = y.numel() / N; - blaslt::Run(ctx, - y_data, - x_data, - ctx.template Alloc(out), - M, - 1, - N, - false, - false, - &matmul_planner); - } else { - const int M = y_dims[y_ndim - 1]; - const int batch_size = y.numel() / (M * N); - if (batch_size == 1) { - blaslt::Run(ctx, - y_data, - x_data, - ctx.template Alloc(out), - M, - 1, - N, - true, - false, - &matmul_planner); - } else { - blaslt::RunWithBatch(ctx, - y_data, - x_data, - ctx.template Alloc(out), - M, - 1, - N, - true, - false, - batch_size, - M * N, - 0, - M, - &matmul_planner); - } - } - return true; - } - - if (y_ndim == 1) { - const int N = y.numel(); - if (trans_x) { - PADDLE_ENFORCE_EQ( - x_dims[x_ndim - 2], - N, - phi::errors::InvalidArgument("Input(X) has error dim." - "X'dims[%d] must be equal to %d" - "But received X'dims[%d] is %d", - x_ndim - 2, - N, - x_ndim - 2, - x_dims[x_ndim - 2])); - const int M = x.numel() / N; - if (!((M == 1 || M % 4 == 0))) { - return false; - } - } else { - PADDLE_ENFORCE_EQ( - x_dims[x_ndim - 1], - N, - phi::errors::InvalidArgument("Input(X) has error dim." - "X'dims[%d] must be equal to %d" - "But received X'dims[%d] is %d", - x_ndim - 1, - N, - x_ndim - 1, - x_dims[x_ndim - 1])); - if (N % 4 != 0) { - return false; - } - } - std::vector out_dims(x_ndim - 1); - if (trans_x) { - std::copy_n(x_dims.cbegin(), x_ndim - 2, out_dims.begin()); - out_dims.back() = x_dims.back(); - } else { - std::copy_n(x_dims.cbegin(), x_ndim - 1, out_dims.begin()); - } - out->ResizeAndAllocate(common::make_ddim(out_dims)); - ctx.template Alloc(out); - - if (trans_x) { - const int M = x_dims[x_ndim - 1]; - const int batch_size = x.numel() / (M * N); - if (batch_size == 1) { - blaslt::Run(ctx, - x_data, - y_data, - ctx.template Alloc(out), - M, - 1, - N, - true, - false, - &matmul_planner); - } else { - blaslt::RunWithBatch(ctx, - x_data, - y_data, - ctx.template Alloc(out), - M, - 1, - N, - true, - false, - batch_size, - M * N, - 0, - M, - &matmul_planner); - } - } else { - const int M = x.numel() / N; - blaslt::Run(ctx, - x_data, - y_data, - ctx.template Alloc(out), - M, - 1, - N, - false, - false, - &matmul_planner); - } - return true; - } - - const int M = trans_x ? x_dims[x_ndim - 1] : x_dims[x_ndim - 2]; - const int K = trans_x ? x_dims[x_ndim - 2] : x_dims[x_ndim - 1]; - if (trans_y) { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 1], - K, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 1, - K, - y_ndim - 1, - y_dims[y_ndim - 1])); - } else { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 2], - K, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 2, - K, - y_ndim - 2, - y_dims[y_ndim - 2])); - } - const int N = trans_y ? y_dims[y_ndim - 2] : y_dims[y_ndim - 1]; - const int ndim = (std::max)(x_ndim, y_ndim); - std::vector x_broadcast_dims(ndim); - std::vector y_broadcast_dims(ndim); - std::vector out_broadcast_dims(ndim); - GetBroadcastFromDims(x_ndim - 2, - x_dims.data(), - y_ndim - 2, - y_dims.data(), - x_broadcast_dims.data(), - y_broadcast_dims.data(), - out_broadcast_dims.data()); - out_broadcast_dims[ndim - 2] = M; - out_broadcast_dims[ndim - 1] = N; - - out->ResizeAndAllocate(common::make_ddim(out_broadcast_dims)); - ctx.template Alloc(out); - - const int batch_dim = ndim - 2; - // broadcast message - const bool is_broadcast_dims = - !std::equal(x_broadcast_dims.cbegin(), - x_broadcast_dims.cbegin() + batch_dim, - y_broadcast_dims.cbegin()); - - const std::int64_t x_batch_size = - std::accumulate(x_broadcast_dims.cbegin(), - x_broadcast_dims.cbegin() + batch_dim, - 1LL, - std::multiplies()); - const std::int64_t y_batch_size = - std::accumulate(y_broadcast_dims.cbegin(), - y_broadcast_dims.cbegin() + batch_dim, - 1LL, - std::multiplies()); - const std::int64_t out_batch_size = - std::accumulate(out_broadcast_dims.cbegin(), - out_broadcast_dims.cbegin() + batch_dim, - 1LL, - std::multiplies()); - if (out_batch_size == 0) return true; - - if (x_batch_size == 1 && M == 1 && trans_y) { - if (!(K % 4 == 0)) { - return false; - } - } else if (!trans_x && !trans_y) { - if (!(N % 4 == 0 || N == 1) || !(K % 4 == 0) || (M == 1 && N == 1)) { - return false; - } - } else if (!trans_x && trans_y) { - if (!(K % 4 == 0)) { - return false; - } - } else if (trans_x && !trans_y) { - if (!(M % 4 == 0 || M == 1) || !(N % 4 == 0 || N == 1)) { - return false; - } - } else { - if (!(M % 4 == 0 || M == 1) || !(K % 4 == 0)) { - return false; - } - } - if (x_batch_size == 1 && y_batch_size == 1) { - blaslt::Run(ctx, - x_data, - y_data, - ctx.template Alloc(out), - M, - N, - K, - trans_x, - trans_y, - &matmul_planner); - } else if (x_batch_size == 1) { - if (M == 1 && trans_y) { - blaslt::Run(ctx, - y_data, - x_data, - ctx.template Alloc(out), - y_batch_size * N, - 1, - K, - false, - false, - &matmul_planner); - } else { - blaslt::RunWithBatch(ctx, - x_data, - y_data, - ctx.template Alloc(out), - M, - N, - K, - trans_x, - trans_y, - out_batch_size, - 0, - K * N, - M * N, - &matmul_planner); - } - } else if (y_batch_size == 1) { - if (!trans_x) { - blaslt::Run(ctx, - x_data, - y_data, - ctx.template Alloc(out), - x_batch_size * M, - N, - K, - false, - trans_y, - &matmul_planner); - } else { - blaslt::RunWithBatch(ctx, - x_data, - y_data, - ctx.template Alloc(out), - M, - N, - K, - true, - trans_y, - out_batch_size, - M * K, - 0, - M * N, - &matmul_planner); - } - } else if (!is_broadcast_dims) { - blaslt::RunWithBatch(ctx, - x_data, - y_data, - ctx.template Alloc(out), - M, - N, - K, - trans_x, - trans_y, - out_batch_size, - M * K, - K * N, - M * N, - &matmul_planner); - } else { - // in the case, can't use stridedgemm - std::vector x_ptr(out_batch_size); - std::vector y_ptr(out_batch_size); - std::vector out_ptr(out_batch_size); - std::vector index(batch_dim, 0); - for (std::int64_t i = 0; i < out_batch_size; ++i) { - // using the index to get offset - const std::int64_t x_index = - GetIndexMessage(batch_dim, x_broadcast_dims.data(), index.data()); - const std::int64_t y_index = - GetIndexMessage(batch_dim, y_broadcast_dims.data(), index.data()); - - x_ptr[i] = x_data + x_index * M * K; - y_ptr[i] = y_data + y_index * K * N; - out_ptr[i] = ctx.template Alloc(out) + i * M * N; - IndexIncreaseFromDims(batch_dim, out_broadcast_dims.data(), index.data()); - } - blaslt::RunWithBatch(ctx, - x_ptr.data(), - y_ptr.data(), - out_ptr.data(), - M, - N, - K, - trans_x, - trans_y, - out_batch_size, - &matmul_planner); - } - return true; -#else - return false; -#endif -} -#endif - -template -typename std::enable_if::value>::type -MatmulJudgeDtypeKernel(const Context& ctx, - const DenseTensor& x, - const DenseTensor& y, - const std::vector& x_dims, - const std::vector& y_dims, - DenseTensor* out, - bool transpose_x, - bool transpose_y) { - bool try_matmul_int8 = MatMulInt8Function( - ctx, x, y, x_dims, y_dims, out, transpose_x, transpose_y); - if (try_matmul_int8) { - return; - } - auto x_tmp = phi::Cast(ctx, x, phi::DataType::FLOAT32); - auto y_tmp = phi::Cast(ctx, y, phi::DataType::FLOAT32); - DenseTensor out_tmp; - MatMulFunction( - ctx, x_tmp, y_tmp, x_dims, y_dims, &out_tmp, transpose_x, transpose_y); - if (x.dtype() == phi::DataType::INT8) { - phi::CastKernel(ctx, out_tmp, phi::DataType::INT32, out); - return; - } - phi::CastKernel(ctx, out_tmp, x.dtype(), out); -} - -template -typename std::enable_if::value>::type -MatmulJudgeDtypeKernel(const Context& ctx, - const DenseTensor& x, - const DenseTensor& y, - const std::vector& x_dims, - const std::vector& y_dims, - DenseTensor* out, - bool transpose_x, - bool transpose_y) { - MatMulFunction( - ctx, x, y, x_dims, y_dims, out, transpose_x, transpose_y); -} - -template -void MatmulKernel(const Context& ctx, - const DenseTensor& x, - const DenseTensor& y, - bool transpose_x, - bool transpose_y, - DenseTensor* out) { - PADDLE_ENFORCE_NE( - common::product(x.dims()), - 0, - phi::errors::InvalidArgument("The Input(X) dims size must not be equal 0," - " but reviced dims size is 0. ")); - PADDLE_ENFORCE_NE( - common::product(y.dims()), - 0, - phi::errors::InvalidArgument("The Input(Y) dims size must not be equal 0," - " but reviced dims size is 0. ")); - const std::vector x_dims = common::vectorize(x.dims()); - const std::vector y_dims = common::vectorize(y.dims()); - MatmulJudgeDtypeKernel( - ctx, x, y, x_dims, y_dims, out, transpose_x, transpose_y); -} - -template -void MatmulWithFlattenKernelImpl(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - int x_num_col_dims, - int y_num_col_dims, - DenseTensor* out) { - const DenseTensor x_matrix = - x.dims().size() > 2 ? phi::ReshapeToMatrix(x, x_num_col_dims) : x; - const DenseTensor y_matrix = - y.dims().size() > 2 ? phi::ReshapeToMatrix(y, y_num_col_dims) : y; - - dev_ctx.template Alloc(out); - auto z_dim = out->dims(); - if (z_dim.size() != 2) { - out->Resize({x_matrix.dims()[0], y_matrix.dims()[1]}); - } - - auto blas = phi::funcs::GetBlas(dev_ctx); - - blas.MatMul(x_matrix, y_matrix, out); - if (z_dim.size() != 2) { - out->Resize(z_dim); - } -} - -#ifdef PADDLE_WITH_CUDA - -template -void MatmulWithFlattenKernelInt8Impl(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - int x_num_col_dims, - int y_num_col_dims, - DenseTensor* out) { - PADDLE_ENFORCE_EQ( - x.dtype(), - DataType::INT8, - phi::errors::InvalidArgument( - "The type of input(x) used in int8 mul must be (%s) " - "does not match the " - "type of data (%s) currently contained in the container.", - phi::CppTypeToDataType::Type(), - x.dtype())); - PADDLE_ENFORCE_EQ( - y.dtype(), - DataType::INT8, - phi::errors::InvalidArgument( - "The type of input(y) used in int8 mul must be (%s) " - "does not match the " - "type of data (%s) currently contained in the container.", - phi::CppTypeToDataType::Type(), - y.dtype())); - - const DenseTensor x_matrix = - x.dims().size() > 2 ? phi::ReshapeToMatrix(x, x_num_col_dims) : x; - const DenseTensor y_matrix = - y.dims().size() > 2 ? phi::ReshapeToMatrix(y, y_num_col_dims) : y; - - PADDLE_ENFORCE_EQ( - x_matrix.dims()[1], - y_matrix.dims()[0], - phi::errors::InvalidArgument( - "X's numbers of columns must be equal to Y's numbers of rows." - "But received X has [%d] columns," - "received Y has [%d] rows", - x_matrix.dims()[1], - y_matrix.dims()[0])); - - PADDLE_ENFORCE_EQ((y_matrix.dims()[1] % 4 == 0 || y_matrix.dims()[1] == 1), - true, - phi::errors::InvalidArgument( - "The dimension size N used in int8 mul must be 1" - "or a multiple of 4 does not match the size (%d)" - "currently contained in the container.", - y_matrix.dims()[1])); - PADDLE_ENFORCE_EQ((x_matrix.dims()[1] % 4 == 0), - true, - phi::errors::InvalidArgument( - "The dimension size K used in int8 mul must be a" - "multiple of 4 does not match the size (%d) currently" - "contained in the container.", - x_matrix.dims()[1])); - - dev_ctx.template Alloc(out); - auto z_dim = out->dims(); - if (z_dim.size() != 2) { - out->Resize({x_matrix.dims()[0], y_matrix.dims()[1]}); - } - -#if CUDA_VERSION >= 11060 && 0 - using blaslt = phi::funcs::MatmulWithCublasLt; - - const int8_t* x_data = x_matrix.data(); - const int8_t* y_data = y_matrix.data(); - - std::vector x_dims = {x_matrix.dims()[0], x_matrix.dims()[1]}; - std::vector y_dims = {y_matrix.dims()[0], y_matrix.dims()[1]}; - phi::funcs::MatmulPlanner matmul_planner( - x_dims, - y_dims, - false, - false, - phi::CppTypeToDataType::Type(), - funcs::MatmulFusedType::kMatmul, - /* bias_data */ nullptr, - /* reserve_data */ nullptr, - /* use_addto */ false, - /* no_exchange */ true); - - blaslt::Run(dev_ctx, - x_data, - y_data, - dev_ctx.template Alloc(out), - x_matrix.dims()[0], - y_matrix.dims()[1], - x_matrix.dims()[1], - false, - false, - &matmul_planner); - - if (z_dim.size() != 2) { - out->Resize(z_dim); - } -#endif -} -#endif - -#ifdef PADDLE_WITH_CUDA -template -typename std::enable_if::value, - void>::type -DispatchMatmulWithFlattenInt8Kernel(const phi::GPUContext& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - int x_num_col_dims, - int y_num_col_dims, - DenseTensor* out) { - MatmulWithFlattenKernelInt8Impl( - dev_ctx, x, y, x_num_col_dims, y_num_col_dims, out); -} -#endif - -template -typename std::enable_if::value, - void>::type -DispatchMatmulWithFlattenInt8Kernel(const phi::CPUContext& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - int x_num_col_dims, - int y_num_col_dims, - DenseTensor* out) { - PADDLE_THROW(phi::errors::Unimplemented( - "MatmulWithFlatten with CPU is NOT implemented " - "yet.")); -} - -template -typename std::enable_if::value, void>::type -DispatchMatmulFlattenKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - int x_num_col_dims, - int y_num_col_dims, - DenseTensor* out) { - DispatchMatmulWithFlattenInt8Kernel( - dev_ctx, x, y, x_num_col_dims, y_num_col_dims, out); -} - -template -typename std::enable_if::value, void>::type -DispatchMatmulFlattenKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - int x_num_col_dims, - int y_num_col_dims, - DenseTensor* out) { - MatmulWithFlattenKernelImpl( - dev_ctx, x, y, x_num_col_dims, y_num_col_dims, out); -} - -template -void MatmulWithFlattenKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - int x_num_col_dims, - int y_num_col_dims, - DenseTensor* out) { - DispatchMatmulFlattenKernel( - dev_ctx, x, y, x_num_col_dims, y_num_col_dims, out); -} - -} // namespace phi diff --git a/backends/metax_gpu/kernels/impl/metax_weight_quantize_kernel_impl.h b/backends/metax_gpu/kernels/impl/metax_weight_quantize_kernel_impl.h new file mode 100644 index 00000000000..9aedba871c5 --- /dev/null +++ b/backends/metax_gpu/kernels/impl/metax_weight_quantize_kernel_impl.h @@ -0,0 +1,303 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/common/enforce.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/aligned_vector.h" +#include "paddle/phi/kernels/funcs/common_shape.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { +template +void show_2d_cpu_tensor(const DenseTensor& tensor, + const int64_t row_num = 3, + const int64_t col_num = 3) { + const int64_t rows = tensor.dims()[0]; + const int64_t cols = tensor.dims()[1]; + printf("\nTensor shape = [%d, %d]\n", rows, cols); + + const DataType* cpu_ptr = tensor.data(); + + for (int r = 0; r < row_num; r++) { + for (int c = 0; c < col_num; c++) { + DataType val = *(cpu_ptr + r * cols + c); + printf("%#x ", val); + } + printf("\n"); + } + printf("\n\n"); +} +template +void show_2d_gpu_tensor(const CustomContext& dev_ctx, + const DenseTensor& tensor, + const int64_t row_num = 3, + const int64_t col_num = 3) { + phi::CPUPlace cpu_place; + + DenseTensor cpu_tensor; + phi::Copy(dev_ctx, tensor, cpu_place, true, &cpu_tensor); + + const int64_t rows = cpu_tensor.dims()[0]; + const int64_t cols = cpu_tensor.dims()[1]; + printf("\nTensor shape = [%d, %d]\n", rows, cols); + + const DataType* cpu_ptr = cpu_tensor.data(); + + for (int r = 0; r < row_num; r++) { + for (int c = 0; c < col_num; c++) { + DataType val = *(cpu_ptr + r * cols + c); + printf("%#x ", val); + } + printf("\n"); + } + printf("\n\n"); +} + +template +void show_1d_gpu_tensor(const CustomContext& dev_ctx, + const DenseTensor& tensor, + const int64_t num = 3) { + phi::CPUPlace cpu_place; + + DenseTensor cpu_tensor; + phi::Copy(dev_ctx, tensor, cpu_place, true, &cpu_tensor); + + const int64_t nums = cpu_tensor.numel(); + printf("\nTensor shape = [%d]\n", nums); + + const DataType* cpu_ptr = cpu_tensor.data(); + + for (int n = 0; n < num; n++) { + DataType val = *(cpu_ptr + n); + printf("%#x ", val); + } + printf("\n\n"); +} + +void cpu_2d_tensor_transpose(const DenseTensor& input_data, + DenseTensor* transposed_data) { + const int64_t input_data_rows = input_data.dims()[0]; + const int64_t input_data_cols = input_data.dims()[1]; + + const int8_t* input_data_ptr = input_data.data(); + int8_t* transposed_data_ptr = transposed_data->data(); + + for (int64_t r = 0; r < input_data_rows; r++) { + for (int64_t c = 0; c < input_data_cols; c++) { + *(transposed_data_ptr + r + c * input_data_rows) = + *(input_data_ptr + r * input_data_cols + c); + } + } +} + +void cpu_int4_quanted_weight_raw_unpack(const DenseTensor& packed_data, + DenseTensor* unpacked_data) { + const int64_t packed_data_rows = packed_data.dims()[0]; + const int64_t packed_data_cols = packed_data.dims()[1]; + + const int8_t* packed_data_ptr = packed_data.data(); + int8_t* unpacked_data_ptr = unpacked_data->data(); + + for (int64_t c = 0; c < packed_data_cols; c++) { + for (int64_t r = 0; r < packed_data_rows; r++) { + int8_t val = *(packed_data_ptr + r * packed_data_cols + c); + int8_t low_int4 = val & 0x0f; + int8_t hight_int4 = (val >> 4) & 0x0f; + + *(unpacked_data_ptr + (2 * r) * packed_data_cols + c) = + low_int4 >= 8 ? low_int4 - 16 : low_int4; + *(unpacked_data_ptr + (2 * r + 1) * packed_data_cols + c) = + hight_int4 >= 8 ? hight_int4 - 16 : hight_int4; + } + } +} + +void cpu_int4_quanted_weight_col_pack(const DenseTensor& unpacked_data, + DenseTensor* packed_data) { + const int64_t packed_data_rows = packed_data->dims()[0]; + const int64_t packed_data_cols = packed_data->dims()[1]; + + int8_t* packed_data_ptr = packed_data->data(); + const int8_t* unpacked_data_ptr = unpacked_data.data(); + + for (int64_t r = 0; r < packed_data_rows; r++) { + for (int64_t c = 0; c < packed_data_cols; c++) { + int8_t low_int4 = *(unpacked_data_ptr + 2 * r * packed_data_cols + 2 * c); + int8_t hight_int4 = + *(unpacked_data_ptr + 2 * r * packed_data_cols + 2 * c + 1); + + low_int4 = low_int4 < 0 ? low_int4 + 16 : low_int4; + hight_int4 = hight_int4 < 0 ? hight_int4 + 16 : hight_int4; + + *(packed_data_ptr + r * packed_data_cols + c) = + ((hight_int4 << 4) & 0xf0) | (low_int4 & 0x0f); + } + } +} + +void cpu_int4_quantized_weight_layout_trans_impl( + const CustomContext& dev_ctx, + const std::vector& shape, + DenseTensor* out) { + const int64_t m = shape[0]; + const int64_t n = shape[1]; + + phi::CPUPlace cpu_place; + + out->Resize({m / 2, n}); + + DenseTensor out_cpu_tensor; + phi::Copy(dev_ctx, (*out), cpu_place, true, &out_cpu_tensor); + + // raw unpack + DenseTensor raw_unpack_tensor; + raw_unpack_tensor.Resize({out_cpu_tensor.dims()[0] * 2, n}); + raw_unpack_tensor.mutable_data(cpu_place); + cpu_int4_quanted_weight_raw_unpack(out_cpu_tensor, &raw_unpack_tensor); + + // transpose + DenseTensor transposed_tensor; + transposed_tensor.Resize( + {raw_unpack_tensor.dims()[1], raw_unpack_tensor.dims()[0]}); + transposed_tensor.mutable_data(cpu_place); + cpu_2d_tensor_transpose(raw_unpack_tensor, &transposed_tensor); + + // col pack + out_cpu_tensor.Resize( + {transposed_tensor.dims()[0], transposed_tensor.dims()[1] / 2}); + cpu_int4_quanted_weight_col_pack(transposed_tensor, &out_cpu_tensor); + + out_cpu_tensor.Resize({n / 2, m}); + out->Resize({n / 2, m}); + phi::Copy(dev_ctx, out_cpu_tensor, dev_ctx.GetPlace(), true, out); +} + +__global__ void int4_quanted_matrix_raw_unpack_kernel(const int8_t* mat, + int8_t* unpack_mat, + int M, + int N) { + int global_idx = blockIdx.x * blockDim.x + threadIdx.x; + + int i = global_idx / N; + int j = global_idx % N; + + if (global_idx >= M * N) { + return; + } + + int8_t val = mat[global_idx]; + int8_t low = val & 0x0F; + int8_t mask = ((low & 0x80) == 0) & ((low & 0x78) != 0); + low -= 16 * mask; + + int8_t high = (val >> 4) & 0x0F; + mask = ((high & 0x80) == 0) & ((high & 0x78) != 0); + high -= 16 * mask; + + int output_global_idx0 = (2 * i) * N + j; + int output_global_idx1 = (2 * i + 1) * N + j; + + unpack_mat[output_global_idx0] = low; + unpack_mat[output_global_idx1] = high; +} + +__global__ void int4_quanted_matrix_col_pack_kernel(const int8_t* mat, + int8_t* pack_mat, + int M, + int N) { + int global_idx = blockIdx.x * blockDim.x + threadIdx.x; + + int i = global_idx / N; + int j = global_idx % N; + + if (global_idx >= M * N) { + return; + } + + int mat_global_idx0 = i * 2 * N + 2 * j; + int mat_global_idx1 = i * 2 * N + 2 * j + 1; + + int8_t low = mat[mat_global_idx0] & 0x0F; + low = low + ((low >> 3) & 1) * 16; + + int8_t high = mat[mat_global_idx1] & 0x0F; + high = high + ((high >> 3) & 1) * 16; + + pack_mat[global_idx] = ((high << 4) & 0xf0) | (low & 0x0f); +} + +void gpu_int4_quantized_weight_layout_trans_impl( + const CustomContext& dev_ctx, + const std::vector& shape, + DenseTensor* out) { + int64_t total_m = shape[0]; + int64_t total_n = shape[1]; + out->Resize({total_m / 2, total_n}); + + DenseTensor unpack_mat(out->type()); + unpack_mat.Resize({total_m, total_n}); + dev_ctx.template Alloc(&unpack_mat); + + constexpr int kBlockSize = 64; + int64_t kGridSize = (out->numel() + kBlockSize - 1) / kBlockSize; + int4_quanted_matrix_raw_unpack_kernel<<>>( + out->data(), + unpack_mat.data(), + out->dims()[0], + out->dims()[1]); + + DenseTensor transposed_tensor; + transposed_tensor.Resize({unpack_mat.dims()[1], unpack_mat.dims()[0]}); + dev_ctx.template Alloc(&transposed_tensor); + std::vector axis = {1, 0}; + funcs::Transpose trans; + trans(dev_ctx, unpack_mat, &transposed_tensor, axis); + + out->Resize({transposed_tensor.dims()[0], transposed_tensor.dims()[1] / 2}); + int4_quanted_matrix_col_pack_kernel<<>>( + transposed_tensor.data(), + out->data(), + out->dims()[0], + out->dims()[1]); + + out->Resize({total_n / 2, total_m}); +} + +template +void MetaxQuantizedWeightLayoutTrans(const Context& dev_ctx, + const std::string& algo, + const std::vector& shape, + DenseTensor* out) { + if (algo == "weight_only_int4") { + if (dev_ctx.GetPlace() == phi::CPUPlace()) { + cpu_int4_quantized_weight_layout_trans_impl(dev_ctx, shape, out); + } else { + gpu_int4_quantized_weight_layout_trans_impl(dev_ctx, shape, out); + } + + } else { + PADDLE_FATAL( + "The algo must be in ['weight_only_int4'" + "], but got[%s]", + algo); + } +} + +} // namespace phi diff --git a/backends/metax_gpu/kernels/impl/multi_dot_kernel_impl.h b/backends/metax_gpu/kernels/impl/multi_dot_kernel_impl.h index aaa7fbd8d2c..7ba97234cc1 100644 --- a/backends/metax_gpu/kernels/impl/multi_dot_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/multi_dot_kernel_impl.h @@ -14,9 +14,9 @@ #pragma once -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" namespace phi { template diff --git a/backends/metax_gpu/kernels/impl/mv_kernel_impl.h b/backends/metax_gpu/kernels/impl/mv_kernel_impl.h index a87d431e250..4baee25a099 100644 --- a/backends/metax_gpu/kernels/impl/mv_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/mv_kernel_impl.h @@ -14,7 +14,7 @@ #pragma once -#include "kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" namespace phi { diff --git a/backends/metax_gpu/kernels/impl/solve_grad_kernel_impl.h b/backends/metax_gpu/kernels/impl/solve_grad_kernel_impl.h index 860bce2cba5..1dd276dde2f 100644 --- a/backends/metax_gpu/kernels/impl/solve_grad_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/solve_grad_kernel_impl.h @@ -14,11 +14,11 @@ limitations under the License. */ #pragma once -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/kernels/expand_as_kernel.h" #include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/matrix_solve.h" #include "paddle/phi/kernels/funcs/reduce_function.h" diff --git a/backends/metax_gpu/kernels/impl/triangular_solve_grad_kernel_impl.h b/backends/metax_gpu/kernels/impl/triangular_solve_grad_kernel_impl.h index 08138853099..ad656b7a6c8 100644 --- a/backends/metax_gpu/kernels/impl/triangular_solve_grad_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/triangular_solve_grad_kernel_impl.h @@ -14,10 +14,10 @@ #pragma once -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/common_shape.h" #include "paddle/phi/kernels/funcs/complex_functors.h" #include "paddle/phi/kernels/funcs/for_range.h" diff --git a/backends/metax_gpu/kernels/metax_kernel/batch_fc_grad_kernel_register.cu b/backends/metax_gpu/kernels/metax_kernel/batch_fc_grad_kernel_register.cu index 51f8f6792e2..c31d82920b3 100644 --- a/backends/metax_gpu/kernels/metax_kernel/batch_fc_grad_kernel_register.cu +++ b/backends/metax_gpu/kernels/metax_kernel/batch_fc_grad_kernel_register.cu @@ -14,10 +14,10 @@ #include -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/eigen/common.h" namespace phi { diff --git a/backends/metax_gpu/kernels/metax_kernel/block_attn.h b/backends/metax_gpu/kernels/metax_kernel/block_attn.h index 1e1eb2c0961..a5b88e34be1 100644 --- a/backends/metax_gpu/kernels/metax_kernel/block_attn.h +++ b/backends/metax_gpu/kernels/metax_kernel/block_attn.h @@ -14,11 +14,11 @@ #pragma once -#include "kernels/funcs/quant_dequant.h" #include "kernels/metax_kernel/mmha_util.cu.h" #include "paddle/common/flags.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/memory_utils.h" +#include "paddle/phi/kernels/funcs/quant_dequant.h" COMMON_DECLARE_bool(use_xqa_optim); COMMON_DECLARE_bool(blha_use_fp32_qk_sum); diff --git a/backends/metax_gpu/kernels/metax_kernel/elementwise.h b/backends/metax_gpu/kernels/metax_kernel/elementwise.h index 52a7709424b..b9f3d8af1c9 100644 --- a/backends/metax_gpu/kernels/metax_kernel/elementwise.h +++ b/backends/metax_gpu/kernels/metax_kernel/elementwise.h @@ -14,9 +14,9 @@ limitations under the License. */ #pragma once -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/common_shape.h" #include "paddle/phi/kernels/funcs/eigen/common.h" diff --git a/backends/metax_gpu/kernels/metax_kernel/metax_context.h b/backends/metax_gpu/kernels/metax_kernel/metax_context.h index 7386811a236..18f1e30f191 100644 --- a/backends/metax_gpu/kernels/metax_kernel/metax_context.h +++ b/backends/metax_gpu/kernels/metax_kernel/metax_context.h @@ -17,9 +17,9 @@ #include #include -#include "kernels/funcs/blas/cublasLt.h" #include "paddle/phi/backends/context_pool.h" #include "paddle/phi/backends/custom/custom_context.h" +#include "paddle/phi/backends/dynload/cublasLt.h" #include "paddle/phi/backends/gpu/forwards.h" #include "paddle/phi/backends/gpu/gpu_decls.h" #include "paddle/phi/backends/gpu/gpu_helper.h" @@ -28,8 +28,6 @@ #include "paddle/phi/core/attribute.h" #include "paddle/phi/core/device_context.h" -cublasLtHandle_t GetBlasLtHandle(); - namespace phi { class DnnWorkspaceHandle { public: diff --git a/backends/metax_gpu/kernels/metax_kernel/mmha_util.cu.h b/backends/metax_gpu/kernels/metax_kernel/mmha_util.cu.h index aa352e600b5..187b0fc534a 100644 --- a/backends/metax_gpu/kernels/metax_kernel/mmha_util.cu.h +++ b/backends/metax_gpu/kernels/metax_kernel/mmha_util.cu.h @@ -49,10 +49,10 @@ #pragma once -#if defined(__CUDACC__) && CUDA_VERSION >= 11000 +// #if defined(__CUDACC__) && CUDA_VERSION >= 11000 #define ENABLE_BF16 #include -#endif +// #endif #ifdef PADDLE_WITH_HIP #include @@ -72,8 +72,8 @@ namespace cub = hipcub; #endif #include "paddle/phi/common/datatype_traits.h" +#include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/funcs/math_function.h" - #ifdef PADDLE_WITH_HIP /// integral_constant template @@ -130,7 +130,7 @@ struct Float4_ { float2 y; }; -#if defined(ENABLE_BF16) || defined(PADDLE_WITH_HIP) +// #if defined(ENABLE_BF16) || defined(PADDLE_WITH_HIP) struct bf16_4_t { __nv_bfloat162 x; __nv_bfloat162 y; @@ -142,7 +142,7 @@ struct bf16_8_t { __nv_bfloat162 z; __nv_bfloat162 w; }; -#endif +// #endif //----------------------------------- template diff --git a/backends/metax_gpu/kernels/metax_kernel/mv_grad_kernel_register.cu b/backends/metax_gpu/kernels/metax_kernel/mv_grad_kernel_register.cu index 895484324a9..8cf069c0f4b 100644 --- a/backends/metax_gpu/kernels/metax_kernel/mv_grad_kernel_register.cu +++ b/backends/metax_gpu/kernels/metax_kernel/mv_grad_kernel_register.cu @@ -12,10 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/mv_grad_kernel.h" namespace phi { diff --git a/backends/metax_gpu/kernels/metax_kernel/quant_dequant.h b/backends/metax_gpu/kernels/metax_kernel/quant_dequant.h index a37fc8c5c57..80d325530f5 100644 --- a/backends/metax_gpu/kernels/metax_kernel/quant_dequant.h +++ b/backends/metax_gpu/kernels/metax_kernel/quant_dequant.h @@ -16,12 +16,12 @@ limitations under the License. */ #include -#include "kernels/funcs/blas/blas.h" #include "paddle/common/hostdevice.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/common/transform.h" #include "paddle/phi/kernels/funcs/aligned_vector.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" namespace phi { diff --git a/backends/metax_gpu/kernels/metax_kernel/rank_attention_grad_kernel_register.cu b/backends/metax_gpu/kernels/metax_kernel/rank_attention_grad_kernel_register.cu index bee25a721fa..ba33e68aa5e 100644 --- a/backends/metax_gpu/kernels/metax_kernel/rank_attention_grad_kernel_register.cu +++ b/backends/metax_gpu/kernels/metax_kernel/rank_attention_grad_kernel_register.cu @@ -17,8 +17,8 @@ #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" #include "paddle/phi/core/kernel_registry.h" -// #include "paddle/phi/kernels/funcs/blas/blas.h" -#include "kernels/funcs/blas/blas.h" +// #include "paddle/phi/paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/rank_attention.cu.h" diff --git a/backends/metax_gpu/kernels/metax_kernel/rank_attention_kernel_register.cu b/backends/metax_gpu/kernels/metax_kernel/rank_attention_kernel_register.cu index b6a4d2d76e9..eeb9c938888 100644 --- a/backends/metax_gpu/kernels/metax_kernel/rank_attention_kernel_register.cu +++ b/backends/metax_gpu/kernels/metax_kernel/rank_attention_kernel_register.cu @@ -17,8 +17,8 @@ #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" #include "paddle/phi/core/kernel_registry.h" -// #include "paddle/phi/kernels/funcs/blas/blas.h" -#include "kernels/funcs/blas/blas.h" +// #include "paddle/phi/paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/rank_attention.cu.h" diff --git a/backends/metax_gpu/kernels/metax_kernel/slogdeterminant_kernel_register.cu b/backends/metax_gpu/kernels/metax_kernel/slogdeterminant_kernel_register.cu index de263c91c4d..3e9a5683ae4 100644 --- a/backends/metax_gpu/kernels/metax_kernel/slogdeterminant_kernel_register.cu +++ b/backends/metax_gpu/kernels/metax_kernel/slogdeterminant_kernel_register.cu @@ -20,12 +20,12 @@ #include #include "glog/logging.h" -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/common/complex.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/determinant_kernel.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/impl/determinant_kernel_impl.h" #include "paddle/phi/kernels/slogdeterminant_kernel.h" diff --git a/backends/metax_gpu/kernels/metax_kernel/softmax_kernel_grad_register.cu b/backends/metax_gpu/kernels/metax_kernel/softmax_kernel_grad_register.cu index 9b981029fc0..407180deca8 100644 --- a/backends/metax_gpu/kernels/metax_kernel/softmax_kernel_grad_register.cu +++ b/backends/metax_gpu/kernels/metax_kernel/softmax_kernel_grad_register.cu @@ -45,5 +45,6 @@ PD_REGISTER_PLUGIN_KERNEL(softmax_grad, ALL_LAYOUT, phi::SoftmaxGradGPUDNNKernel, float, + double, phi::dtype::float16, phi::dtype::bfloat16) {} diff --git a/backends/metax_gpu/kernels/metax_kernel/softmax_kernel_register.cu b/backends/metax_gpu/kernels/metax_kernel/softmax_kernel_register.cu index 0344a81dc19..523a2e4d76b 100644 --- a/backends/metax_gpu/kernels/metax_kernel/softmax_kernel_register.cu +++ b/backends/metax_gpu/kernels/metax_kernel/softmax_kernel_register.cu @@ -11,7 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - +#if 0 #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/float16.h" @@ -27,3 +27,5 @@ PD_REGISTER_PLUGIN_KERNEL(softmax, double, phi::dtype::float16, phi::dtype::bfloat16) {} + +#endif diff --git a/backends/metax_gpu/kernels/metax_kernel/svd_kernel_register.cu b/backends/metax_gpu/kernels/metax_kernel/svd_kernel_register.cu index 5f9d6cc20e0..c8ece09bbae 100644 --- a/backends/metax_gpu/kernels/metax_kernel/svd_kernel_register.cu +++ b/backends/metax_gpu/kernels/metax_kernel/svd_kernel_register.cu @@ -15,7 +15,7 @@ #ifndef PADDLE_WITH_HIP // HIP not support cusolver -#include "kernels/impl/values_vectors_functor.h" +#include "kernels/metax_kernel/metax_context.h" #include "paddle/phi/backends/dynload/cusolver.h" #include "paddle/phi/common/memory_utils.h" #include "paddle/phi/core/kernel_registry.h" @@ -60,7 +60,6 @@ void GesvdjBatched(const phi::GPUContext& dev_ctx, int ldu = m; int ldt = n; int lwork = 0; - // auto handle = dev_ctx.cusolver_dn_handle(); auto handle = GetCusolverDnHandle(dev_ctx.stream(), dev_ctx.GetPlace()); PADDLE_ENFORCE_GPU_SUCCESS( phi::dynload::cusolverDnCreateGesvdjInfo(&gesvdj_params)); @@ -142,7 +141,6 @@ void GesvdjBatched(const phi::GPUContext& dev_ctx, int ldu = m; int ldt = n; int lwork = 0; - // auto handle = dev_ctx.cusolver_dn_handle(); auto handle = GetCusolverDnHandle(dev_ctx.stream(), dev_ctx.GetPlace()); PADDLE_ENFORCE_GPU_SUCCESS( phi::dynload::cusolverDnCreateGesvdjInfo(&gesvdj_params)); @@ -205,17 +203,17 @@ void GesvdjBatched(const phi::GPUContext& dev_ctx, } template <> -void GesvdjBatched>(const phi::GPUContext& dev_ctx, - int batchSize, - int m, - int n, - int k, - phi::dtype::complex* A, - phi::dtype::complex* U, - phi::dtype::complex* V, - float* S, - int* info, - int thin_UV) { +void GesvdjBatched(const phi::GPUContext& dev_ctx, + int batchSize, + int m, + int n, + int k, + phi::complex64* A, + phi::complex64* U, + phi::complex64* V, + float* S, + int* info, + int thin_UV) { /* compute singular vectors */ const cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR; /* compute singular vectors */ @@ -224,7 +222,6 @@ void GesvdjBatched>(const phi::GPUContext& dev_ctx, int ldu = m; int ldt = n; int lwork = 0; - // auto handle = dev_ctx.cusolver_dn_handle(); auto handle = GetCusolverDnHandle(dev_ctx.stream(), dev_ctx.GetPlace()); PADDLE_ENFORCE_GPU_SUCCESS( phi::dynload::cusolverDnCreateGesvdjInfo(&gesvdj_params)); @@ -245,10 +242,10 @@ void GesvdjBatched>(const phi::GPUContext& dev_ctx, gesvdj_params)); auto workspace = phi::memory_utils::Alloc( dev_ctx.GetPlace(), - lwork * sizeof(phi::dtype::complex), + lwork * sizeof(phi::complex64), phi::Stream(reinterpret_cast(dev_ctx.stream()))); - phi::dtype::complex* workspace_ptr = - reinterpret_cast*>(workspace->ptr()); + phi::complex64* workspace_ptr = + reinterpret_cast(workspace->ptr()); int stride_A = lda * n; int stride_U = ldu * (thin_UV ? k : m); int stride_V = ldt * (thin_UV ? k : n); @@ -289,17 +286,17 @@ void GesvdjBatched>(const phi::GPUContext& dev_ctx, } template <> -void GesvdjBatched>(const phi::GPUContext& dev_ctx, - int batchSize, - int m, - int n, - int k, - phi::dtype::complex* A, - phi::dtype::complex* U, - phi::dtype::complex* V, - double* S, - int* info, - int thin_UV) { +void GesvdjBatched(const phi::GPUContext& dev_ctx, + int batchSize, + int m, + int n, + int k, + phi::complex128* A, + phi::complex128* U, + phi::complex128* V, + double* S, + int* info, + int thin_UV) { /* compute singular vectors */ const cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR; /* compute singular vectors */ @@ -308,7 +305,6 @@ void GesvdjBatched>(const phi::GPUContext& dev_ctx, int ldu = m; int ldt = n; int lwork = 0; - // auto handle = dev_ctx.cusolver_dn_handle(); auto handle = GetCusolverDnHandle(dev_ctx.stream(), dev_ctx.GetPlace()); PADDLE_ENFORCE_GPU_SUCCESS( phi::dynload::cusolverDnCreateGesvdjInfo(&gesvdj_params)); @@ -329,10 +325,10 @@ void GesvdjBatched>(const phi::GPUContext& dev_ctx, gesvdj_params)); auto workspace = phi::memory_utils::Alloc( dev_ctx.GetPlace(), - lwork * sizeof(phi::dtype::complex), + lwork * sizeof(phi::complex128), phi::Stream(reinterpret_cast(dev_ctx.stream()))); - phi::dtype::complex* workspace_ptr = - reinterpret_cast*>(workspace->ptr()); + phi::complex128* workspace_ptr = + reinterpret_cast(workspace->ptr()); int stride_A = lda * n; int stride_U = ldu * (thin_UV ? k : m); int stride_V = ldt * (thin_UV ? k : n); @@ -432,7 +428,7 @@ PD_REGISTER_PLUGIN_KERNEL(svd, // cuda_only phi::SvdKernel, float, double, - phi::dtype::complex, - phi::dtype::complex) {} + phi::complex64, + phi::complex128) {} #endif // not PADDLE_WITH_HIP diff --git a/backends/metax_gpu/kernels/metax_kernel/triangular_solve_kernel_register.cu b/backends/metax_gpu/kernels/metax_kernel/triangular_solve_kernel_register.cu index 5ff3211fe87..ed1ed259437 100644 --- a/backends/metax_gpu/kernels/metax_kernel/triangular_solve_kernel_register.cu +++ b/backends/metax_gpu/kernels/metax_kernel/triangular_solve_kernel_register.cu @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "kernels/funcs/blas/blas.h" #include "paddle/common/ddim.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/memory_utils.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/expand_kernel.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/common_shape.h" #include "paddle/phi/kernels/triangular_solve_kernel.h" diff --git a/backends/metax_gpu/kernels/metax_kernel/weight_only_linear_kernel.cu b/backends/metax_gpu/kernels/metax_kernel/weight_only_linear_kernel.cu index d2f39ccf751..65cf99d3065 100644 --- a/backends/metax_gpu/kernels/metax_kernel/weight_only_linear_kernel.cu +++ b/backends/metax_gpu/kernels/metax_kernel/weight_only_linear_kernel.cu @@ -166,7 +166,7 @@ void WeightOnlyLinearKernel(const Context& dev_ctx, mctlassGemmScaleOp_w4a16_nobias::epilogueParams( reinterpret_cast(bias_data)), mctlassGemmScaleOp_w4a16_nobias::quantscaleParams( - 1, + 2, group_size, reinterpret_cast(weight_scale_data)), reinterpret_cast(x_data), @@ -191,7 +191,7 @@ void WeightOnlyLinearKernel(const Context& dev_ctx, mctlassGemmScaleOp_w4a16_bias::epilogueParams( reinterpret_cast(bias_data)), mctlassGemmScaleOp_w4a16_bias::quantscaleParams( - 1, + 2, group_size, reinterpret_cast(weight_scale_data)), reinterpret_cast(x_data), diff --git a/backends/metax_gpu/kernels/metax_kernel/weight_quantize_kernel_register.cu b/backends/metax_gpu/kernels/metax_kernel/weight_quantize_kernel_register.cu index 44ac7f2fddc..efc18693e21 100644 --- a/backends/metax_gpu/kernels/metax_kernel/weight_quantize_kernel_register.cu +++ b/backends/metax_gpu/kernels/metax_kernel/weight_quantize_kernel_register.cu @@ -11,6 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +#include "../impl/metax_weight_quantize_kernel_impl.h" #include "paddle/common/enforce.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/datatype_traits.h" @@ -115,12 +116,11 @@ void WeightQuantizeKernel(const Context& dev_ctx, dev_ctx.template Alloc(scale); weight_quant_gpu(dev_ctx, x.data(), - out->data(), + quanted_x.data(), scale->data(), weight_shape, arch, algo); - out->Resize({m, n}); #ifdef PADDLE_WITH_HIP DenseTensor x_int_tmp(out->type()); x_int_tmp.Resize({m, n / 2}); @@ -141,6 +141,13 @@ void WeightQuantizeKernel(const Context& dev_ctx, // arch, // algo); #endif + quanted_x.Resize({m / 2, n}); + + std::vector axis = {1, 0}; + funcs::Transpose trans; + trans(dev_ctx, quanted_x, out, axis); + + out->Resize({n / 2, m}); } else if (algo == "w4a8") { weight_permute_gpu_w4a8(dev_ctx, x.data(), diff --git a/backends/metax_gpu/patch/paddle.patch b/backends/metax_gpu/patch/paddle.patch index 7ba32b5b399..fe0d9e104a5 100755 --- a/backends/metax_gpu/patch/paddle.patch +++ b/backends/metax_gpu/patch/paddle.patch @@ -18,6 +18,22 @@ index cfada544d4..a690e97d74 100644 endif() set(EIGEN_INCLUDE_DIR ${SOURCE_DIR}) +diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt +index 99a0116d92..2566e7c41a 100755 +--- a/paddle/fluid/operators/fused/CMakeLists.txt ++++ b/paddle/fluid/operators/fused/CMakeLists.txt +@@ -43,6 +43,11 @@ if(WITH_GPU OR WITH_ROCM) + op_library(fused_multi_transformer_int8_op) + endif() + ++ if 1 ++ op_library(fused_gemm_epilogue_op) ++ endif() ++ ++ + if(CUDA_VERSION GREATER_EQUAL 11.6) + op_library(fused_gemm_epilogue_op) + endif() diff --git a/paddle/fluid/platform/profiler/cupti_data_process.cc b/paddle/fluid/platform/profiler/cupti_data_process.cc index bff0f2bf70..9376b5781f 100644 --- a/paddle/fluid/platform/profiler/cupti_data_process.cc @@ -31,6 +47,56 @@ index bff0f2bf70..9376b5781f 100644 #include "paddle/phi/core/os_info.h" #include "paddle/phi/core/platform/device/gpu/gpu_info.h" #include "paddle/phi/core/platform/profiler/utils.h" +diff --git a/paddle/phi/backends/dynload/cublas.h b/paddle/phi/backends/dynload/cublas.h +index 62beb53cfe..0b0ac09fc0 100644 +--- a/paddle/phi/backends/dynload/cublas.h ++++ b/paddle/phi/backends/dynload/cublas.h +@@ -49,7 +49,12 @@ extern void *cublas_dso_handle; + std::call_once(cublas_dso_flag, []() { \ + cublas_dso_handle = phi::dynload::GetCublasDsoHandle(); \ + }); \ +- static void *p_##__name = dlsym(cublas_dso_handle, #__name); \ ++ std::string replaced_name = #__name; \ ++ replaced_name = replaced_name.replace(0, 2, "mc"); \ ++ int index = replaced_name.find("_", 0); \ ++ if (index != -1) replaced_name = replaced_name.substr(0, index); \ ++ static void* p_##__name = \ ++ dlsym(cublas_dso_handle, replaced_name.c_str()); \ + return reinterpret_cast(p_##__name)(args...); \ + } \ + }; \ +diff --git a/paddle/phi/backends/dynload/cublasLt.h b/paddle/phi/backends/dynload/cublasLt.h +index 8b2e08c777..ca926df151 100644 +--- a/paddle/phi/backends/dynload/cublasLt.h ++++ b/paddle/phi/backends/dynload/cublasLt.h +@@ -46,12 +46,14 @@ extern void *cublasLt_dso_handle; + std::call_once(cublasLt_dso_flag, []() { \ + cublasLt_dso_handle = phi::dynload::GetCublasLtDsoHandle(); \ + }); \ +- static void *p_##__name = dlsym(cublasLt_dso_handle, #__name); \ ++ std::string replaced_name = #__name; \ ++ replaced_name = replaced_name.replace(0, 2, "mc"); \ ++ static void* p_##__name = \ ++ dlsym(cublasLt_dso_handle, replaced_name.c_str()); \ + return reinterpret_cast(p_##__name)(args...); \ + } \ + }; \ + extern DynLoad__##__name __name +- + // APIs available after CUDA 11.1 + #if CUDA_VERSION >= 11010 || defined(PADDLE_WITH_CUSTOM_DEVICE) + #define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \ +@@ -79,8 +81,8 @@ extern void *cublasLt_dso_handle; + __macro(cublasLtMatmulAlgoConfigGetAttribute); \ + __macro(cublasLtMatmulAlgoGetIds); \ + __macro(cublasLtMatmulAlgoCapGetAttribute); \ +- __macro(cublasLtMatmulAlgoCheck); \ +- __macro(cublasLtGetCudartVersion); ++ __macro(cublasLtMatmulAlgoCheck); ++ // __macro(cublasLtGetCudartVersion); + #else + #define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \ + __macro(cublasLtCreate); \ diff --git a/paddle/phi/backends/dynload/cudnn.h b/paddle/phi/backends/dynload/cudnn.h index c0080f0a5e..458ca3e2e8 100644 --- a/paddle/phi/backends/dynload/cudnn.h @@ -210,6 +276,29 @@ index 8ec3cf2792..6f5460df00 100644 return reinterpret_cast(p_##__name)(args...); \ } \ }; \ +diff --git a/paddle/phi/backends/dynload/dynamic_loader.cc b/paddle/phi/backends/dynload/dynamic_loader.cc +index 859f696896..87b5100a1b 100644 +--- a/paddle/phi/backends/dynload/dynamic_loader.cc ++++ b/paddle/phi/backends/dynload/dynamic_loader.cc +@@ -18,7 +18,6 @@ limitations under the License. */ + #include + #include + #include +-#include "paddle/phi/backends/dynload/cupti_lib_path.h" + #include "paddle/phi/common/port.h" + #include "paddle/phi/core/enforce.h" + +@@ -108,6 +107,10 @@ COMMON_DECLARE_string(win_cuda_bin_dir); + #define SPARSELT_LIB_NAME "libcusparseLt.so" + #endif + ++#ifndef CUPTI_LIB_PATH ++#define CUPTI_LIB_PATH "@CUPTI_LIBRARY_PATH@" ++#endif ++ + #ifdef PADDLE_WITH_HIP + + PHI_DEFINE_string(miopen_dir, diff --git a/paddle/phi/backends/dynload/nvjpeg.h b/paddle/phi/backends/dynload/nvjpeg.h index c5309e7e11..3328571380 100644 --- a/paddle/phi/backends/dynload/nvjpeg.h @@ -346,21 +435,10 @@ index 4ff2e528a9..23f7f4b583 100644 for (int offset = warpSize / 2; offset > 0; offset /= 2) diff --git a/paddle/phi/core/enforce.h b/paddle/phi/core/enforce.h -index 024a7de73e..1e4cdf16be 100644 +index 024a7de73e..66b373d698 100644 --- a/paddle/phi/core/enforce.h +++ b/paddle/phi/core/enforce.h -@@ -45,7 +45,9 @@ limitations under the License. */ - #endif - - #ifdef PADDLE_WITH_CUDA --#include "paddle/phi/backends/dynload/cublas.h" -+// #include "paddle/phi/backends/dynload/../../../../../cublas.h" -+#include "../backends/metax_gpu/kernels/funcs/blas/cublas.h" -+// #include "paddle/phi/backends/dynload/cublas.h" - #include "paddle/phi/backends/dynload/cudnn.h" - #include "paddle/phi/backends/dynload/curand.h" - #include "paddle/phi/backends/dynload/cusolver.h" -@@ -97,7 +99,7 @@ inline bool is_error(bool stat) { return !stat; } +@@ -97,7 +97,7 @@ inline bool is_error(bool stat) { return !stat; } void ThrowWarnInternal(const std::string& message); @@ -369,75 +447,280 @@ index 024a7de73e..1e4cdf16be 100644 // For cuda, the assertions can affect performance and it is therefore // recommended to disable them in production code // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#assertion -@@ -109,7 +111,7 @@ void ThrowWarnInternal(const std::string& message); +@@ -109,7 +109,7 @@ void ThrowWarnInternal(const std::string& message); __LINE__, \ #_IS_NOT_ERROR, \ ##__VA_ARGS__); \ - asm("trap;"); \ -+ __builtin_trap(); \ ++ __builtin_trap(); \ } \ } while (0) #elif defined(__HIPCC__) -@@ -757,4 +759,4 @@ inline void retry_sleep(unsigned millisecond) { +diff --git a/paddle/phi/kernels/funcs/blas/blas_impl.cu.h b/paddle/phi/kernels/funcs/blas/blas_impl.cu.h +index ae7b67de6d..9ac725314f 100644 +--- a/paddle/phi/kernels/funcs/blas/blas_impl.cu.h ++++ b/paddle/phi/kernels/funcs/blas/blas_impl.cu.h +@@ -218,11 +218,27 @@ struct CUBlas { + } + }; - } // namespace enforce - using namespace enforce; // NOLINT --} // namespace phi -+} // namespace phi -\ No newline at end of file -diff --git a/paddle/phi/core/platform/device/gpu/gpu_types.h b/paddle/phi/core/platform/device/gpu/gpu_types.h -index c646e487d0..325122175c 100644 ---- a/paddle/phi/core/platform/device/gpu/gpu_types.h -+++ b/paddle/phi/core/platform/device/gpu/gpu_types.h -@@ -25,8 +25,9 @@ ++template ++void print_args(Args... args) { ++ std::cout << "Arguments (" << sizeof...(args) << "): ["; ++ bool first = true; ++ auto printer = [&first](const auto& arg) { ++ if (!first) std::cout << ", "; ++ std::cout << arg; ++ first = false; ++ }; ++ (printer(args), ...); ++ std::cout << "]" << std::endl; ++} ++ + template <> + struct CUBlas { + template + static void GEMM(ARGS... args) { ++ // print_args(args...); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDgemm(args...)); ++ ++ + } + + template +@@ -368,7 +384,7 @@ struct CUBlas { + cudaDataType_t Ctype, + int ldc, + int batchCount, +- cudaDataType_t computeType) { ++ cublasComputeType_t computeType) { + #if CUDA_VERSION >= 8000 + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; + #if CUDA_VERSION >= 9000 +@@ -476,7 +492,7 @@ struct CUBlas { + void *C, + cudaDataType_t Ctype, + int ldc, +- cudaDataType_t computeType) { ++ cublasComputeType_t computeType) { + #if CUDA_VERSION >= 8000 + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; + #if CUDA_VERSION >= 9000 +@@ -532,7 +548,7 @@ struct CUBlas { + void *C, + cudaDataType_t Ctype, + int64_t ldc, +- cudaDataType_t computeType) { ++ cublasComputeType_t computeType) { + #if CUDA_VERSION >= 12030 && defined(__linux__) + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; + bool use_tensor_op_math = dev_ctx->tensor_core_available(); +@@ -759,7 +775,7 @@ struct CUBlas { + void *C, + cudaDataType_t Ctype, + int ldc, +- cudaDataType_t computeType) { ++ cublasComputeType_t computeType) { + #if CUDA_VERSION >= 8000 + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; + #if CUDA_VERSION >= 9000 +@@ -815,7 +831,7 @@ struct CUBlas { + void *C, + cudaDataType_t Ctype, + int64_t ldc, +- cudaDataType_t computeType) { ++ cublasComputeType_t computeType) { + #if CUDA_VERSION >= 12030 && defined(__linux__) + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; + bool use_tensor_op_math = dev_ctx->tensor_core_available(); +@@ -1154,7 +1170,7 @@ struct CUBlas { + void *C, + cudaDataType_t Ctype, + int ldc, +- cudaDataType_t computeType) { ++ cublasComputeType_t computeType) { + #if CUDA_VERSION >= 8000 + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; + #if CUDA_VERSION >= 9000 +@@ -1210,7 +1226,7 @@ struct CUBlas { + void *C, + cudaDataType_t Ctype, + int64_t ldc, +- cudaDataType_t computeType) { ++ cublasComputeType_t computeType) { + #if CUDA_VERSION >= 12030 && defined(__linux__) + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; + bool use_tensor_op_math = dev_ctx->tensor_core_available(); +@@ -1484,7 +1500,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, + C, + CUDA_R_16F, + N, +- CUDA_R_32F); ++ CUBLAS_COMPUTE_32F); #else - #include - --#include "paddle/phi/backends/dynload/cublas.h" --#include "paddle/phi/backends/dynload/cublasLt.h" -+// #include "paddle/phi/backends/dynload/cublas.h" -+#include "kernels/funcs/blas/cublas.h" -+// #include "paddle/phi/backends/dynload/cublasLt.h" - #include "paddle/phi/backends/dynload/cudnn.h" - #endif + PADDLE_THROW(common::errors::Unimplemented( + "GEMM_EX_64 is not supported on cuda < 12.3")); +@@ -1508,7 +1524,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, + C, + CUDA_R_16F, + static_cast(N), +- CUDA_R_32F); ++ CUBLAS_COMPUTE_32F); + } + #else + // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm +@@ -1694,7 +1710,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, + C, + CUDA_R_16F, + N, +- CUDA_R_32F); ++ CUBLAS_COMPUTE_32F); + #else + PADDLE_THROW(common::errors::Unimplemented( + "GEMM_EX_64 is not supported on cuda < 12.3")); +@@ -1719,7 +1735,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, + C, + CUDA_R_16F, + static_cast(N), +- CUDA_R_32F); ++ CUBLAS_COMPUTE_32F); + #else + // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm + dev_ctx_.CublasCall([&](cublasHandle_t handle) { +@@ -1831,7 +1847,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, + C, + CUDA_R_16BF, + static_cast(N), +- CUDA_R_32F, ++ CUBLAS_COMPUTE_32F, + algo)); + }); + } +@@ -1932,7 +1948,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, + C, + CUDA_R_16BF, + static_cast(N), +- CUDA_R_32F, ++ CUBLAS_COMPUTE_32F, + algo)); + }); + } +@@ -2026,7 +2042,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, + C, + CUDA_C_32F, + static_cast(N), +- CUDA_C_32F); ++ CUBLAS_COMPUTE_32F); -@@ -90,7 +91,7 @@ DECLARE_TYPE_FOR_GPU(gpuStreamCaptureMode, - - // TODO(Ming Huang): Since there is no blasLt handler, - // use rocblas_handle for workaround. --DECLARE_TYPE_FOR_GPU(blasLtHandle_t, cublasLtHandle_t, rocblas_handle); -+// DECLARE_TYPE_FOR_GPU(blasLtHandle_t, cublasLtHandle_t, rocblas_handle); - - #undef DECLARE_TYPE_FOR_GPU - -diff --git a/paddle/phi/core/platform/device_context.h b/paddle/phi/core/platform/device_context.h -index 2d02eb370b..8a7233e34e 100644 ---- a/paddle/phi/core/platform/device_context.h -+++ b/paddle/phi/core/platform/device_context.h -@@ -25,8 +25,8 @@ limitations under the License. */ - #include "paddle/phi/core/platform/device/gpu/gpu_types.h" - #include "paddle/phi/core/platform/device_type.h" - #ifdef PADDLE_WITH_CUDA --#include "paddle/phi/backends/dynload/cublas.h" --#include "paddle/phi/backends/dynload/cublasLt.h" -+#include "kernels/funcs/blas/cublas.h" -+#include "kernels/funcs/blas/cublasLt.h" - #include "paddle/phi/backends/dynload/cudnn.h" - #include "paddle/phi/backends/dynload/cusolver.h" - #include "paddle/phi/backends/dynload/cusparse.h" -diff --git a/paddle/phi/kernels/cpu/index_select_impl.h b/paddle/phi/kernels/cpu/index_select_impl.h -index d69eb67d6f..1d8b6e9375 100644 ---- a/paddle/phi/kernels/cpu/index_select_impl.h -+++ b/paddle/phi/kernels/cpu/index_select_impl.h -@@ -18,7 +18,7 @@ - - #include "paddle/phi/core/dense_tensor.h" - #include "paddle/phi/core/tensor_utils.h" --#include "paddle/phi/kernels/funcs/blas/blas.h" -+#include "kernels/funcs/blas/blas.h" - #include "paddle/phi/kernels/funcs/eigen/common.h" - #include "paddle/phi/kernels/funcs/math_function.h" + #else + dev_ctx_.CublasCall([&](cublasHandle_t handle) { +@@ -2111,7 +2127,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, + C, + CUDA_C_64F, + N, +- CUDA_C_64F); ++ CUBLAS_COMPUTE_64F); + #else + PADDLE_THROW(common::errors::Unimplemented( + "GEMM_EX_64 is not supported on cuda < 12.3")); +@@ -2136,7 +2152,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, + C, + CUDA_C_64F, + static_cast(N), +- CUDA_C_64F); ++ CUBLAS_COMPUTE_64F); + #else // CUDA_VERSION >= 8000 + // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm + dev_ctx_.CublasCall([&](cublasHandle_t handle) { +@@ -2272,7 +2288,7 @@ inline void Blas::GEMM(bool transA, + C, + CUDA_R_16F, + ldc, +- CUDA_R_32F, ++ CUBLAS_COMPUTE_32F, + algo)); + }); + } +@@ -2334,7 +2350,7 @@ inline void Blas::GEMM(bool transA, + C, + CUDA_R_16BF, + ldc, +- CUDA_R_32F, ++ CUBLAS_COMPUTE_32F, + algo)); + }); + #else +@@ -3129,7 +3145,7 @@ inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, + CUDA_R_16F, + ldc, + batchCount, +- CUDA_R_32F); ++ CUBLAS_COMPUTE_32F); + } + template <> +@@ -3197,7 +3213,7 @@ inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, + CUDA_R_16BF, + ldc, + batchCount, +- CUDA_R_32F, ++ CUBLAS_COMPUTE_32F, + algo)); + }); + #else +diff --git a/paddle/phi/kernels/funcs/blas/blaslt_gemm_search.h b/paddle/phi/kernels/funcs/blas/blaslt_gemm_search.h +index e63b3d2f6e..95d7e6f204 100644 +--- a/paddle/phi/kernels/funcs/blas/blaslt_gemm_search.h ++++ b/paddle/phi/kernels/funcs/blas/blaslt_gemm_search.h +@@ -628,7 +628,13 @@ class CublasLtAlgoCache { + infile >> cublaslt_version; + VLOG(1) << "cublaslt_version " << cublaslt_version; + +- if (dynload::cublasLtGetCudartVersion() != cublaslt_version) { ++ // if (dynload::cublasLtGetCudartVersion() != cublaslt_version) { ++ // LOG(INFO) << algo_caches_file_ ++ // << " is not compatible with current cublaslt_version " ++ // << real_cublaslt_version; ++ // return; ++ // } ++ if (3000 != cublaslt_version) { + LOG(INFO) << algo_caches_file_ + << " is not compatible with current cublaslt_version " + << real_cublaslt_version; +@@ -655,7 +661,8 @@ class CublasLtAlgoCache { + if (dev == 0) { + std::ofstream outfile; + outfile.open(algo_caches_file_, std::ios::out | std::ios::trunc); +- outfile << dynload::cublasLtGetCudartVersion() << std::endl; ++ // outfile << dynload::cublasLtGetCudartVersion() << std::endl; ++ outfile << 3000 << std::endl; + + for (const auto& [seed, algo] : algo_caches_) { + outfile << seed << " "; +diff --git a/paddle/phi/kernels/funcs/cublaslt.h b/paddle/phi/kernels/funcs/cublaslt.h +index fbbf57c25a..f690db59e9 100644 +--- a/paddle/phi/kernels/funcs/cublaslt.h ++++ b/paddle/phi/kernels/funcs/cublaslt.h +@@ -42,19 +42,11 @@ class CublasLtHelper { + CublasLtHelper(int m, int k, int n, cublasLtHandle_t handle) + : handle_(handle), alpha_(1), beta_(0), m_(m), k_(k), n_(n) { + cublasStatus_t status; +-#if CUBLAS_VER_MAJOR < 11 +- cudaDataType_t cudaComputeType = CUDA_R_32I; +-#else + cublasComputeType_t cudaComputeType = CUBLAS_COMPUTE_32I; +-#endif + + // matmul desc +-#if CUBLAS_VER_MAJOR < 11 +- status = dyl::cublasLtMatmulDescCreate(&matmul_desc_, cudaComputeType); +-#else + status = dyl::cublasLtMatmulDescCreate( + &matmul_desc_, cudaComputeType, CUDA_R_32I); +-#endif + + PADDLE_ENFORCE_EQ( + status, diff --git a/paddle/phi/kernels/funcs/embedding_grad.h b/paddle/phi/kernels/funcs/embedding_grad.h index 461e6e2474..48a64ae9ce 100644 --- a/paddle/phi/kernels/funcs/embedding_grad.h @@ -453,38 +736,6 @@ index 461e6e2474..48a64ae9ce 100644 #endif dim3 threads(kWarpSize, kBlockDimY); dim3 grids(static_cast((D + kWarpSize - 1) / kWarpSize)); -diff --git a/paddle/phi/kernels/funcs/fc_functor.cu b/paddle/phi/kernels/funcs/fc_functor.cu -index cb35feee32..64f5bd24ac 100644 ---- a/paddle/phi/kernels/funcs/fc_functor.cu -+++ b/paddle/phi/kernels/funcs/fc_functor.cu -@@ -16,12 +16,12 @@ limitations under the License. */ - - #include "paddle/phi/backends/all_context.h" - #include "paddle/phi/kernels/funcs/aligned_vector.h" --#include "paddle/phi/kernels/funcs/blas/blas.h" -+#include "kernels/funcs/blas/blas.h" - #include "paddle/phi/kernels/funcs/fc_functor.h" - - #include "paddle/phi/backends/gpu/gpu_launch_config.h" - #include "paddle/phi/core/dense_tensor.h" --#include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h" -+// #include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h" - #include "paddle/phi/kernels/funcs/quant_dequant.h" - #include "paddle/phi/kernels/matmul_kernel.h" - -diff --git a/paddle/phi/kernels/funcs/gru_compute.cu b/paddle/phi/kernels/funcs/gru_compute.cu -index 88663ec880..98b93072a3 100644 ---- a/paddle/phi/kernels/funcs/gru_compute.cu -+++ b/paddle/phi/kernels/funcs/gru_compute.cu -@@ -12,7 +12,7 @@ limitations under the License. */ - #include "paddle/phi/kernels/funcs/gru_compute.h" - - #include "paddle/phi/backends/gpu/gpu_context.h" --#include "paddle/phi/kernels/funcs/blas/blas.h" -+#include "kernels/funcs/blas/blas.h" - #include "paddle/phi/kernels/funcs/detail/gru_gpu_kernel.h" - #include "paddle/phi/kernels/funcs/detail/gru_kernel.h" - diff --git a/paddle/phi/kernels/funcs/layer_norm_impl.cu.h b/paddle/phi/kernels/funcs/layer_norm_impl.cu.h index 4eae698648..5c047723ea 100644 --- a/paddle/phi/kernels/funcs/layer_norm_impl.cu.h @@ -503,19 +754,6 @@ index 4eae698648..5c047723ea 100644 #endif return block_dim >= kMaxBlockDim ? kMaxBlockDim : lwarpSize; } -diff --git a/paddle/phi/kernels/funcs/math/context_project.h b/paddle/phi/kernels/funcs/math/context_project.h -index 15e1a4a3c3..e4780538d7 100644 ---- a/paddle/phi/kernels/funcs/math/context_project.h -+++ b/paddle/phi/kernels/funcs/math/context_project.h -@@ -18,7 +18,7 @@ - #include - - #include "paddle/phi/core/tensor_utils.h" --#include "paddle/phi/kernels/funcs/blas/blas.h" -+#include "kernels/funcs/blas/blas.h" - #include "paddle/phi/kernels/funcs/im2col.h" - - namespace phi { diff --git a/paddle/phi/kernels/funcs/math_cuda_utils.h b/paddle/phi/kernels/funcs/math_cuda_utils.h index e5361b836e..5ad238df08 100644 --- a/paddle/phi/kernels/funcs/math_cuda_utils.h @@ -559,51 +797,6 @@ index e5361b836e..5ad238df08 100644 return val; } -diff --git a/paddle/phi/kernels/funcs/matrix_inverse.cu b/paddle/phi/kernels/funcs/matrix_inverse.cu -index e101224970..a52eb6096f 100644 ---- a/paddle/phi/kernels/funcs/matrix_inverse.cu -+++ b/paddle/phi/kernels/funcs/matrix_inverse.cu -@@ -15,11 +15,13 @@ limitations under the License. */ - #include "paddle/phi/kernels/funcs/matrix_inverse.h" - - #include "paddle/phi/common/memory_utils.h" --#include "paddle/phi/kernels/funcs/blas/blas.h" -+#include "kernels/funcs/blas/blas.h" - - namespace phi { - namespace funcs { - -+ -+ - template - void MatrixInverseFunctor::operator()(const Context& dev_ctx, - const DenseTensor& a, -diff --git a/paddle/phi/kernels/funcs/matrix_solve.cu b/paddle/phi/kernels/funcs/matrix_solve.cu -index 558d363b39..05da04b517 100644 ---- a/paddle/phi/kernels/funcs/matrix_solve.cu -+++ b/paddle/phi/kernels/funcs/matrix_solve.cu -@@ -16,7 +16,7 @@ limitations under the License. */ - #include "paddle/phi/backends/gpu/cuda/cudnn_workspace_helper.h" - #include "paddle/phi/common/memory_utils.h" - #include "paddle/phi/core/tensor_utils.h" --#include "paddle/phi/kernels/funcs/blas/blas.h" -+#include "kernels/funcs/blas/blas.h" - #include "paddle/phi/kernels/funcs/math_function.h" - #include "paddle/phi/kernels/funcs/scatter.cu.h" - -diff --git a/paddle/phi/kernels/funcs/multihead_matmul_functor.cu b/paddle/phi/kernels/funcs/multihead_matmul_functor.cu -index 047f52bd91..a05b34d3ba 100644 ---- a/paddle/phi/kernels/funcs/multihead_matmul_functor.cu -+++ b/paddle/phi/kernels/funcs/multihead_matmul_functor.cu -@@ -27,7 +27,7 @@ namespace cub = hipcub; - - #include "paddle/phi/kernels/funcs/multihead_matmul_functor.h" - --#include "paddle/phi/kernels/funcs/blas/blas.h" -+#include "kernels/funcs/blas/blas.h" - #include "paddle/phi/kernels/funcs/math_cuda_utils.h" - - namespace phi { diff --git a/paddle/phi/kernels/funcs/top_k_function_cuda.h b/paddle/phi/kernels/funcs/top_k_function_cuda.h index e30d440ff3..108edda7ca 100644 --- a/paddle/phi/kernels/funcs/top_k_function_cuda.h @@ -873,31 +1066,17 @@ index e30d440ff3..108edda7ca 100644 } // namespace funcs } // namespace phi +// -diff --git a/paddle/phi/kernels/fusion/gpu/fused_dropout_helper.h b/paddle/phi/kernels/fusion/gpu/fused_dropout_helper.h -index 32db61532f..0220316bc3 100644 ---- a/paddle/phi/kernels/fusion/gpu/fused_dropout_helper.h -+++ b/paddle/phi/kernels/fusion/gpu/fused_dropout_helper.h -@@ -15,7 +15,7 @@ - #pragma once - - #if defined(PADDLE_WITH_CUDA) --#include "paddle/phi/backends/dynload/cublasLt.h" -+// #include "paddle/phi/backends/dynload/cublasLt.h" - #endif - - #include "glog/logging.h" diff --git a/paddle/phi/kernels/fusion/gpu/fused_layernorm_residual_dropout_bias.h b/paddle/phi/kernels/fusion/gpu/fused_layernorm_residual_dropout_bias.h -index 9d4bb18d55..ea42cc10a9 100644 +index 9d4bb18d55..80405c2b78 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_layernorm_residual_dropout_bias.h +++ b/paddle/phi/kernels/fusion/gpu/fused_layernorm_residual_dropout_bias.h -@@ -638,9 +638,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel( +@@ -638,9 +638,6 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel( RandVec(&state, rand); #pragma unroll for (int jt = 0; jt < VecSize; jt++) { -#ifndef PADDLE_WITH_HIP -#pragma unroll -#endif -+// #pragma unroll mask_vec[it][jt] = static_cast(rand[jt] >= dropout_prob); } } @@ -928,7 +1107,7 @@ index b2d15a59f8..f64582e85a 100644 namespace phi { namespace fusion { diff --git a/paddle/phi/kernels/gpu/depthwise_conv.h b/paddle/phi/kernels/gpu/depthwise_conv.h -index f0cca0f701..02ea957240 100644 +index 2edac5eba5..4f265e3db7 100644 --- a/paddle/phi/kernels/gpu/depthwise_conv.h +++ b/paddle/phi/kernels/gpu/depthwise_conv.h @@ -29,8 +29,8 @@ namespace cub = hipcub; @@ -942,19 +1121,6 @@ index f0cca0f701..02ea957240 100644 namespace phi { // To determine use cudnn or not. -diff --git a/paddle/phi/kernels/gpu/dot_kernel.cu b/paddle/phi/kernels/gpu/dot_kernel.cu -index af27ac89ab..ee0edc6b8e 100644 ---- a/paddle/phi/kernels/gpu/dot_kernel.cu -+++ b/paddle/phi/kernels/gpu/dot_kernel.cu -@@ -15,7 +15,7 @@ - #include "paddle/phi/kernels/dot_kernel.h" - #include "paddle/phi/backends/gpu/gpu_context.h" - #include "paddle/phi/core/kernel_registry.h" --#include "paddle/phi/kernels/funcs/blas/blas.h" -+#include "kernels/funcs/blas/blas.h" - #include "paddle/phi/kernels/funcs/eigen/common.h" - - #include "paddle/phi/kernels/full_kernel.h" diff --git a/paddle/phi/kernels/gpu/gelu_funcs.h b/paddle/phi/kernels/gpu/gelu_funcs.h index 29fa252e96..4ae72b0935 100644 --- a/paddle/phi/kernels/gpu/gelu_funcs.h @@ -1007,7 +1173,7 @@ index 63c35dd4ee..15da9aea45 100644 namespace phi { diff --git a/paddle/phi/kernels/gpu/lstsq_kernel.cu b/paddle/phi/kernels/gpu/lstsq_kernel.cu -index 1bdbe1564c..f753b54bc6 100644 +index c7f27b2924..4cf6204ac7 100644 --- a/paddle/phi/kernels/gpu/lstsq_kernel.cu +++ b/paddle/phi/kernels/gpu/lstsq_kernel.cu @@ -21,7 +21,7 @@ @@ -1019,84 +1185,6 @@ index 1bdbe1564c..f753b54bc6 100644 #include "paddle/phi/kernels/impl/qr_kernel_impl.h" #include "paddle/phi/kernels/impl/tril_triu_kernel_impl.h" #include "paddle/phi/kernels/lstsq_kernel.h" -diff --git a/paddle/phi/kernels/impl/addmm_grad_kernel_impl.h b/paddle/phi/kernels/impl/addmm_grad_kernel_impl.h -index 9bc5326c90..79b57a8203 100644 ---- a/paddle/phi/kernels/impl/addmm_grad_kernel_impl.h -+++ b/paddle/phi/kernels/impl/addmm_grad_kernel_impl.h -@@ -21,7 +21,7 @@ limitations under the License. */ - #include "paddle/phi/common/amp_type_traits.h" - #include "paddle/phi/kernels/addmm_grad_kernel.h" - #include "paddle/phi/kernels/full_kernel.h" --#include "paddle/phi/kernels/funcs/blas/blas.h" -+#include "kernels/funcs/blas/blas.h" - #include "paddle/phi/kernels/funcs/eigen/common.h" - #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" - #include "paddle/phi/kernels/funcs/for_range.h" -diff --git a/paddle/phi/kernels/impl/baddbmm_grad_kernel_impl.h b/paddle/phi/kernels/impl/baddbmm_grad_kernel_impl.h -index cf80666b4e..ca76e055fb 100644 ---- a/paddle/phi/kernels/impl/baddbmm_grad_kernel_impl.h -+++ b/paddle/phi/kernels/impl/baddbmm_grad_kernel_impl.h -@@ -19,7 +19,7 @@ limitations under the License. */ - - #include "paddle/phi/common/amp_type_traits.h" - #include "paddle/phi/kernels/baddbmm_grad_kernel.h" --#include "paddle/phi/kernels/funcs/blas/blas.h" -+#include "kernels/funcs/blas/blas.h" - #include "paddle/phi/kernels/funcs/eigen/common.h" - #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" - #include "paddle/phi/kernels/funcs/for_range.h" -diff --git a/paddle/phi/kernels/impl/baddbmm_kernel_impl.h b/paddle/phi/kernels/impl/baddbmm_kernel_impl.h -index 2789cb59a2..b91b076f7f 100644 ---- a/paddle/phi/kernels/impl/baddbmm_kernel_impl.h -+++ b/paddle/phi/kernels/impl/baddbmm_kernel_impl.h -@@ -20,7 +20,7 @@ limitations under the License. */ - - #include "paddle/phi/common/amp_type_traits.h" - #include "paddle/phi/kernels/baddbmm_kernel.h" --#include "paddle/phi/kernels/funcs/blas/blas.h" -+#include "kernels/funcs/blas/blas.h" - #include "paddle/phi/kernels/funcs/eigen/common.h" - #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" - -diff --git a/paddle/phi/kernels/impl/conv_transpose_grad_kernel_impl.h b/paddle/phi/kernels/impl/conv_transpose_grad_kernel_impl.h -index 9a21c23666..86413d1577 100644 ---- a/paddle/phi/kernels/impl/conv_transpose_grad_kernel_impl.h -+++ b/paddle/phi/kernels/impl/conv_transpose_grad_kernel_impl.h -@@ -19,7 +19,7 @@ - #include "paddle/phi/kernels/conv_transpose_grad_kernel.h" - #include "paddle/phi/kernels/cpu/conv_util.h" - #include "paddle/phi/kernels/full_kernel.h" --#include "paddle/phi/kernels/funcs/blas/blas.h" -+#include "kernels/funcs/blas/blas.h" - #include "paddle/phi/kernels/funcs/concat_and_split_functor.h" - #include "paddle/phi/kernels/funcs/im2col.h" - #include "paddle/phi/kernels/funcs/slice.h" -diff --git a/paddle/phi/kernels/impl/deformable_conv_grad_kernel_impl.h b/paddle/phi/kernels/impl/deformable_conv_grad_kernel_impl.h -index 4459a931da..837c8682b8 100644 ---- a/paddle/phi/kernels/impl/deformable_conv_grad_kernel_impl.h -+++ b/paddle/phi/kernels/impl/deformable_conv_grad_kernel_impl.h -@@ -18,7 +18,7 @@ - #include "paddle/phi/core/dense_tensor.h" - #include "paddle/phi/kernels/empty_kernel.h" - #include "paddle/phi/kernels/full_kernel.h" --#include "paddle/phi/kernels/funcs/blas/blas.h" -+#include "kernels/funcs/blas/blas.h" - #include "paddle/phi/kernels/funcs/deformable_conv_functor.h" - - namespace phi { -diff --git a/paddle/phi/kernels/impl/deformable_conv_kernel_impl.h b/paddle/phi/kernels/impl/deformable_conv_kernel_impl.h -index ad9e9197dd..5478d9817d 100644 ---- a/paddle/phi/kernels/impl/deformable_conv_kernel_impl.h -+++ b/paddle/phi/kernels/impl/deformable_conv_kernel_impl.h -@@ -18,7 +18,7 @@ - #include "paddle/phi/core/dense_tensor.h" - #include "paddle/phi/kernels/empty_kernel.h" - #include "paddle/phi/kernels/full_kernel.h" --#include "paddle/phi/kernels/funcs/blas/blas.h" -+#include "kernels/funcs/blas/blas.h" - #include "paddle/phi/kernels/funcs/deformable_conv_functor.h" - #include "paddle/phi/kernels/transpose_kernel.h" - #include "paddle/utils/optional.h" diff --git a/paddle/phi/kernels/impl/gammaincc_kernel_impl.h b/paddle/phi/kernels/impl/gammaincc_kernel_impl.h index e6b3960f6d..564125f1f6 100644 --- a/paddle/phi/kernels/impl/gammaincc_kernel_impl.h @@ -1113,9 +1201,20 @@ index e6b3960f6d..564125f1f6 100644 if ((x <= T{0}) || (a <= T{0})) return (T{1.0}); diff --git a/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h b/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h -index 410fb3c560..009ce03440 100644 +index 410fb3c560..7d173d46f5 100644 --- a/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h +@@ -20,8 +20,8 @@ + namespace phi { + template + HOSTDEVICE T digamma_positive_domain(T x) { +- static T c = T{8.5}; +- static T euler_mascheroni = T{0.57721566490153286060}; ++ const static T c = T{8.5}; ++ const static T euler_mascheroni = T{0.57721566490153286060}; + T r; + T value; + T x2; @@ -54,7 +54,7 @@ HOSTDEVICE T digamma_positive_domain(T x) { template @@ -1125,67 +1224,3 @@ index 410fb3c560..009ce03440 100644 if (x == T{0.0}) { T inf = std::numeric_limits::infinity(); -diff --git a/paddle/phi/kernels/impl/llm_int8_matmul_kernel_impl.h b/paddle/phi/kernels/impl/llm_int8_matmul_kernel_impl.h -index 5ebbc8d2db..c7b6c338e2 100644 ---- a/paddle/phi/kernels/impl/llm_int8_matmul_kernel_impl.h -+++ b/paddle/phi/kernels/impl/llm_int8_matmul_kernel_impl.h -@@ -15,8 +15,9 @@ limitations under the License. */ - #include - #include - #include "paddle/phi/common/datatype_traits.h" --#include "paddle/phi/kernels/funcs/cublaslt.h" --#include "paddle/phi/kernels/funcs/quant_dequant.h" -+#include "kernels/funcs/blas/cublaslt.h" -+#include "kernels/funcs/quant_dequant.h" -+#include "kernels/metax_kernel/metax_context.h" - - #pragma once - -@@ -668,7 +669,7 @@ void LLMGemm(const phi::GPUContext& dev_ctx, - - { - auto helper = -- std::make_unique(m, k, n, dev_ctx.cublaslt_handle()); -+ std::make_unique(m, k, n, GetBlasLtHandle()); - helper->GEMM(quant_input.data(), - weight->data(), - int_out.data(), -diff --git a/paddle/phi/kernels/impl/matrix_power_grad_kernel_impl.h b/paddle/phi/kernels/impl/matrix_power_grad_kernel_impl.h -index 1f319c4ae3..9186eb6906 100644 ---- a/paddle/phi/kernels/impl/matrix_power_grad_kernel_impl.h -+++ b/paddle/phi/kernels/impl/matrix_power_grad_kernel_impl.h -@@ -15,7 +15,7 @@ limitations under the License. */ - #pragma once - - #include "paddle/phi/core/dense_tensor.h" --#include "paddle/phi/kernels/funcs/blas/blas.h" -+#include "kernels/funcs/blas/blas.h" - #include "paddle/phi/kernels/funcs/matrix_inverse.h" - - namespace phi { -diff --git a/paddle/phi/kernels/impl/matrix_power_kernel_impl.h b/paddle/phi/kernels/impl/matrix_power_kernel_impl.h -index 6f03f76eeb..5fe2c3e7dc 100644 ---- a/paddle/phi/kernels/impl/matrix_power_kernel_impl.h -+++ b/paddle/phi/kernels/impl/matrix_power_kernel_impl.h -@@ -15,7 +15,7 @@ limitations under the License. */ - #pragma once - - #include "paddle/phi/core/dense_tensor.h" --#include "paddle/phi/kernels/funcs/blas/blas.h" -+#include "kernels/funcs/blas/blas.h" - #include "paddle/phi/kernels/funcs/for_range.h" - #include "paddle/phi/kernels/funcs/matrix_inverse.h" - -diff --git a/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h b/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h -index 4099d8b506..baef2cd643 100644 ---- a/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h -+++ b/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h -@@ -14,7 +14,7 @@ - - #pragma once - --#include "paddle/phi/kernels/funcs/blas/blas.h" -+#include "kernels/funcs/blas/blas.h" - #include "paddle/phi/kernels/funcs/eigen/common.h" - #include "paddle/phi/kernels/funcs/math_function.h" - diff --git a/backends/metax_gpu/runtime/runtime.cc b/backends/metax_gpu/runtime/runtime.cc index 36fbd88c2ea..edbe937e7ba 100644 --- a/backends/metax_gpu/runtime/runtime.cc +++ b/backends/metax_gpu/runtime/runtime.cc @@ -36,12 +36,12 @@ #include #include "glog/logging.h" -#include "kernels/funcs/blas/cublasLt.h" #include "paddle/fluid/platform/profiler/cuda_tracer.h" #include "paddle/fluid/platform/profiler/cupti_data_process.h" #include "paddle/phi/api/profiler/trace_event_collector.h" #include "paddle/phi/backends/device_base.h" #include "paddle/phi/backends/device_ext.h" +#include "paddle/phi/backends/dynload/cublasLt.h" #include "paddle/phi/backends/dynload/cupti.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/allocator.h" diff --git a/backends/metax_gpu/tests/ignore.txt b/backends/metax_gpu/tests/ignore.txt index be0357e5319..2b0fae559e6 100644 --- a/backends/metax_gpu/tests/ignore.txt +++ b/backends/metax_gpu/tests/ignore.txt @@ -24,9 +24,9 @@ test_conv3d_layer test_conv3d_transpose_part2_op test_fused_conv2d_add_act_op test_swiglu_metax -test_set_value_op -test_pad_op test_squared_l2_norm_op -test_concat_op test_dygraph_spectral_norm test_bincount_op +test_adamw_op +test_einsum_op +test_complex_matmul