diff --git a/Paddle b/Paddle index 5dbecdcb0e4..b51d1da36de 160000 --- a/Paddle +++ b/Paddle @@ -1 +1 @@ -Subproject commit 5dbecdcb0e4ddd3488927f49082dfb66c794f9e7 +Subproject commit b51d1da36debb9faaa4197629c82c0fe907a94c9 diff --git a/backends/metax_gpu/CMakeLists.txt b/backends/metax_gpu/CMakeLists.txt index 6aecdc1f833..9e257e9507d 100755 --- a/backends/metax_gpu/CMakeLists.txt +++ b/backends/metax_gpu/CMakeLists.txt @@ -109,6 +109,10 @@ file( CUDA_SRCS # backends ${PADDLE_SOURCE_DIR}/paddle/phi/backends/gpu/cuda/cuda_info.cc + ${PADDLE_SOURCE_DIR}/paddle/phi/backends/dynload/dynamic_loader.cc + ${PADDLE_SOURCE_DIR}/paddle/phi/backends/dynload/cublas.cc + ${PADDLE_SOURCE_DIR}/paddle/phi/backends/dynload/cublasLt.cc + ${PADDLE_SOURCE_DIR}/paddle/phi/backends/dynload/cudnn.cc ${PADDLE_SOURCE_DIR}/paddle/phi/backends/dynload/cuda_driver.cc ${PADDLE_SOURCE_DIR}/paddle/phi/backends/gpu/cuda/cuda_graph.cc # Core @@ -698,7 +702,6 @@ file( kernels/gpudnn/*.cu kernels/cuda_kernels/*.cc kernels/cuda_kernels/*.cu - kernels/funcs/blas/*.cc kernels/ernie_core/*.cu) set(CUSTOM_DEVICE_SRCS ${CUDA_SRCS} ${CC_SRCS} ${ERNIE_CORE_SRCS}) @@ -746,11 +749,28 @@ target_compile_definitions( PUBLIC PADDLE_WITH_CUDA=1 PADDLE_WITH_CUSTOM_DEVICE=1 mcblasContext=cublasContext + cublasLtContext=mcblasLtContext GPUContext=CustomContext KPSContext=CustomContext STREAM_TYPE=cudaStream_t EVENT_TYPE=cudaEvent_t - EIGEN_USE_GPU=1) + EIGEN_USE_GPU=1 + CUDA_LIB_NAME="libmcruntime.so" + BLAS_LIB_NAME="libmcblas.so" + BLASLT_LIB_NAME="libmcblasLt.so" + DNN_LIB_NAME="libmcdnn.so" + PTI_LIB_NAME="libmcpti.so" + RAND_LIB_NAME="libcurand.so" + JPEG_LIB_NAME="libnvjpeg.so" + SOLVER_LIB_NAME="libmcsolver.so" + SPARSE_LIB_NAME="libmcsparse.so" + RTC_LIB_NAME="libmcruntime.so" + FLASHATTN_LIB_NAME="libmcFlashAttn.so" + FLASHATTNV3_LIB_NAME="libflashattnv3.so" + CCL_LIB_NAME="libmccl.so" + FFT_LIB_NAME="libcufft.so" + SPARSELT_LIB_NAME="libcusparseLt.so" + CUPTI_LIB_PATH="/root/cu-bridge/CUDA_DIR/extras/CUPTI/lib64") # packing wheel package configure_file(${CMAKE_CURRENT_SOURCE_DIR}/setup.py.in diff --git a/backends/metax_gpu/common/flags_declare.cc b/backends/metax_gpu/common/flags_declare.cc index 6b497cf9fdf..fb656878033 100644 --- a/backends/metax_gpu/common/flags_declare.cc +++ b/backends/metax_gpu/common/flags_declare.cc @@ -37,6 +37,27 @@ */ static constexpr int kDefaultConvWorkspaceSizeLimitMB = 512; +/** + * CUDA related FLAG + * Name: FLAGS_cublaslt_exhaustive_search_times + * Since Version: 2.3.0 + * Value Range: int64_t, default=0 + * Example: + * Note: Represents times of exhaustive search to evaluate performance of + * cuBlasLt matmul algorithm (with/without epilogue). Set this flag + * with value > 0 to enable exhaustive search. Default is 0, means + * getting algorithms via heuristic search. There are two search methods + * in cuBlasLt, heuristic search and exhaustive search. Exhaustive search + * attempts all cuBlasLt algorithms to select the fastest, which is very + * time-consuming, and the selected algorithm will be cached for a given + * layer specification Once you change the layer specifications + * (such as M, N and K), it will re-search again. + */ +PHI_DEFINE_EXPORTED_int64( + cublaslt_exhaustive_search_times, + 0, + "The times of exhaustive search for cuBlasLt matmul with/without " + " epilogue algorithms, default is 0, means disabling exhaustive search."); PHI_DEFINE_EXPORTED_bool( cudnn_exhaustive_search, diff --git a/backends/metax_gpu/kernels/cuda_kernels/activation_grad_kernel_register.cu b/backends/metax_gpu/kernels/cuda_kernels/activation_grad_kernel_register.cu index 6c46ef10c0f..f5ee4ec25f8 100644 --- a/backends/metax_gpu/kernels/cuda_kernels/activation_grad_kernel_register.cu +++ b/backends/metax_gpu/kernels/cuda_kernels/activation_grad_kernel_register.cu @@ -15,8 +15,6 @@ limitations under the License. */ #include "glog/logging.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_device_function.h" -#include "paddle/phi/common/bfloat16.h" -#include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/activation_grad_kernel.h" #include "paddle/phi/kernels/full_kernel.h" @@ -103,6 +101,21 @@ void ActivationGradGPUImpl(const Context& dev_ctx, dev_ctx, &x, nullptr, &dout, dx, functor); \ } +#define DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_DOUBLE_ATTRS_DEPX( \ + name, functor_class, attr) \ + template \ + void name##GradKernel(const Context& dev_ctx, \ + const DenseTensor& x, \ + const DenseTensor& dout, \ + double attr, \ + DenseTensor* dx) { \ + funcs::functor_class functor; \ + auto attrs = functor.GetAttrs(); \ + *(attrs[0].second) = attr; \ + ActivationGradGPUImpl>( \ + dev_ctx, &x, nullptr, &dout, dx, functor); \ + } + #define DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX( \ name, functor_class, attr1, attr2) \ template \ @@ -119,6 +132,7 @@ void ActivationGradGPUImpl(const Context& dev_ctx, ActivationGradGPUImpl>( \ dev_ctx, &x, nullptr, &dout, dx, functor); \ } + #define DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_DOUBLE_ATTRS_DEPX( \ name, functor_class, attr1, attr2) \ template \ @@ -135,6 +149,7 @@ void ActivationGradGPUImpl(const Context& dev_ctx, ActivationGradGPUImpl>( \ dev_ctx, &x, nullptr, &dout, dx, functor); \ } + #define DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(name, functor_class) \ template \ void name##GradKernel(const Context& dev_ctx, \ @@ -161,6 +176,21 @@ void ActivationGradGPUImpl(const Context& dev_ctx, dev_ctx, nullptr, &out, &dout, dx, functor); \ } +#define DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_DOUBLE_ATTRS_DEPOUT( \ + name, functor_class, attr) \ + template \ + void name##GradKernel(const Context& dev_ctx, \ + const DenseTensor& out, \ + const DenseTensor& dout, \ + double attr, \ + DenseTensor* dx) { \ + funcs::functor_class functor; \ + auto attrs = functor.GetAttrs(); \ + *(attrs[0].second) = attr; \ + ActivationGradGPUImpl>( \ + dev_ctx, nullptr, &out, &dout, dx, functor); \ + } + #define DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPOUT( \ name, functor_class, attr1, attr2) \ template \ @@ -224,9 +254,9 @@ DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Log10, CudaLog10GradFunctor); DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Log1p, CudaLog1pGradFunctor); DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Swish, CudaSwishGradFunctor); -DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(LeakyRelu, - CudaLeakyReluGradFunctor, - alpha); +DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_DOUBLE_ATTRS_DEPX(LeakyRelu, + CudaLeakyReluGradFunctor, + alpha); DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(SoftShrink, CudaSoftShrinkGradFunctor, lambda); @@ -240,9 +270,9 @@ DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Mish, DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Celu, CudaCELUGradFunctor, alpha); -DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPOUT(LogitCUDA, - CudaLogitGradFunctor, - eps); +DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_DOUBLE_ATTRS_DEPOUT(LogitCUDA, + CudaLogitGradFunctor, + eps); DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(HardTanh, CudaHardTanhGradFunctor, @@ -266,6 +296,7 @@ DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(ThresholdedRelu, CudaThresholdedReluGradFunctor, threshold, value); + template void SiluGradKernel(const Context& dev_ctx, const DenseTensor& x, @@ -390,14 +421,14 @@ PD_CUSTOM_KERNEL_REGISTER(relu_grad, phi::ReluGradKernel, float, double, - phi::dtype::float16) {} + phi::float16) {} PD_CUSTOM_KERNEL_REGISTER(relu_double_grad, metax_gpu, ALL_LAYOUT, phi::ReluDoubleGradKernel, float, double, - phi::dtype::float16) {} + phi::float16) {} #else PD_CUSTOM_KERNEL_REGISTER(relu_grad, metax_gpu, @@ -405,16 +436,16 @@ PD_CUSTOM_KERNEL_REGISTER(relu_grad, phi::ReluGradKernel, float, double, - phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::float16, + phi::bfloat16) {} PD_CUSTOM_KERNEL_REGISTER(relu_double_grad, metax_gpu, ALL_LAYOUT, phi::ReluDoubleGradKernel, float, double, - phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::float16, + phi::bfloat16) {} #endif #define PD_REGISTER_ACTIVATION_GRAD_KERNEL(name, func) \ @@ -424,8 +455,8 @@ PD_CUSTOM_KERNEL_REGISTER(relu_double_grad, phi::func, \ float, \ double, \ - phi::dtype::float16, \ - phi::dtype::bfloat16) {} + phi::float16, \ + phi::bfloat16) {} #define PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(name, func) \ PD_CUSTOM_KERNEL_REGISTER(name, \ @@ -434,10 +465,10 @@ PD_CUSTOM_KERNEL_REGISTER(relu_double_grad, phi::func, \ float, \ double, \ - phi::dtype::float16, \ - phi::dtype::bfloat16, \ - phi::dtype::complex, \ - phi::dtype::complex) {} + phi::float16, \ + phi::bfloat16, \ + phi::complex64, \ + phi::complex128) {} PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(sin_grad, SinGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(cos_grad, CosGradKernel) @@ -483,10 +514,10 @@ PD_CUSTOM_KERNEL_REGISTER(exp_grad, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_REGISTER_ACTIVATION_GRAD_KERNEL(softshrink_grad, SoftShrinkGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_shrink_grad, HardShrinkGradKernel) @@ -502,10 +533,10 @@ PD_CUSTOM_KERNEL_REGISTER(expm1_grad, phi::Expm1GradKernel, float, double, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(square_grad, metax_gpu, @@ -515,10 +546,10 @@ PD_CUSTOM_KERNEL_REGISTER(square_grad, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(square_double_grad, metax_gpu, ALL_LAYOUT, @@ -527,10 +558,10 @@ PD_CUSTOM_KERNEL_REGISTER(square_double_grad, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(sin_double_grad, metax_gpu, @@ -540,10 +571,10 @@ PD_CUSTOM_KERNEL_REGISTER(sin_double_grad, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(sin_triple_grad, metax_gpu, @@ -553,10 +584,10 @@ PD_CUSTOM_KERNEL_REGISTER(sin_triple_grad, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(cos_double_grad, metax_gpu, @@ -566,10 +597,10 @@ PD_CUSTOM_KERNEL_REGISTER(cos_double_grad, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(cos_triple_grad, metax_gpu, @@ -579,10 +610,10 @@ PD_CUSTOM_KERNEL_REGISTER(cos_triple_grad, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(softsign_grad, SoftsignGradKernel) @@ -604,10 +635,10 @@ PD_CUSTOM_KERNEL_REGISTER(log_double_grad, phi::LogDoubleGradKernel, float, double, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(hardswish_grad, HardSwishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel) @@ -622,8 +653,8 @@ PD_CUSTOM_KERNEL_REGISTER(rint_grad, int64_t, float, double, - phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::float16, + phi::bfloat16) {} PD_CUSTOM_KERNEL_REGISTER(round_grad, metax_gpu, ALL_LAYOUT, @@ -632,10 +663,10 @@ PD_CUSTOM_KERNEL_REGISTER(round_grad, int64_t, float, double, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(pow_grad, metax_gpu, ALL_LAYOUT, @@ -644,10 +675,10 @@ PD_CUSTOM_KERNEL_REGISTER(pow_grad, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(pow_double_grad, metax_gpu, ALL_LAYOUT, @@ -656,10 +687,10 @@ PD_CUSTOM_KERNEL_REGISTER(pow_double_grad, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(pow_triple_grad, metax_gpu, ALL_LAYOUT, @@ -668,10 +699,10 @@ PD_CUSTOM_KERNEL_REGISTER(pow_triple_grad, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(ceil_grad, metax_gpu, ALL_LAYOUT, @@ -683,8 +714,8 @@ PD_CUSTOM_KERNEL_REGISTER(ceil_grad, int16_t, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::float16, + phi::bfloat16) {} PD_CUSTOM_KERNEL_REGISTER(floor_grad, metax_gpu, ALL_LAYOUT, @@ -696,5 +727,5 @@ PD_CUSTOM_KERNEL_REGISTER(floor_grad, int16_t, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::float16, + phi::bfloat16) {} diff --git a/backends/metax_gpu/kernels/cuda_kernels/activation_kernel_register.cu b/backends/metax_gpu/kernels/cuda_kernels/activation_kernel_register.cu index 363932cfc28..d91e4afd25e 100644 --- a/backends/metax_gpu/kernels/cuda_kernels/activation_kernel_register.cu +++ b/backends/metax_gpu/kernels/cuda_kernels/activation_kernel_register.cu @@ -14,8 +14,6 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_device_function.h" -#include "paddle/phi/common/bfloat16.h" -#include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/activation_kernel.h" #include "paddle/phi/kernels/full_kernel.h" @@ -75,6 +73,19 @@ void ActivationGPUImpl(const Context& dev_ctx, dev_ctx, x, out, functor); \ } +#define DEFINE_GPU_ACT_KERNEL_WITH_ONE_DOUBLE_ATTRS(name, functor_class, attr) \ + template \ + void name##Kernel(const Context& dev_ctx, \ + const DenseTensor& x, \ + double attr, \ + DenseTensor* out) { \ + funcs::functor_class functor; \ + auto attrs = functor.GetAttrs(); \ + *(attrs[0].second) = attr; \ + ActivationGPUImpl>( \ + dev_ctx, x, out, functor); \ + } + #define DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS( \ name, functor_class, attr1, attr2) \ template \ @@ -90,6 +101,7 @@ void ActivationGPUImpl(const Context& dev_ctx, ActivationGPUImpl>( \ dev_ctx, x, out, functor); \ } + #define DEFINE_GPU_ACT_KERNEL_WITH_TWO_DOUBLE_ATTRS( \ name, functor_class, attr1, attr2) \ template \ @@ -105,6 +117,7 @@ void ActivationGPUImpl(const Context& dev_ctx, ActivationGPUImpl>( \ dev_ctx, x, out, functor); \ } + DEFINE_GPU_ACTIVATION_KERNEL(Cos, CudaCosFunctor) DEFINE_GPU_ACTIVATION_KERNEL(Tan, CudaTanFunctor) DEFINE_GPU_ACTIVATION_KERNEL(Acos, CudaAcosFunctor) @@ -138,8 +151,10 @@ DEFINE_GPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Log1p, CudaLog1pFunctor) DEFINE_GPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Exp, CudaExpFunctor) DEFINE_GPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Expm1, CudaExpm1Functor) -DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, CudaLeakyReluFunctor, alpha) -DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(LogitCUDA, CudaLogitFunctor, eps) +DEFINE_GPU_ACT_KERNEL_WITH_ONE_DOUBLE_ATTRS(LeakyRelu, + CudaLeakyReluFunctor, + alpha) +DEFINE_GPU_ACT_KERNEL_WITH_ONE_DOUBLE_ATTRS(LogitCUDA, CudaLogitFunctor, eps) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink, CudaHardShrinkFunctor, threshold) @@ -286,13 +301,9 @@ void PowKernel(const Context& dev_ctx, } // namespace phi #ifdef PADDLE_WITH_HIP -PD_CUSTOM_KERNEL_REGISTER(relu, - metax_gpu, - ALL_LAYOUT, - phi::ReluKernel, - float, - double, - phi::dtype::float16) {} +PD_CUSTOM_KERNEL_REGISTER( + relu, metax_gpu, ALL_LAYOUT, phi::ReluKernel, float, double, phi::float16) { +} #else PD_CUSTOM_KERNEL_REGISTER(relu, metax_gpu, @@ -300,8 +311,8 @@ PD_CUSTOM_KERNEL_REGISTER(relu, phi::ReluKernel, float, double, - phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::float16, + phi::bfloat16) {} #endif #define PD_REGISTER_ACTIVATION_KERNEL(name, func) \ @@ -311,8 +322,8 @@ PD_CUSTOM_KERNEL_REGISTER(relu, phi::func, \ float, \ double, \ - phi::dtype::float16, \ - phi::dtype::bfloat16) {} + phi::float16, \ + phi::bfloat16) {} #define PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(name, func) \ PD_CUSTOM_KERNEL_REGISTER(name, \ @@ -321,10 +332,10 @@ PD_CUSTOM_KERNEL_REGISTER(relu, phi::func, \ float, \ double, \ - phi::dtype::float16, \ - phi::dtype::bfloat16, \ - phi::dtype::complex, \ - phi::dtype::complex) {} + phi::float16, \ + phi::bfloat16, \ + phi::complex64, \ + phi::complex128) {} PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(sin, SinKernel) PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(cos, CosKernel) @@ -357,10 +368,10 @@ PD_CUSTOM_KERNEL_REGISTER(exp, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(expm1, metax_gpu, ALL_LAYOUT, @@ -369,10 +380,10 @@ PD_CUSTOM_KERNEL_REGISTER(expm1, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(square, metax_gpu, ALL_LAYOUT, @@ -381,10 +392,10 @@ PD_CUSTOM_KERNEL_REGISTER(square, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_REGISTER_ACTIVATION_KERNEL(hard_shrink, HardShrinkKernel) PD_REGISTER_ACTIVATION_KERNEL(softshrink, SoftShrinkKernel) @@ -409,8 +420,8 @@ PD_CUSTOM_KERNEL_REGISTER(rint, int64_t, float, double, - phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::float16, + phi::bfloat16) {} PD_CUSTOM_KERNEL_REGISTER(round, metax_gpu, ALL_LAYOUT, @@ -419,10 +430,10 @@ PD_CUSTOM_KERNEL_REGISTER(round, int64_t, float, double, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(log, metax_gpu, ALL_LAYOUT, @@ -431,10 +442,10 @@ PD_CUSTOM_KERNEL_REGISTER(log, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(log2, metax_gpu, ALL_LAYOUT, @@ -443,10 +454,10 @@ PD_CUSTOM_KERNEL_REGISTER(log2, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(log10, metax_gpu, ALL_LAYOUT, @@ -455,10 +466,10 @@ PD_CUSTOM_KERNEL_REGISTER(log10, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(log1p, metax_gpu, ALL_LAYOUT, @@ -467,10 +478,10 @@ PD_CUSTOM_KERNEL_REGISTER(log1p, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(pow, metax_gpu, ALL_LAYOUT, @@ -479,10 +490,10 @@ PD_CUSTOM_KERNEL_REGISTER(pow, double, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(ceil, metax_gpu, ALL_LAYOUT, @@ -494,8 +505,8 @@ PD_CUSTOM_KERNEL_REGISTER(ceil, int16_t, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::float16, + phi::bfloat16) {} PD_CUSTOM_KERNEL_REGISTER(floor, metax_gpu, ALL_LAYOUT, @@ -507,5 +518,5 @@ PD_CUSTOM_KERNEL_REGISTER(floor, int16_t, int, int64_t, - phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::float16, + phi::bfloat16) {} diff --git a/backends/metax_gpu/kernels/cuda_kernels/argsort_kernel_register.cu b/backends/metax_gpu/kernels/cuda_kernels/argsort_kernel_register.cu index 8fb331eeedd..20ea33834e6 100644 --- a/backends/metax_gpu/kernels/cuda_kernels/argsort_kernel_register.cu +++ b/backends/metax_gpu/kernels/cuda_kernels/argsort_kernel_register.cu @@ -26,11 +26,11 @@ namespace cub = hipcub; #endif -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/primitive/functor_primitives.h" #include "paddle/phi/kernels/transpose_kernel.h" diff --git a/backends/metax_gpu/kernels/cuda_kernels/batch_fc_kernel_register.cu b/backends/metax_gpu/kernels/cuda_kernels/batch_fc_kernel_register.cu index caccb01f71d..0e82304d31d 100644 --- a/backends/metax_gpu/kernels/cuda_kernels/batch_fc_kernel_register.cu +++ b/backends/metax_gpu/kernels/cuda_kernels/batch_fc_kernel_register.cu @@ -14,10 +14,10 @@ #include -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/eigen/common.h" namespace phi { diff --git a/backends/metax_gpu/kernels/cuda_kernels/fused_gemm_epilogue_grad_kernel_register.cu b/backends/metax_gpu/kernels/cuda_kernels/fused_gemm_epilogue_grad_kernel_register.cu new file mode 100644 index 00000000000..2e8d33b964c --- /dev/null +++ b/backends/metax_gpu/kernels/cuda_kernels/fused_gemm_epilogue_grad_kernel_register.cu @@ -0,0 +1,26 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h" +#include "paddle/phi/kernels/fusion/gpu/fused_gemm_epilogue_grad_kernel.cu" //NOLINT + +PD_CUSTOM_KERNEL_REGISTER(fused_gemm_epilogue_grad, + metax_gpu, + ALL_LAYOUT, + phi::fusion::FusedGemmEpilogueGradKernel, + float, + double, + phi::float16, + phi::bfloat16) {} diff --git a/backends/metax_gpu/kernels/cuda_kernels/fused_gemm_epilogue_kernel_register.cu b/backends/metax_gpu/kernels/cuda_kernels/fused_gemm_epilogue_kernel_register.cu new file mode 100644 index 00000000000..9be5794c54f --- /dev/null +++ b/backends/metax_gpu/kernels/cuda_kernels/fused_gemm_epilogue_kernel_register.cu @@ -0,0 +1,26 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h" +#include "paddle/phi/kernels/fusion/gpu/fused_gemm_epilogue_kernel.cu" //NOLINT + +PD_CUSTOM_KERNEL_REGISTER(fused_gemm_epilogue, + metax_gpu, + ALL_LAYOUT, + phi::fusion::FusedGemmEpilogueKernel, + float, + double, + phi::float16, + phi::bfloat16) {} diff --git a/backends/metax_gpu/kernels/cuda_kernels/fused_linear_param_grad_add_kernel_register.cu b/backends/metax_gpu/kernels/cuda_kernels/fused_linear_param_grad_add_kernel_register.cu new file mode 100644 index 00000000000..c88f94625b7 --- /dev/null +++ b/backends/metax_gpu/kernels/cuda_kernels/fused_linear_param_grad_add_kernel_register.cu @@ -0,0 +1,24 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/fusion/gpu/fused_linear_param_grad_add_kernel.cu" //NOLINT +PD_CUSTOM_KERNEL_REGISTER(fused_linear_param_grad_add, + metax_gpu, + ALL_LAYOUT, + phi::fusion::FusedLinearParamGradAdd, + float, + double, + phi::float16, + phi::bfloat16) {} diff --git a/backends/metax_gpu/kernels/cuda_kernels/matmul_grad_kernel_register.cu b/backends/metax_gpu/kernels/cuda_kernels/matmul_grad_kernel_register.cu index f9eef9908ab..bb3b07d24d0 100644 --- a/backends/metax_gpu/kernels/cuda_kernels/matmul_grad_kernel_register.cu +++ b/backends/metax_gpu/kernels/cuda_kernels/matmul_grad_kernel_register.cu @@ -13,9 +13,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "../impl/matmul_grad_kernel_impl.h" #include "paddle/phi/common/complex.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/matmul_grad_kernel_impl.h" #include "paddle/phi/kernels/matmul_grad_kernel.h" PD_CUSTOM_KERNEL_REGISTER(matmul_grad, diff --git a/backends/metax_gpu/kernels/cuda_kernels/matmul_kernel_register.cu b/backends/metax_gpu/kernels/cuda_kernels/matmul_kernel_register.cu index 57c3a85b1ea..750cf2a9f36 100644 --- a/backends/metax_gpu/kernels/cuda_kernels/matmul_kernel_register.cu +++ b/backends/metax_gpu/kernels/cuda_kernels/matmul_kernel_register.cu @@ -18,7 +18,7 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/complex.h" #include "paddle/phi/core/kernel_registry.h" -#include "kernels/impl/matmul_kernel_impl.h" +#include "paddle/phi/kernels/impl/matmul_kernel_impl.h" #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) diff --git a/backends/metax_gpu/kernels/cuda_kernels/multihead_matmul_kernel.cu b/backends/metax_gpu/kernels/cuda_kernels/multihead_matmul_kernel.cu index 151c929e41c..998854140fc 100644 --- a/backends/metax_gpu/kernels/cuda_kernels/multihead_matmul_kernel.cu +++ b/backends/metax_gpu/kernels/cuda_kernels/multihead_matmul_kernel.cu @@ -15,11 +15,11 @@ #include #include -#include "kernels/funcs/blas/blas.h" #include "paddle/common/errors.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/multihead_matmul_functor.h" namespace phi { diff --git a/backends/metax_gpu/kernels/cuda_kernels/pad_grad_kernel_register.cu b/backends/metax_gpu/kernels/cuda_kernels/pad_grad_kernel_register.cu index 38b89fce698..f87f589a424 100644 --- a/backends/metax_gpu/kernels/cuda_kernels/pad_grad_kernel_register.cu +++ b/backends/metax_gpu/kernels/cuda_kernels/pad_grad_kernel_register.cu @@ -20,6 +20,8 @@ PD_CUSTOM_KERNEL_REGISTER(pad_grad, ALL_LAYOUT, phi::PadGradKernel, float, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex) {} + double, + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} diff --git a/backends/metax_gpu/kernels/dynload/cupti_lib_path.h b/backends/metax_gpu/kernels/dynload/cupti_lib_path.h deleted file mode 100644 index 6082fffd60e..00000000000 --- a/backends/metax_gpu/kernels/dynload/cupti_lib_path.h +++ /dev/null @@ -1,19 +0,0 @@ -// 2024 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights -// Reserved. -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#define CUPTI_LIB_PATH "/root/cu-bridge/CUDA_DIR/extras/CUPTI/lib64" diff --git a/backends/metax_gpu/kernels/dynload/dynamic_loader.cc b/backends/metax_gpu/kernels/dynload/dynamic_loader.cc deleted file mode 100644 index a23b7fa2aff..00000000000 --- a/backends/metax_gpu/kernels/dynload/dynamic_loader.cc +++ /dev/null @@ -1,938 +0,0 @@ -// 2024 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights -// Reserved. -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -// #include "paddle/phi/backends/dynload/dynamic_loader.h" -#include "kernels/dynload/dynamic_loader.h" - -#include - -#include -#include -#include -#include -// #include "paddle/phi/backends/dynload/cupti_lib_path.h" -#include "./dynload/cupti_lib_path.h" -#include "paddle/phi/common/port.h" -#include "paddle/phi/core/enforce.h" - -#if defined(_WIN32) -#include -#endif - -// TODO(wilber): The phi computing library requires a component to manage flags -// (maybe not use gflags). -#include "glog/logging.h" -#include "paddle/common/flags.h" - -COMMON_DECLARE_string(cudnn_dir); -COMMON_DECLARE_string(cuda_dir); -COMMON_DECLARE_string(cublas_dir); -COMMON_DECLARE_string(nccl_dir); -COMMON_DECLARE_string(cupti_dir); -COMMON_DECLARE_string(tensorrt_dir); -COMMON_DECLARE_string(mklml_dir); -COMMON_DECLARE_string(lapack_dir); -COMMON_DECLARE_string(mkl_dir); -COMMON_DECLARE_string(op_dir); -COMMON_DECLARE_string(cusparselt_dir); -COMMON_DECLARE_string(curand_dir); -COMMON_DECLARE_string(cusolver_dir); -COMMON_DECLARE_string(cusparse_dir); -COMMON_DECLARE_string(win_cuda_bin_dir); -#ifdef PADDLE_WITH_HIP - -PHI_DEFINE_string(miopen_dir, - "", - "Specify path for loading libMIOpen.so. For instance, " - "/opt/rocm/miopen/lib. If empty [default], dlopen " - "will search miopen from LD_LIBRARY_PATH"); - -PHI_DEFINE_string(rocm_dir, - "", - "Specify path for loading rocm library, such as librocblas, " - "libmiopen, libhipsparse. For instance, /opt/rocm/lib. " - "If default, dlopen will search rocm from LD_LIBRARY_PATH"); - -PHI_DEFINE_string(rccl_dir, - "", - "Specify path for loading rccl library, such as librccl.so. " - "For instance, /opt/rocm/rccl/lib. If default, " - "dlopen will search rccl from LD_LIBRARY_PATH"); -#endif - -// #ifdef PADDLE_WITH_FLAGCX -// COMMON_DECLARE_string(flagcx_dir); -// #endif - -// PHI_DEFINE_EXPORTED_string( -// flagcx_dir, // NOLINT -// "", -// "Specify path for loading libflagcx.so. For instance, " -// "For instance, /usr/local/flagcx/lib. If default, " -// "dlopen will search flagcx from LD_LIBRARY_PATH"); - -#ifdef PADDLE_WITH_XPU -PD_DEFINE_string(xpti_dir, "", "Specify path for loading libxpti.so."); -#endif - -namespace phi::dynload { - -struct PathNode { - PathNode() = default; - std::string path = ""; -}; - -static constexpr char cupti_lib_path[] = CUPTI_LIB_PATH; // NOLINT - -// NOTE: In order to adapt to the default installation path of cuda -#if defined(_WIN32) && defined(PADDLE_WITH_CUDA) -static constexpr char cuda_lib_path[] = CUDA_TOOLKIT_ROOT_DIR "/bin"; -#else -static constexpr char cuda_lib_path[] = "/usr/local/cuda/lib64"; // NOLINT -#endif - -static PathNode s_py_site_pkg_path; - -#if defined(_WIN32) && defined(PADDLE_WITH_CUDA) -static constexpr char* win_cudnn_lib = "cudnn64_" CUDNN_MAJOR_VERSION ".dll"; -static constexpr char* win_cublas_lib = - "cublas64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR - ".dll;cublas64_" CUDA_VERSION_MAJOR ".dll"; -#if CUDA_VERSION >= 11000 -static constexpr char* win_curand_lib = - "curand64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR - ".dll;curand64_" CUDA_VERSION_MAJOR ".dll;curand64_10.dll"; -static constexpr char* win_nvjpeg_lib = - "nvjpeg64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR - ".dll;nvjpeg64_" CUDA_VERSION_MAJOR ".dll;nvjpeg64_10.dll"; -static constexpr char* win_cusolver_lib = - "cusolver64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR - ".dll;cusolver64_" CUDA_VERSION_MAJOR - ".dll;cusolver64_11.dll;cusolver64_10.dll"; -static constexpr char* win_cusparse_lib = - "cusparse64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR - ".dll;cusparse64_" CUDA_VERSION_MAJOR ".dll;cusparse64_10.dll"; -static constexpr char* win_cufft_lib = - "cufft64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR - ".dll;cufft64_" CUDA_VERSION_MAJOR ".dll;cufft64_11.dll;cufft64_10.dll"; -#else -static constexpr char* win_curand_lib = - "curand64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR - ".dll;curand64_" CUDA_VERSION_MAJOR ".dll"; -static constexpr char* win_nvjpeg_lib = - "nvjpeg64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR - ".dll;nvjpeg64_" CUDA_VERSION_MAJOR ".dll"; -static constexpr char* win_cusolver_lib = - "cusolver64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR - ".dll;cusolver64_" CUDA_VERSION_MAJOR ".dll"; -static constexpr char* win_cusparse_lib = - "cusparse64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR - ".dll;cusparse64_" CUDA_VERSION_MAJOR ".dll"; -static constexpr char* win_cufft_lib = - "cufft64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR - ".dll;cufft64_" CUDA_VERSION_MAJOR ".dll"; -#endif // CUDA_VERSION -#endif - -static inline std::string join(const std::string& part1, - const std::string& part2) { -// directory separator -#if defined(_WIN32) - const char sep = '\\'; -#else - const char sep = '/'; -#endif - if (!part2.empty() && part2.front() == sep) { - return part2; - } - std::string ret; - ret.reserve(part1.size() + part2.size() + 1); - ret = part1; - if (!ret.empty() && ret.back() != sep) { - ret += sep; - } - ret += part2; - return ret; -} - -static inline std::vector split( - const std::string& str, const std::string separator = " ") { - std::vector str_list; - std::string::size_type firstPos = 0; - firstPos = str.find_first_not_of(separator, 0); - std::string::size_type lastPos = 0; - lastPos = str.find_first_of(separator, firstPos); - while (std::string::npos != firstPos && std::string::npos != lastPos) { - str_list.push_back(str.substr(firstPos, lastPos - firstPos)); - firstPos = str.find_first_not_of(separator, lastPos); - lastPos = str.find_first_of(separator, firstPos); - } - if (std::string::npos == lastPos) { - str_list.push_back(str.substr(firstPos, lastPos - firstPos)); - } - return str_list; -} - -void SetPaddleLibPath(const std::string& py_site_pkg_path) { - s_py_site_pkg_path.path = py_site_pkg_path; - VLOG(3) << "Set paddle lib path : " << py_site_pkg_path; -} - -static inline void* GetDsoHandleFromSpecificPath(const std::string& spec_path, - const std::string& dso_name, - int dynload_flags) { - void* dso_handle = nullptr; - if (!spec_path.empty()) { - // search xxx.so from custom path - VLOG(3) << "Try to find library: " << dso_name - << " from specific path: " << spec_path; - std::string dso_path = join(spec_path, dso_name); - dso_handle = dlopen(dso_path.c_str(), dynload_flags); - } - return dso_handle; -} - -static inline std::string FindLibAbsolutePath(const std::string& directory, - const std::string& filename) { - DIR* dir = opendir(directory.c_str()); - struct dirent* ent; - - if (dir != nullptr) { - while ((ent = readdir(dir)) != nullptr) { - if (ent->d_type == DT_REG || ent->d_type == DT_LNK) { - if (filename == std::string(ent->d_name)) { - closedir(dir); - return join(directory, ent->d_name); - } - } else if (ent->d_type == DT_DIR) { - if (strcmp(ent->d_name, ".") != 0 && strcmp(ent->d_name, "..") != 0) { - std::string res = - FindLibAbsolutePath(join(directory, ent->d_name) + "/", filename); - if (!res.empty()) { - closedir(dir); - return res; - } - } - } - } - closedir(dir); - } - return ""; -} - -static inline void* GetDsoHandleFromDefaultPath(const std::string& dso_path, - int dynload_flags) { - // default search from LD_LIBRARY_PATH/DYLD_LIBRARY_PATH - // and /usr/local/lib path - void* dso_handle = dlopen(dso_path.c_str(), dynload_flags); - VLOG(3) << "Try to find library: " << dso_path - << " from default system path."; - -// TODO(chenweihang): This path is used to search which libs? -// DYLD_LIBRARY_PATH is disabled after Mac OS 10.11 to -// bring System Integrity Projection (SIP), if dso_handle -// is null, search from default package path in Mac OS. -#if defined(__APPLE__) || defined(__OSX__) -#if defined(__arm__) || defined(__aarch64__) - if (nullptr == dso_handle) { - dso_handle = - dlopen(FindLibAbsolutePath("/opt/homebrew/Cellar/", dso_path).c_str(), - dynload_flags); - } -#else - if (nullptr == dso_handle) { - dso_handle = - dlopen(FindLibAbsolutePath("/usr/local/cuda/lib/", dso_path).c_str(), - dynload_flags); - } -#endif -#endif - - return dso_handle; -} - -/* - * We define three priorities for dynamic library search: - * - * First: Search for path specified by the user - * Second: Search the stheystem default path - * Third: Search for a special path corresponding to - * a specific library to adapt to changes and easy to expand. - */ - -static inline void* GetDsoHandleFromSearchPath( - const std::string& config_path, - const std::string& dso_name, - bool throw_on_error = true, - const std::vector& extra_paths = std::vector(), - const std::string& warning_msg = std::string()) { -#if !defined(_WIN32) - int dynload_flags = RTLD_LAZY | RTLD_LOCAL; -#else - int dynload_flags = 0; -#endif // !_WIN32 -#if defined(_WIN32) - std::vector cuda_bin_search_path = { - L"cublas", - L"cuda_nvrtc", - L"cuda_runtime", - L"cudnn", - L"cufft", - L"curand", - L"cusolver", - L"cusparse", - L"nvjitlink", - }; - for (auto search_path : cuda_bin_search_path) { - std::wstring_convert> converter; - std::wstring win_path_wstring = - converter.from_bytes(FLAGS_win_cuda_bin_dir); - search_path = win_path_wstring + L"\\" + search_path + L"\\bin"; - AddDllDirectory(search_path.c_str()); - } -#endif - std::vector dso_names = split(dso_name, ";"); - void* dso_handle = nullptr; - for (auto const& dso : dso_names) { - // 1. search in user config path by FLAGS - dso_handle = GetDsoHandleFromSpecificPath(config_path, dso, dynload_flags); - // 2. search in system default path - if (nullptr == dso_handle) { - dso_handle = GetDsoHandleFromDefaultPath(dso, dynload_flags); - } - // 3. search in extra paths - if (nullptr == dso_handle) { - for (auto const& path : extra_paths) { - VLOG(3) << "extra_paths: " << path; - dso_handle = GetDsoHandleFromSpecificPath(path, dso, dynload_flags); - } - } - if (nullptr != dso_handle) break; - } - - // 4. [If Failed for All dso_names] logging warning if exists - if (nullptr == dso_handle && !warning_msg.empty()) { - LOG(WARNING) << warning_msg; - } - - // 5. [If Failed for All dso_names] logging or throw error info - if (nullptr == dso_handle) { - auto error_msg = - "The third-party dynamic library (%s) that Paddle depends on is not " - "configured correctly. (error code is %s)\n" - " Suggestions:\n" - " 1. Check if the third-party dynamic library (e.g. CUDA, CUDNN) " - "is installed correctly and its version is matched with paddlepaddle " - "you installed.\n" - " 2. Configure third-party dynamic library environment variables as " - "follows:\n" - " - Linux: set LD_LIBRARY_PATH by `export LD_LIBRARY_PATH=...`\n" - " - Windows: set PATH by `set PATH=XXX;%%PATH%%`\n" - " - Mac: set DYLD_LIBRARY_PATH by `export DYLD_LIBRARY_PATH=...` " - "[Note: After Mac OS 10.11, using the DYLD_LIBRARY_PATH is " - "impossible unless System Integrity Protection (SIP) is disabled.]"; -#if !defined(_WIN32) - auto errorno = dlerror(); -#else - auto errorno = GetLastError(); -#endif // !_WIN32 - if (throw_on_error) { - // NOTE: Special error report case, no need to change its format - PADDLE_THROW( - common::errors::PreconditionNotMet(error_msg, dso_name, errorno)); - } else { - LOG(WARNING) << paddle::string::Sprintf(error_msg, dso_name, errorno); - } - } - - return dso_handle; -} - -void* GetCublasDsoHandle() { -#if defined(__APPLE__) || defined(__OSX__) - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcublas.dylib"); -#elif defined(_WIN32) && defined(PADDLE_WITH_CUDA) - if (CUDA_VERSION >= 11000 && CUDA_VERSION < 12000) { -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "cublas64_11.dll"); -#else - return GetDsoHandleFromSearchPath( - FLAGS_cuda_dir, win_cublas_lib, true, {cuda_lib_path}); -#endif - } else if (CUDA_VERSION >= 12000 && CUDA_VERSION < 13000) { -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "cublas64_12.dll"); -#else - return GetDsoHandleFromSearchPath( - FLAGS_cuda_dir, win_cublas_lib, true, {cuda_lib_path}); -#endif - } else { - std::string warning_msg( - "Your CUDA_VERSION is less than 11 or greater than 13, paddle " - "temporarily no longer supports"); - return nullptr; - } - -#elif defined(__linux__) && defined(PADDLE_WITH_CUDA) - if (CUDA_VERSION >= 11000 && CUDA_VERSION < 12000) { -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath(FLAGS_cublas_dir, "libcublas.so.11"); -#else - // VLOG(0) << "dynload:libmcblas.so: CUDA_VERSION=11000-12000 start" ; - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libmcblas.so"); - // VLOG(0) << "dynload:libmcblas.so: CUDA_VERSION=11000-12000 end" ; -#endif - } else if (CUDA_VERSION >= 12000 && CUDA_VERSION < 13000) { -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath(FLAGS_cublas_dir, "libcublas.so.12"); -#else - // VLOG(0) << "dynload:libmcblas.so: CUDA_VERSION=12000-13000 start" ; - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libmcblas.so"); - // VLOG(0) << "dynload:libmcblas.so: CUDA_VERSION=12000-13000 end" ; -#endif - } else { - std::string warning_msg( - "Your CUDA_VERSION is less than 11 or greater than 13, paddle " - "temporarily no longer supports"); - return nullptr; - } -#elif defined(PADDLE_WITH_HIP) - return GetDsoHandleFromSearchPath(FLAGS_rocm_dir, "librocblas.so"); -#else - // VLOG(0) << "dynload:libmcblas.so: CUDA_VERSION=else start" ; - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libmcblas.so"); - // VLOG(0) << "dynload:libmcblas.so: CUDA_VERSION=else end" ; -// return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libmcblas.so"); -#endif -} - -void* GetCublasLtDsoHandle() { -// APIs available after CUDA 10.1 -#if defined(__linux__) && defined(PADDLE_WITH_CUDA) - if (CUDA_VERSION >= 11000 && CUDA_VERSION < 12000) { -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath(FLAGS_cublas_dir, "libcublasLt.so.11"); -#else - // return GetDsoHandleFromSearchPath(FLAGS_cublas_dir, "libcublasLt.so"); - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libmcblasLt.so"); -#endif - } else if (CUDA_VERSION >= 12000 && CUDA_VERSION < 13000) { -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath(FLAGS_cublas_dir, "libcublasLt.so.12"); -#else - // return GetDsoHandleFromSearchPath(FLAGS_cublas_dir, "libcublasLt.so"); - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libmcblasLt.so"); -#endif - } else { - std::string warning_msg( - "Your CUDA_VERSION is less than 11 or greater than 13, paddle " - "temporarily no longer supports"); - return nullptr; - } -#elif defined(_WIN32) && defined(PADDLE_WITH_CUDA) - if (CUDA_VERSION >= 11000 && CUDA_VERSION < 12000) { -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "cublasLt64_11.dll"); -#else - return GetDsoHandleFromSearchPath( - FLAGS_cuda_dir, win_cublas_lib, true, {cuda_lib_path}); -#endif - } else if (CUDA_VERSION >= 12000 && CUDA_VERSION < 13000) { -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "cublasLt64_12.dll"); -#else - return GetDsoHandleFromSearchPath( - FLAGS_cuda_dir, win_cublas_lib, true, {cuda_lib_path}); -#endif - } else { - std::string warning_msg( - "Your CUDA_VERSION is less than 11 or greater than 12, paddle " - "temporarily no longer supports"); - return nullptr; - } -#elif !defined(__linux__) && defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10010 - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcublasLt.so"); -#elif defined(PADDLE_WITH_HIP) - return GetDsoHandleFromSearchPath(FLAGS_rocm_dir, "libhipblaslt.so"); -#else - std::string warning_msg( - "Your CUDA_VERSION less 10.1, not support CublasLt. " - "If you want to use CublasLt, please upgrade CUDA and rebuild " - "PaddlePaddle."); - return nullptr; -#endif -} - -void* GetCUDNNDsoHandle() { -#if defined(__APPLE__) || defined(__OSX__) - std::string mac_warn_meg( - "Note: [Recommend] copy cudnn into /usr/local/cuda/ \n " - "For instance, sudo tar -xzf " - "cudnn-7.5-osx-x64-v5.0-ga.tgz -C /usr/local \n sudo " - "chmod a+r /usr/local/cuda/include/cudnn.h " - "/usr/local/cuda/lib/libcudnn*"); - return GetDsoHandleFromSearchPath( - FLAGS_cudnn_dir, "libcudnn.dylib", false, {}, mac_warn_meg); -#elif defined(_WIN32) && defined(PADDLE_WITH_CUDA) - std::string win_warn_meg( - "Note: [Recommend] copy cudnn into CUDA installation directory. \n " - "For instance, download cudnn-10.0-windows10-x64-v7.6.5.32.zip from " - "NVIDIA's official website, \n" - "then, unzip it and copy it into C:\\Program Files\\NVIDIA GPU Computing " - "Toolkit\\CUDA\\v10.0\n" - "You should do this according to your CUDA installation directory and " - "CUDNN version."); - if (CUDA_VERSION >= 11000 && CUDA_VERSION < 12030) { -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath( - FLAGS_cuda_dir, "cudnn64_8.dll", true, {cuda_lib_path}, win_warn_meg); -#else - return GetDsoHandleFromSearchPath( - FLAGS_cuda_dir, win_cudnn_lib, true, {cuda_lib_path}, win_warn_meg); -#endif - } else if (CUDA_VERSION >= 12030) { -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath( - FLAGS_cuda_dir, "cudnn64_9.dll", true, {cuda_lib_path}, win_warn_meg); -#else - return GetDsoHandleFromSearchPath( - FLAGS_cuda_dir, win_cudnn_lib, true, {cuda_lib_path}, win_warn_meg); -#endif - } -#elif defined(PADDLE_WITH_HIP) - return GetDsoHandleFromSearchPath(FLAGS_miopen_dir, "libMIOpen.so", false); -#else -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - if (CUDA_VERSION >= 12030) { - return GetDsoHandleFromSearchPath( - FLAGS_cudnn_dir, "libcudnn.so.9", false, {cuda_lib_path}); - } else { - return GetDsoHandleFromSearchPath( - FLAGS_cudnn_dir, "libcudnn.so.8", false, {cuda_lib_path}); - } -#else - return GetDsoHandleFromSearchPath( - FLAGS_cudnn_dir, "libmcdnn.so", false, {cuda_lib_path}); -#endif -#endif -} - -void* GetCUPTIDsoHandle() { -#if defined(__APPLE__) || defined(__OSX__) - return GetDsoHandleFromSearchPath( - FLAGS_cupti_dir, "libcupti.dylib", false, {cupti_lib_path}); -#elif defined(__linux__) && defined(PADDLE_WITH_CUDA) - if (CUDA_VERSION >= 11000 && CUDA_VERSION < 12000) { -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath( - FLAGS_cupti_dir, "libcupti.so.11.8", false, {cupti_lib_path}); -#else - return GetDsoHandleFromSearchPath( - FLAGS_cupti_dir, "libmcpti.so", false, {cupti_lib_path}); -#endif - - } else if (CUDA_VERSION >= 12000 && CUDA_VERSION < 13000) { -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath( - FLAGS_cupti_dir, "libcupti.so.12", false, {cupti_lib_path}); -#else - return GetDsoHandleFromSearchPath( - FLAGS_cupti_dir, "libmcpti.so", false, {cupti_lib_path}); -#endif - } else { - std::string warning_msg( - "Your CUDA_VERSION is less than 11 or greater than 13, paddle " - "temporarily no longer supports"); - return nullptr; - } -#else - return GetDsoHandleFromSearchPath( - FLAGS_cupti_dir, "libmcpti.so", false, {cupti_lib_path}); -#endif -} - -void* GetCurandDsoHandle() { -#if defined(__APPLE__) || defined(__OSX__) - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcurand.dylib"); -#elif defined(_WIN32) && defined(PADDLE_WITH_CUDA) -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath( - FLAGS_cuda_dir, "curand64_10.dll", true, {cuda_lib_path}); -#else - return GetDsoHandleFromSearchPath( - FLAGS_cuda_dir, win_curand_lib, true, {cuda_lib_path}); -#endif -#elif defined(PADDLE_WITH_HIP) - return GetDsoHandleFromSearchPath(FLAGS_rocm_dir, "libhiprand.so"); -#else -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath(FLAGS_curand_dir, "libcurand.so.10"); -#else - return GetDsoHandleFromSearchPath(FLAGS_curand_dir, "libcurand.so"); -#endif - -#endif -} - -#ifdef PADDLE_WITH_HIP -void* GetROCFFTDsoHandle() { -#if defined(__APPLE__) || defined(__OSX__) - return GetDsoHandleFromSearchPath(FLAGS_rocm_dir, "librocfft.dylib"); -#else - return GetDsoHandleFromSearchPath(FLAGS_rocm_dir, "libhipfft.so"); -#endif -} -#endif - -void* GetNvjpegDsoHandle() { -#if defined(__APPLE__) || defined(__OSX__) - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libnvjpeg.dylib"); -#elif defined(_WIN32) && defined(PADDLE_WITH_CUDA) - return GetDsoHandleFromSearchPath( - FLAGS_cuda_dir, win_nvjpeg_lib, true, {cuda_lib_path}); -#else - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libnvjpeg.so"); -#endif -} - -void* GetCusolverDsoHandle() { -#if defined(__APPLE__) || defined(__OSX__) - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcusolver.dylib"); -#elif defined(_WIN32) && defined(PADDLE_WITH_CUDA) -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath( - FLAGS_cuda_dir, "cusolver64_11.dll", true, {cuda_lib_path}); -#else - return GetDsoHandleFromSearchPath( - FLAGS_cuda_dir, win_cusolver_lib, true, {cuda_lib_path}); -#endif -#elif defined(PADDLE_WITH_HIP) - return GetDsoHandleFromSearchPath(FLAGS_rocm_dir, "librocsolver.so"); -#else -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcusolver.so.11"); -#else - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libmcsolver.so"); -#endif -#endif -} - -void* GetCusparseDsoHandle() { -#if defined(__APPLE__) || defined(__OSX__) - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcusparse.dylib"); -#elif defined(_WIN32) && defined(PADDLE_WITH_CUDA) - if (CUDA_VERSION >= 11000 && CUDA_VERSION < 12000) { -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "cusparse64_11.dll"); -#else - return GetDsoHandleFromSearchPath( - FLAGS_cuda_dir, win_cusparse_lib, true, {cuda_lib_path}); -#endif - } else if (CUDA_VERSION >= 12000 && CUDA_VERSION < 13000) { -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "cusparse64_12.dll"); -#else - return GetDsoHandleFromSearchPath( - FLAGS_cuda_dir, win_cusparse_lib, true, {cuda_lib_path}); -#endif - } else { - std::string warning_msg( - "Your CUDA_VERSION is less than 11 or greater than 13, paddle " - "temporarily no longer supports"); - return nullptr; - } -#elif defined(__linux__) && defined(PADDLE_WITH_CUDA) - if (CUDA_VERSION >= 11000 && CUDA_VERSION < 12000) { -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath(FLAGS_cusparse_dir, "libcusparse.so.11"); -#else - return GetDsoHandleFromSearchPath(FLAGS_cusparse_dir, "libmcsparse.so"); -#endif - } else if (CUDA_VERSION >= 12000 && CUDA_VERSION < 13000) { -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath(FLAGS_cusparse_dir, "libcusparse.so.12"); -#else - return GetDsoHandleFromSearchPath(FLAGS_cusparse_dir, "libmcsparse.so"); -#endif - } else { - std::string warning_msg( - "Your CUDA_VERSION is less than 11 or greater than 12, paddle " - "temporarily no longer."); - return nullptr; - } -#elif defined(PADDLE_WITH_HIP) - return GetDsoHandleFromSearchPath(FLAGS_rocm_dir, "librocsparse.so"); -#else - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libmcsparse.so"); -#endif -} - -void* GetNVRTCDsoHandle() { -#if defined(__APPLE__) || defined(__OSX__) - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libnvrtc.dylib", false); -#elif defined(PADDLE_WITH_HIP) - return GetDsoHandleFromSearchPath(FLAGS_rocm_dir, "libamdhip64.so", false); -#else - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libmcruntime.so", false); -#endif -} - -void* GetCUDADsoHandle() { -#if defined(__APPLE__) || defined(__OSX__) - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcuda.dylib", false); -#elif defined(PADDLE_WITH_HIP) - return GetDsoHandleFromSearchPath(FLAGS_rocm_dir, "libamdhip64.so", false); -#elif defined(_WIN32) - char system32_dir[MAX_PATH]; - GetSystemDirectory(system32_dir, MAX_PATH); - return GetDsoHandleFromSearchPath(system32_dir, "nvcuda.dll"); -#else - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libmcruntime.so", false); -#endif -} - -void* GetWarpCTCDsoHandle() { - std::string warpctc_dir = ""; - if (!s_py_site_pkg_path.path.empty()) { - warpctc_dir = s_py_site_pkg_path.path; - } -#if defined(__APPLE__) || defined(__OSX__) - return GetDsoHandleFromSearchPath(warpctc_dir, "libwarpctc.dylib"); -#elif defined(_WIN32) - return GetDsoHandleFromSearchPath(warpctc_dir, "warpctc.dll"); -#else - return GetDsoHandleFromSearchPath(warpctc_dir, "libwarpctc.so"); -#endif -} - -void* GetWarpRNNTDsoHandle() { - std::string warprnnt_dir = ""; - if (!s_py_site_pkg_path.path.empty()) { - warprnnt_dir = s_py_site_pkg_path.path; - } -#if defined(__APPLE__) || defined(__OSX__) - return GetDsoHandleFromSearchPath(warprnnt_dir, "libwarprnnt.dylib"); -#elif defined(_WIN32) - return GetDsoHandleFromSearchPath(warprnnt_dir, "warprnnt.dll"); -#else - return GetDsoHandleFromSearchPath(warprnnt_dir, "libwarprnnt.so"); -#endif -} - -void* GetFlashAttnDsoHandle() { - std::string flashattn_dir = ""; - if (!s_py_site_pkg_path.path.empty()) { - flashattn_dir = s_py_site_pkg_path.path; - } -#if defined(__APPLE__) || defined(__OSX__) - return GetDsoHandleFromSearchPath(flashattn_dir, "libflashattn.dylib"); -#elif defined(_WIN32) - return GetDsoHandleFromSearchPath(flashattn_dir, "flashattn.dll"); -#else - return GetDsoHandleFromSearchPath(flashattn_dir, "libmcFlashAttn.so"); -#endif -} - -void* GetFlashAttnV3DsoHandle() { - std::string flashattn_dir = ""; - if (!s_py_site_pkg_path.path.empty()) { - flashattn_dir = s_py_site_pkg_path.path; - } -#if defined(__APPLE__) || defined(__OSX__) - return GetDsoHandleFromSearchPath(flashattn_dir, "libflashattnv3.dylib"); -#elif defined(_WIN32) - return GetDsoHandleFromSearchPath(flashattn_dir, "flashattnv3.dll"); -#else - return GetDsoHandleFromSearchPath(flashattn_dir, "libflashattnv3.so"); -#endif -} - -void* GetAfsApiDsoHandle() { - std::string afsapi_dir = ""; - if (!s_py_site_pkg_path.path.empty()) { - afsapi_dir = s_py_site_pkg_path.path; - } -#if defined(__APPLE__) || defined(__OSX__) || defined(_WIN32) - return NULL; -#else - return GetDsoHandleFromSearchPath(afsapi_dir, "libafs-api-so.so"); -#endif -} - -void* GetNCCLDsoHandle() { -#ifdef PADDLE_WITH_HIP - std::string warning_msg( - "You may need to install 'rccl' from ROCM official website: " - "https://rocmdocs.amd.com/en/latest/Installation_Guide/" - "Installation-Guide.html before install PaddlePaddle."); -#else - std::string warning_msg( - "You may need to install 'nccl2' from NVIDIA official website: " - "https://developer.nvidia.com/nccl/nccl-download " - "before install PaddlePaddle."); -#endif - -#if defined(__APPLE__) || defined(__OSX__) - return GetDsoHandleFromSearchPath( - FLAGS_nccl_dir, "libnccl.dylib", true, {}, warning_msg); -#elif defined(PADDLE_WITH_HIP) && defined(PADDLE_WITH_RCCL) - return GetDsoHandleFromSearchPath( - FLAGS_rccl_dir, "librccl.so", true, {}, warning_msg); -#else -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath( - FLAGS_nccl_dir, "libnccl.so;libnccl.so.2", true, {}, warning_msg); -#else - return GetDsoHandleFromSearchPath( - FLAGS_nccl_dir, "libmccl.so", true, {}, warning_msg); -#endif - -#endif -} - -// void* GetFLAGCXDsoHandle() { -// #ifdef PADDLE_WITH_FLAGCX -// return GetDsoHandleFromSearchPath(FLAGS_flagcx_dir, "libflagcx.so"); -// #else -// return nullptr; -// #endif -// } - -void* GetTensorRtDsoHandle() { -#if defined(__APPLE__) || defined(__OSX__) - return GetDsoHandleFromSearchPath(FLAGS_tensorrt_dir, "libnvinfer.dylib"); -#elif defined(_WIN32) - return GetDsoHandleFromSearchPath(FLAGS_mklml_dir, "nvinfer.dll"); -#else - return GetDsoHandleFromSearchPath(FLAGS_tensorrt_dir, "libnvinfer.so"); -#endif -} - -void* GetMKLMLDsoHandle() { -#if defined(__APPLE__) || defined(__OSX__) - return GetDsoHandleFromSearchPath(FLAGS_mklml_dir, "libmklml_intel.dylib"); -#elif defined(_WIN32) - return GetDsoHandleFromSearchPath(FLAGS_mklml_dir, "mklml.dll"); -#else - return GetDsoHandleFromSearchPath(FLAGS_mklml_dir, "libmklml_intel.so"); -#endif -} - -void* GetLAPACKDsoHandle() { -#if defined(__APPLE__) || defined(__OSX__) -#if defined(__arm__) || defined(__aarch64__) - return GetDsoHandleFromSearchPath(FLAGS_lapack_dir, "liblapack.dylib"); -#else - return GetDsoHandleFromSearchPath(FLAGS_lapack_dir, "liblapack.3.dylib"); -#endif -#elif defined(_WIN32) - return GetDsoHandleFromSearchPath(FLAGS_lapack_dir, "liblapack.dll"); -#else - return GetDsoHandleFromSearchPath(FLAGS_lapack_dir, "liblapack.so.3"); -#endif -} - -void* GetOpDsoHandle(const std::string& dso_name) { - return GetDsoHandleFromSearchPath(FLAGS_op_dir, dso_name); -} - -void* GetNvtxDsoHandle() { -#if defined(__APPLE__) || defined(__OSX__) - PADDLE_THROW(common::errors::Unimplemented("Nvtx do not support Apple.")); -#elif defined(_WIN32) - PADDLE_THROW(common::errors::Unimplemented("Nvtx do not support Windows.")); -#elif !defined(PADDLE_WITH_CUDA) - PADDLE_THROW( - common::errors::Unimplemented("Nvtx do not support without CUDA.")); -#else - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libnvToolsExt.so"); -#endif -} - -void* GetCUFFTDsoHandle() { -#if defined(__APPLE__) || defined(__OSX__) - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcufft.dylib"); -#elif defined(__linux__) && defined(PADDLE_WITH_CUDA) - if (CUDA_VERSION >= 11000 && CUDA_VERSION < 12000) { -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcufft.so.10"); -#else - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcufft.so"); -#endif - } else if (CUDA_VERSION >= 12000 && CUDA_VERSION < 13000) { - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcufft.so.11"); - } else { - std::string warning_msg( - "Your CUDA_VERSION is less than 11 or greater than 13, paddle " - "temporarily no longer."); - return nullptr; - } -#elif defined(_WIN32) && defined(PADDLE_WITH_CUDA) - if (CUDA_VERSION >= 11000 && CUDA_VERSION < 12000) { -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "cufft64_10.dll"); -#else - return GetDsoHandleFromSearchPath( - FLAGS_cuda_dir, win_cufft_lib, true, {cuda_lib_path}); -#endif - } else if (CUDA_VERSION >= 12000 && CUDA_VERSION < 13000) { -#ifdef PADDLE_WITH_PIP_CUDA_LIBRARIES - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "cufft64_11.dll"); -#else - return GetDsoHandleFromSearchPath( - FLAGS_cuda_dir, win_cufft_lib, true, {cuda_lib_path}); -#endif - } else { - std::string warning_msg( - "Your CUDA_VERSION is less than 11 or greater than 13, paddle " - "temporarily no longer supports"); - return nullptr; - } -#else - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcufft.so"); -#endif -} - -void* GetMKLRTDsoHandle() { -#if defined(__APPLE__) || defined(__OSX__) - return GetDsoHandleFromSearchPath(FLAGS_mkl_dir, "libmkl_rt.dylib"); -#elif defined(_WIN32) - return GetDsoHandleFromSearchPath(FLAGS_mkl_dir, "mkl_rt.dll"); -#else - return GetDsoHandleFromSearchPath(FLAGS_mkl_dir, "libmkl_rt.so"); -#endif -} - -void* GetCusparseLtDsoHandle() { -// APIs available after CUDA 11.2 -#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11020 && 0 - return GetDsoHandleFromSearchPath(FLAGS_cusparselt_dir, "libcusparseLt.so"); -#else - std::string warning_msg( - "Your CUDA_VERSION less 11.2, not support cusparseLt. " - "If you want to use cusparseLt, please upgrade CUDA and rebuild " - "PaddlePaddle."); - return nullptr; -#endif -} - -void* GetXPTIDsoHandle() { -#ifdef PADDLE_WITH_XPTI - return GetDsoHandleFromSearchPath(FLAGS_xpti_dir, "libxpti.so"); -#else - return nullptr; -#endif -} -} // namespace phi::dynload diff --git a/backends/metax_gpu/kernels/dynload/dynamic_loader.h b/backends/metax_gpu/kernels/dynload/dynamic_loader.h deleted file mode 100644 index a5d3d0ff76c..00000000000 --- a/backends/metax_gpu/kernels/dynload/dynamic_loader.h +++ /dev/null @@ -1,61 +0,0 @@ -// 2024 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights -// Reserved. -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include - -#include "paddle/utils/test_macros.h" -namespace phi { -namespace dynload { - -#ifndef _WIN32 -#define DECLARE_TYPE(__name, ...) decltype(__name(__VA_ARGS__)) -#else -#define DECLARE_TYPE(__name, ...) decltype(auto) -#endif - -void* GetCublasDsoHandle(); -void* GetCublasLtDsoHandle(); -TEST_API void* GetCUDNNDsoHandle(); -void* GetCUPTIDsoHandle(); -void* GetCurandDsoHandle(); -void* GetNvjpegDsoHandle(); -void* GetCusolverDsoHandle(); -void* GetCusparseDsoHandle(); -void* GetNVRTCDsoHandle(); -void* GetCUDADsoHandle(); -void* GetWarpCTCDsoHandle(); -void* GetWarpRNNTDsoHandle(); -void* GetFlashAttnDsoHandle(); -void* GetFlashAttnV3DsoHandle(); -void* GetNCCLDsoHandle(); -// void* GetFLAGCXDsoHandle(); -void* GetTensorRtDsoHandle(); -void* GetMKLMLDsoHandle(); -void* GetLAPACKDsoHandle(); -void* GetOpDsoHandle(const std::string& dso_name); -void* GetNvtxDsoHandle(); -void* GetCUFFTDsoHandle(); -void* GetMKLRTDsoHandle(); -void* GetROCFFTDsoHandle(); -void* GetCusparseLtDsoHandle(); -void* GetXPTIDsoHandle(); -void* GetAfsApiDsoHandle(); - -void SetPaddleLibPath(const std::string&); - -} // namespace dynload -} // namespace phi diff --git a/backends/metax_gpu/kernels/funcs/affine_grid_utils.h b/backends/metax_gpu/kernels/funcs/affine_grid_utils.h index c137d9ad468..b973d75a9be 100644 --- a/backends/metax_gpu/kernels/funcs/affine_grid_utils.h +++ b/backends/metax_gpu/kernels/funcs/affine_grid_utils.h @@ -14,8 +14,8 @@ #pragma once -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/backends/metax_gpu/kernels/funcs/blas/blas.cc b/backends/metax_gpu/kernels/funcs/blas/blas.cc deleted file mode 100644 index 098a0400552..00000000000 --- a/backends/metax_gpu/kernels/funcs/blas/blas.cc +++ /dev/null @@ -1,59 +0,0 @@ -// 2024 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights -// Reserved. -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// clang-format off -#include "funcs/blas/blas.h" // NOLINT -#include "paddle/phi/core/enforce.h" -// clang-format on -namespace phi { -namespace funcs { -MatDescriptor CreateMatrixDescriptor(const DDim &tensor_dim, - int num_flatten_cols, - bool trans) { - PADDLE_ENFORCE_GT( - tensor_dim.size(), - 1, - phi::errors::InvalidArgument("The tensor dim size should be greater " - "than 1, but reveived dim size is %d", - tensor_dim.size())); - MatDescriptor retv; - if (num_flatten_cols > 1) { - auto flatten_dim = common::flatten_to_2d(tensor_dim, num_flatten_cols); - retv.height_ = flatten_dim[0]; - retv.width_ = flatten_dim[1]; - } else { - if (tensor_dim.size() == 2) { - retv.height_ = tensor_dim[0]; - retv.width_ = tensor_dim[1]; - } else { - auto dim_vec = common::vectorize(tensor_dim); - retv.batch_size_ = 1; - for (size_t i = 0; i < dim_vec.size() - 2; ++i) { - retv.batch_size_ *= dim_vec[i]; - } - retv.height_ = dim_vec[dim_vec.size() - 2]; - retv.width_ = dim_vec[dim_vec.size() - 1]; - retv.stride_ = retv.height_ * retv.width_; - } - } - if (trans) { - std::swap(retv.width_, retv.height_); - } - retv.trans_ = trans; - return retv; -} -} // namespace funcs -} // namespace phi diff --git a/backends/metax_gpu/kernels/funcs/blas/blas.h b/backends/metax_gpu/kernels/funcs/blas/blas.h deleted file mode 100644 index 75ea8c921e2..00000000000 --- a/backends/metax_gpu/kernels/funcs/blas/blas.h +++ /dev/null @@ -1,631 +0,0 @@ -// 2024 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights -// Reserved. -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/phi/core/dense_tensor.h" - -#ifdef PADDLE_WITH_MKLML -#include "paddle/phi/backends/dynload/mklml.h" -#endif - -#ifdef PADDLE_WITH_LIBXSMM -#include -#endif - -#if defined(PADDLE_USE_OPENBLAS) || defined(PADDLE_USE_REFERENCE_CBLAS) -#include -#endif -// #include "paddle/phi/core/enforce_metax.h" -namespace phi { -namespace funcs { - -/** - * Matrix Descriptor of a memory buffer. - * - * It is used for Blas::MatMul. MatMul operator can be batched. - * if Mat A is [BatchSize, H, W], Mat B is [BatchSize, H, W]. It will be a - * `batch_size` times of GEMM. The batched GEMM could be faster base on the - * implementation of the blas library. The batch size could be zero. If any - * matrix of `matmul` has a batch size, there will be a batched GEMM, too. e.g., - * Mat A is [BatchSize, H1, W2], and Mat B [H2, W2], The result matrix wil be - * [BatchSize, H1, W2] - * - * The boolean flag, `trans`, describe the memory is the transpose of matrix or - * not. If the trans is true, the last two dims of matrix are transposed. The - * memory layout of the matrix is [Width, Height] or [BatchSize, Width, Height]. - * - * The MatDescriptor is not only the dimension or shape of a matrix, it also - * contains the layout, stride of matrix. It is clearer to have a structure than - * reuse `DDim`. - */ -struct MatDescriptor { - int64_t height_; - int64_t width_; - int64_t stride_{0}; - int64_t batch_size_{0}; - bool trans_; -}; - -/** - * Create Matrix Descriptor from a tensor dim, num_flatten_cols, and transpose - * flag - * - * @param tensor_dim: The dimension of the tensor. The rank of this dimension - * must larger than 1. - * - * @param num_flatten_cols: Reshape a tensor to a matrix. The matrix's first - * dimension(column length) will be the product of tensor's first `num_col_dims` - * dimensions. If num_flatten_cols is zero, the first N-2 dimension will be the - * batch_size of descriptor. - * - * @param trans: True if the matrix is transposed. - */ -extern MatDescriptor CreateMatrixDescriptor(const DDim& tensor_dim, - int num_flatten_cols, - bool trans); - -template -class Blas { - public: - explicit Blas(const DeviceContext& context) : dev_ctx_(context) {} - - template - void GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int64_t M, - int64_t N, - int64_t K, - T alpha, - const T* A, - const T* B, - T beta, - T* C) const; - - template - void GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int64_t M, - int64_t N, - int64_t K, - U alpha, - const T* A, - const T* B, - U beta, - T* C) const; - - template - void GEMM(bool transA, - bool transB, - int M, - int N, - int K, - T alpha, - const T* A, - int lda, - const T* B, - int ldb, - T beta, - T* C, - int ldc) const; - - template - void GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int M, - int N, - int K, - T alpha, - const T* A, - int lda, - const T* B, - int ldb, - T beta, - T* C, - int ldc) const; - -#ifdef PADDLE_WITH_MKLML // @{ Group MKLML: class Blas - template - T* GEMM_ALLOC(const CBLAS_IDENTIFIER id, - const int M, - const int N, - const int K) const; - - template - void GEMM_PACK(const CBLAS_IDENTIFIER id, - const CBLAS_TRANSPOSE trans, - int M, - int N, - int K, - const T alpha, - const T* src, - const int ld, - T* dst) const; - - template - void GEMM_COMPUTE(int transA, - int transB, - int M, - int N, - int K, - const T* A, - const int lda, - const T* B, - const int ldb, - T beta, - T* C, - const int ldc) const; - - template - void GEMM_FREE(T* data) const; - - template - void CSRMM(const char* transa, - const int* m, - const int* n, - const int* k, - const T* alpha, - const char* matdescra, - const T* val, - const int* indx, - const int* pntrb, - const int* pntre, - const T* b, - const int* ldb, - const T* beta, - T* c, - const int* ldc) const; - -#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) - template - void MatMulWithHead(const phi::DenseTensor& mat_a, - const MatDescriptor& dim_a, - const phi::DenseTensor& mat_b, - const MatDescriptor& dim_b, - T alpha, - int head_number, - phi::DenseTensor* mat_out, - T beta, - bool mat_y_split_vertical) const; -#endif -#endif // @} End Group MKLML: class Blas - - template - void MatMul(const int M, - const int N, - const int K, - const T* A, - const T* B, - T* C) const; - - template - void MatMul(const phi::DenseTensor& mat_a, - bool trans_a, - const phi::DenseTensor& mat_b, - bool trans_b, - T alpha, - phi::DenseTensor* mat_out, - T beta) const; - - template - void MatMul(const phi::DenseTensor& mat_a, - bool trans_a, - const phi::DenseTensor& mat_b, - bool trans_b, - phi::DenseTensor* mat_out) const { - MatMul(mat_a, - trans_a, - mat_b, - trans_b, - static_cast(1.0), - mat_out, - static_cast(0.0)); - } - - template - void MatMul(const phi::DenseTensor& mat_a, - const phi::DenseTensor& mat_b, - phi::DenseTensor* mat_out) const { - this->template MatMul(mat_a, false, mat_b, false, mat_out); - } - - template - void AXPY(int n, T alpha, const T* x, T* y) const; - - template - void VADD(int n, const T* x, const T* y, T* z) const; - - template - void VSUB(int n, const T* x, const T* y, T* z) const; - - template - void VMUL(int n, const T* x, const T* y, T* z) const; - - template - void VDIV(int n, const T* x, const T* y, T* z) const; - - template - void VCOPY(int n, const T* x, T* y) const; - - template - void VEXP(int n, const T* x, T* y) const; - - template - void VSQUARE(int n, const T* x, T* y) const; - - template - void VPOW(int n, const T* x, T alpha, T* y) const; - - template - void GEMV(bool trans_a, - int M, - int N, - T alpha, - const T* A, - const T* B, - T beta, - T* C) const; - - template - T DOT(int n, const T* x, const T* y) const; - - template - void CUDOT( - int n, const T* x, int incx, const T* y, int incy, T* result) const; - template - void SCAL(int n, const T a, T* x) const; - - template - T ASUM(int n, T* x, int inc) const; - - template - void BatchedGEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int64_t M, - int64_t N, - int64_t K, - T alpha, - const T* A, - const T* B, - T beta, - T* C, - int64_t batchCount, - int64_t strideA, - int64_t strideB) const; - - template - void BatchedGEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int64_t M, - int64_t N, - int64_t K, - U alpha, - const T* A, - const T* B, - U beta, - T* C, - int64_t batchCount, - int64_t strideA, - int64_t strideB) const; - - template - void BatchedGEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int M, - int N, - int K, - T alpha, - const T** A, - const T** B, - T beta, - T** C, - int batchCount) const; - -#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) && \ - !defined(PADDLE_WITH_HIP) - template - void BatchedGEMMWithHead(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int W1, - int H1, - int W2, - int H2, - T alpha, - const T* A, - const T* B, - T beta, - T* C, - int batchCount, - int64_t strideA, - int64_t strideB, - int64_t head_number, - bool split_b_vertical) const; -#endif - - template - void MatMul(const phi::DenseTensor& mat_a, - const MatDescriptor& dim_a, - const phi::DenseTensor& mat_b, - const MatDescriptor& dim_b, - T alpha, - phi::DenseTensor* mat_out, - T beta) const; - - template - void MatMul(const T* mat_a, - const MatDescriptor& dim_a, - const T* mat_b, - const MatDescriptor& dim_b, - T alpha, - T* mat_out, - T beta) const; - - template - void VINV(int n, const T* a, T* y) const; - - template - void VMERF(int n, const T* a, T* y, int64_t mode) const; - - template - void TRSM(CBLAS_SIDE side, - CBLAS_UPLO uplo, - CBLAS_TRANSPOSE transA, - CBLAS_DIAG diag, - int M, - int N, - T alpha, - const T* A, - int lda, - T* B, - int ldb) const; - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - template - void BatchedGETRF(int n, T** a, int* ipiv, int* info, int batch_size) const; - - template - void BatchedGETRI(int n, - const T** a, - const int* ipiv, - T** a_inv, - int* info, - int batch_size) const; - - template - void BatchedMatInv( - int n, const T** a, T** a_inv, int* info, int batch_size) const; - - // cuBlas solve - template - void BatchedGETRS(CBLAS_TRANSPOSE trans, - int n, - int nrhs, - const T** a, - int lda, - int* ipiv, - T** b, - int ldb, - int* info, - int batch_size) const; - - // cuBlas triangular_solve - template - void BatchedTRSM(CBLAS_SIDE side, - CBLAS_UPLO uplo, - CBLAS_TRANSPOSE transA, - CBLAS_DIAG diag, - int M, - int N, - T alpha, - const T** a, - int lda, - T** b, - int ldb, - int batch_size) const; -#endif - - private: - const DeviceContext& dev_ctx_; -}; - -template -class BlasT : private Blas { - public: - using Blas::Blas; - - template - void GEMM(ARGS... args) const { - Base()->template GEMM(args...); - } - -#ifdef PADDLE_WITH_MKLML // @{ Group MKLML: class BlasT - template - T* GEMM_ALLOC(ARGS... args) const { - return Base()->template GEMM_ALLOC(args...); - } - - template - void GEMM_PACK(ARGS... args) const { - Base()->template GEMM_PACK(args...); - } - - template - void GEMM_COMPUTE(ARGS... args) const { - Base()->template GEMM_COMPUTE(args...); - } - - template - void GEMM_FREE(ARGS... args) const { - Base()->template GEMM_FREE(args...); - } - - template - void CSRMM(ARGS... args) const { - Base()->template CSRMM(args...); - } - -#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) - template - void MatMulWithHead(ARGS... args) const { - Base()->template MatMulWithHead(args...); - } -#endif -#endif // @} End Group MKLML: class BlasT - - template - void MatMul(ARGS... args) const { - Base()->template MatMul(args...); - } - - template - void AXPY(ARGS... args) const { - Base()->template AXPY(args...); - } - - template - void VADD(ARGS... args) const { - Base()->template VADD(args...); - } - - template - void VSUB(ARGS... args) const { - Base()->template VSUB(args...); - } - - template - void VMUL(ARGS... args) const { - Base()->template VMUL(args...); - } - - template - void VDIV(ARGS... args) const { - Base()->template VDIV(args...); - } - - template - void VCOPY(ARGS... args) const { - Base()->template VCOPY(args...); - } - - template - void VEXP(ARGS... args) const { - Base()->template VEXP(args...); - } - - template - void VSQUARE(ARGS... args) const { - Base()->template VSQUARE(args...); - } - - template - void VPOW(ARGS... args) const { - Base()->template VPOW(args...); - } - - template - void GEMV(ARGS... args) const { - Base()->template GEMV(args...); - } - - template - T DOT(ARGS... args) const { - return Base()->template DOT(args...); - } - template - void CUDOT(ARGS... args) const { - Base()->template CUDOT(args...); - } - template - void SCAL(ARGS... args) const { - Base()->template SCAL(args...); - } - - template - T ASUM(ARGS... args) const { - return Base()->template ASUM(args...); - } - - template - void BatchedGEMM(ARGS... args) const { - Base()->template BatchedGEMM(args...); - } - - template - void VINV(ARGS... args) const { - Base()->template VINV(args...); - } - - template - void VMERF(ARGS... args) const { - Base()->template VMERF(args...); - } - - template - void TRSM(ARGS... args) const { - Base()->template TRSM(args...); - } - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - template - void BatchedGETRF(ARGS... args) const { - Base()->template BatchedGETRF(args...); - } - - template - void BatchedGETRI(ARGS... args) const { - Base()->template BatchedGETRI(args...); - } - - template - void BatchedMatInv(ARGS... args) const { - Base()->template BatchedMatInv(args...); - } - - // solve - template - void BatchedGETRS(ARGS... args) const { - Base()->template BatchedGETRS(args...); - } - - // triangular_solve - template - void BatchedTRSM(ARGS... args) const { - Base()->template BatchedTRSM(args...); - } -#endif - - private: - const Blas* Base() const { - return static_cast*>(this); - } -}; - -template -inline BlasT GetBlas(const DeviceContext& dev_ctx) { - return BlasT(dev_ctx); -} - -} // namespace funcs -} // namespace phi -// clang-format off -#include "./blas_impl.h" -#ifdef PADDLE_WITH_CUDA -#include "./blas_impl.cu.h" -#endif -#ifdef PADDLE_WITH_HIP -#include "paddle/phi/kernels/funcs/blas/blas_impl.hip.h" -#endif -// clang-format on diff --git a/backends/metax_gpu/kernels/funcs/blas/blas_impl.cu.h b/backends/metax_gpu/kernels/funcs/blas/blas_impl.cu.h deleted file mode 100644 index ae4baa52613..00000000000 --- a/backends/metax_gpu/kernels/funcs/blas/blas_impl.cu.h +++ /dev/null @@ -1,3027 +0,0 @@ -// 2024 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights -// Reserved. -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#if defined(__NVCC__) -#include -#endif -#include "./cublas.h" -#include "glog/logging.h" -#include "paddle/common/flags.h" -#include "paddle/phi/backends/gpu/gpu_context.h" -// #include "paddle/phi/core/flags.h" -#include "paddle/phi/core/enforce.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -#define INT_MAX_VALUE 2147483647 - -PHI_DECLARE_bool(enable_cublas_tensor_op_math); -PHI_DECLARE_bool(gemm_use_half_precision_compute_type); - -namespace phi { -namespace funcs { -template -struct CUBlas; - -template <> -struct CUBlas { - template - static void GEMM(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSgemm(args...)); - } - - template - static void AXPY(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSaxpy(args...)); - } - - template - static void SCAL(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSscal(args...)); - } - - template - static void VCOPY(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasScopy(args...)); - } - - template - static void GEMV(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSgemv(args...)); - } - - template - static void GEMM_BATCH(ARGS... args) { -#if CUDA_VERSION >= 8000 - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSgemmBatched(args...)); -#else - PADDLE_THROW(phi::errors::Unimplemented( - "SgemmBatched is not supported on cuda <= 7.5")); -#endif - } - - template - static void GEMM_STRIDED_BATCH(ARGS... args) { -#if CUDA_VERSION >= 8000 - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::cublasSgemmStridedBatched(args...)); -#else - PADDLE_THROW(phi::errors::Unimplemented( - "SgemmStridedBatched is not supported on cuda <= 7.5")); -#endif - } - - // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply. - // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode - template - static void GEMM_EX(phi::GPUContext *dev_ctx, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float *alpha, - const void *A, - cudaDataType_t Atype, - int lda, - const void *B, - cudaDataType_t Btype, - int ldb, - const float *beta, - void *C, - cudaDataType_t Ctype, - int ldc) { -// Because the gcc 4.8 doesn't expand template parameter pack that -// appears in a lambda-expression, I can not use template parameter pack -// here. -#if CUDA_VERSION >= 8000 - VLOG(5) << "use_tensor_op_math: " - << (dev_ctx->tensor_core_available() ? "True" : "False"); - dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSgemmEx(handle, - transa, - transb, - m, - n, - k, - alpha, - A, - Atype, - lda, - B, - Btype, - ldb, - beta, - C, - Ctype, - ldc)); - }); -#else - PADDLE_THROW(phi::errors::Unimplemented( - "cublasSgemmEx is not supported on cuda <= 7.5")); -#endif - } - - template - static void TRSM(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasStrsm(args...)); - } - - template - static void GETRF_BATCH(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSgetrfBatched(args...)); - } - - template - static void GETRI_BATCH(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSgetriBatched(args...)); - } - - template - static void MATINV_BATCH(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSmatinvBatched(args...)); - } - - template - static void GETRS_BATCH(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSgetrsBatched(args...)); - } - - template - static void TRSM_BATCH(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasStrsmBatched(args...)); - } -}; - -template <> -struct CUBlas { - template - static void GEMM(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDgemm(args...)); - } - - template - static void AXPY(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDaxpy(args...)); - } - - template - static void SCAL(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDscal(args...)); - } - - template - static void VCOPY(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDcopy(args...)); - } - - template - static void GEMV(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDgemv(args...)); - } - - template - static void GEMM_BATCH(ARGS... args) { -#if CUDA_VERSION >= 8000 - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDgemmBatched(args...)); -#else - PADDLE_THROW(phi::errors::Unimplemented( - "DgemmBatched is not supported on cuda <= 7.5")); -#endif - } - - template - static void GEMM_STRIDED_BATCH(ARGS... args) { -#if CUDA_VERSION >= 8000 - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::cublasDgemmStridedBatched(args...)); -#else - PADDLE_THROW(phi::errors::Unimplemented( - "DgemmStridedBatched is not supported on cuda <= 7.5")); -#endif - } - - template - static void GEMM_EX(ARGS... args UNUSED) { - PADDLE_THROW( - phi::errors::Unimplemented("Currently there are not cublasDgemmEx.")); - } - - template - static void TRSM(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDtrsm(args...)); - } - - template - static void GETRF_BATCH(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDgetrfBatched(args...)); - } - - template - static void GETRI_BATCH(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDgetriBatched(args...)); - } - - template - static void MATINV_BATCH(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDmatinvBatched(args...)); - } - - template - static void GETRS_BATCH(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDgetrsBatched(args...)); - } - - template - static void TRSM_BATCH(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDtrsmBatched(args...)); - } -}; - -template <> -struct CUBlas { - using float16 = phi::dtype::float16; - - static void GEMM(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float16 *alpha, - const float16 *A, - int lda, - const float16 *B, - int ldb, - const float16 *beta, - float16 *C, - int ldc) { - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::cublasHgemm(handle, - transa, - transb, - m, - n, - k, - reinterpret_cast(alpha), - reinterpret_cast(A), - lda, - reinterpret_cast(B), - ldb, - reinterpret_cast(beta), - reinterpret_cast(C), - ldc)); - } - -#if defined(__NVCC__) - static void GEMM_BATCH(phi::GPUContext *dev_ctx, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float *alpha, - const float16 **A, - cudaDataType_t Atype, - int lda, - const float16 **B, - cudaDataType_t Btype, - int ldb, - const float *beta, - float16 **C, - cudaDataType_t Ctype, - int ldc, - int batchCount, - cublasComputeType_t computeType) { -#if CUDA_VERSION >= 8000 - cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; -#if CUDA_VERSION >= 9000 - bool use_tensor_op_math = dev_ctx->tensor_core_available(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " - << (use_tensor_op_math ? "True" : "False"); -#endif // CUDA_VERSION >= 9000 - thrust::device_vector A_ptr(A, A + batchCount); - thrust::device_vector B_ptr(B, B + batchCount); - thrust::device_vector C_ptr(C, C + batchCount); - dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::cublasGemmBatchedEx(handle, - transa, - transb, - m, - n, - k, - alpha, - A_ptr.data().get(), - Atype, - lda, - B_ptr.data().get(), - Btype, - ldb, - beta, - C_ptr.data().get(), - Ctype, - ldc, - batchCount, - computeType, - algo)); - }); -#else - PADDLE_THROW(phi::errors::Unimplemented( - "cublasGemmBatchedEx is not supported on cuda <= 7.5")); -#endif - } -#endif - - static void GEMM_STRIDED_BATCH(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float16 *alpha, - const float16 *A, - int lda, - long long int strideA, // NOLINT - const float16 *B, // NOLINT - int ldb, - long long int strideB, // NOLINT - const float16 *beta, - float16 *C, - int ldc, - long long int strideC, // NOLINT - int batchCount) { -#if CUDA_VERSION >= 8000 - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasHgemmStridedBatched( - handle, - transa, - transb, - m, - n, - k, - reinterpret_cast(alpha), - reinterpret_cast(A), - lda, - strideA, - reinterpret_cast(B), - ldb, - strideB, - reinterpret_cast(beta), - reinterpret_cast(C), - ldc, - strideC, - batchCount)); -#else - PADDLE_THROW(phi::errors::Unimplemented( - "HgemmStridedBatched is not supported on cuda <= 7.5")); -#endif - } - - // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply. - // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode - template - static void GEMM_EX(phi::GPUContext *dev_ctx, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const void *alpha, - const void *A, - cudaDataType_t Atype, - int lda, - const void *B, - cudaDataType_t Btype, - int ldb, - const void *beta, - void *C, - cudaDataType_t Ctype, - int ldc, - cublasComputeType_t computeType) { -#if CUDA_VERSION >= 8000 - cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; -#if CUDA_VERSION >= 9000 - bool use_tensor_op_math = dev_ctx->tensor_core_available(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " - << (use_tensor_op_math ? "True" : "False"); -#endif // CUDA_VERSION >= 9000 - - dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle, - transa, - transb, - m, - n, - k, - alpha, - A, - Atype, - lda, - B, - Btype, - ldb, - beta, - C, - Ctype, - ldc, - computeType, - algo)); - }); -#else - PADDLE_THROW(phi::errors::Unimplemented( - "cublasGemmEx is not supported on cuda <= 7.5")); -#endif - } -}; - -template <> -struct CUBlas> { - static void GEMV(cublasHandle_t handle, - cublasOperation_t transa, - int m, - int n, - const phi::dtype::complex *alpha, - const phi::dtype::complex *A, - int lda, - const phi::dtype::complex *B, - int ldb, - const phi::dtype::complex *beta, - phi::dtype::complex *C, - int ldc) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCgemv( - handle, - transa, - m, - n, - reinterpret_cast(alpha), - reinterpret_cast(A), - lda, - reinterpret_cast(B), - ldb, - reinterpret_cast(beta), - reinterpret_cast(C), - ldc)); - } - - static void AXPY(cublasHandle_t handle, - int n, - const phi::dtype::complex *alpha, - const phi::dtype::complex *X, - const int incX, - phi::dtype::complex *Y, - const int incY) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCaxpy( - handle, - n, - reinterpret_cast(alpha), - reinterpret_cast(X), - incX, - reinterpret_cast(Y), - incY)); - } - - static void GEMM_STRIDED_BATCH(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const phi::dtype::complex *alpha, - const phi::dtype::complex *A, - int lda, - long long int strideA, // NOLINT - const phi::dtype::complex *B, // NOLINT - int ldb, - long long int strideB, // NOLINT - const phi::dtype::complex *beta, - phi::dtype::complex *C, - int ldc, - long long int strideC, // NOLINT - int batchCount) { -#if CUDA_VERSION >= 8000 - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCgemmStridedBatched( - handle, - transa, - transb, - m, - n, - k, - reinterpret_cast(alpha), - reinterpret_cast(A), - lda, - strideA, - reinterpret_cast(B), - ldb, - strideB, - reinterpret_cast(beta), - reinterpret_cast(C), - ldc, - strideC, - batchCount)); -#else - PADDLE_THROW(phi::errors::Unimplemented( - "CgemmStridedBatched is not supported on cuda <= 7.5")); -#endif - } - - static void GEMM(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const phi::dtype::complex *alpha, - const phi::dtype::complex *A, - int lda, - const phi::dtype::complex *B, - int ldb, - const phi::dtype::complex *beta, - phi::dtype::complex *C, - int ldc) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCgemm( - handle, - transa, - transb, - m, - n, - k, - reinterpret_cast(alpha), - reinterpret_cast(A), - lda, - reinterpret_cast(B), - ldb, - reinterpret_cast(beta), - reinterpret_cast(C), - ldc)); - } - - static void TRSM(cublasHandle_t handle, - cublasSideMode_t side, - cublasFillMode_t uplo, - cublasOperation_t transa, - cublasDiagType_t diag, - int m, - int n, - const phi::dtype::complex *alpha, - const phi::dtype::complex *A, - int lda, - phi::dtype::complex *B, - int ldb) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCtrsm( - handle, - side, - uplo, - transa, - diag, - m, - n, - reinterpret_cast(alpha), - reinterpret_cast(A), - lda, - reinterpret_cast(B), - ldb)); - } - - // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply. - // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode - template - static void GEMM_EX(phi::GPUContext *dev_ctx, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const void *alpha, - const void *A, - cudaDataType_t Atype, - int lda, - const void *B, - cudaDataType_t Btype, - int ldb, - const void *beta, - void *C, - cudaDataType_t Ctype, - int ldc, - cublasComputeType_t computeType) { -#if CUDA_VERSION >= 8000 - cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; -#if CUDA_VERSION >= 9000 - bool use_tensor_op_math = dev_ctx->tensor_core_available(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " - << (use_tensor_op_math ? "True" : "False"); -#endif // CUDA_VERSION >= 9000 - - dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle, - transa, - transb, - m, - n, - k, - alpha, - A, - Atype, - lda, - B, - Btype, - ldb, - beta, - C, - Ctype, - ldc, - computeType, - algo)); - }); -#else - PADDLE_THROW(phi::errors::Unimplemented( - "cublasGemmEx is not supported on cuda <= 7.5")); -#endif - } - - static void TRSM_BATCH(cublasHandle_t handle, - cublasSideMode_t side, - cublasFillMode_t uplo, - cublasOperation_t transa, - cublasDiagType_t diag, - int m, - int n, - const phi::dtype::complex *alpha, - const phi::dtype::complex **A, - int lda, - phi::dtype::complex **B, - int ldb, - int batch_size) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCtrsmBatched( - handle, - side, - uplo, - transa, - diag, - m, - n, - reinterpret_cast(alpha), - reinterpret_cast(A), - lda, - reinterpret_cast(B), - ldb, - batch_size)); - } - // ****************************************************************新增模版定义********************* - - static void GETRF_BATCH(cublasHandle_t handle, - int n, - phi::dtype::complex **A, - int lda, - int *ipiv, - int *info, - int batch_size) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCgetrfBatched( - handle, - n, - reinterpret_cast(A), - lda, - ipiv, - info, - batch_size)); - } - - static void GETRI_BATCH(cublasHandle_t handle, - int n, - const phi::dtype::complex **A, - int lda, - const int *ipiv, - phi::dtype::complex **Ainv, - int ldc, - int *info, - int batch_size) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCgetriBatched( - handle, - n, - reinterpret_cast(A), - lda, - ipiv, - reinterpret_cast(Ainv), - ldc, - info, - batch_size)); - } - - static void MATINV_BATCH(cublasHandle_t handle, - int n, - const phi::dtype::complex **A, - int lda, - phi::dtype::complex **Ainv, - int lda_inv, - int *info, - int batch_size) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCmatinvBatched( - handle, - n, - reinterpret_cast(A), - lda, - reinterpret_cast(Ainv), - lda_inv, - info, - batch_size)); - } - // ****************************************************************新增模版定义********************* -}; - -template <> -struct CUBlas> { - static void GEMV(cublasHandle_t handle, - cublasOperation_t transa, - int m, - int n, - const phi::dtype::complex *alpha, - const phi::dtype::complex *A, - int lda, - const phi::dtype::complex *B, - int ldb, - const phi::dtype::complex *beta, - phi::dtype::complex *C, - int ldc) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZgemv( - handle, - transa, - m, - n, - reinterpret_cast(alpha), - reinterpret_cast(A), - lda, - reinterpret_cast(B), - ldb, - reinterpret_cast(beta), - reinterpret_cast(C), - ldc)); - } - - static void AXPY(cublasHandle_t handle, - int n, - const phi::dtype::complex *alpha, - const phi::dtype::complex *X, - const int incX, - phi::dtype::complex *Y, - const int incY) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZaxpy( - handle, - n, - reinterpret_cast(alpha), - reinterpret_cast(X), - incX, - reinterpret_cast(Y), - incY)); - } - - static void GEMM_STRIDED_BATCH( - cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const phi::dtype::complex *alpha, - const phi::dtype::complex *A, - int lda, - long long int strideA, // NOLINT - const phi::dtype::complex *B, // NOLINT - int ldb, - long long int strideB, // NOLINT - const phi::dtype::complex *beta, - phi::dtype::complex *C, - int ldc, - long long int strideC, // NOLINT - int batchCount) { -#if CUDA_VERSION >= 8000 - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZgemmStridedBatched( - handle, - transa, - transb, - m, - n, - k, - reinterpret_cast(alpha), - reinterpret_cast(A), - lda, - strideA, - reinterpret_cast(B), - ldb, - strideB, - reinterpret_cast(beta), - reinterpret_cast(C), - ldc, - strideC, - batchCount)); -#else - PADDLE_THROW(phi::errors::Unimplemented( - "CgemmStridedBatched is not supported on cuda <= 7.5")); -#endif - } - - static void GEMM(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const phi::dtype::complex *alpha, - const phi::dtype::complex *A, - int lda, - const phi::dtype::complex *B, - int ldb, - const phi::dtype::complex *beta, - phi::dtype::complex *C, - int ldc) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZgemm( - handle, - transa, - transb, - m, - n, - k, - reinterpret_cast(alpha), - reinterpret_cast(A), - lda, - reinterpret_cast(B), - ldb, - reinterpret_cast(beta), - reinterpret_cast(C), - ldc)); - } - - static void TRSM(cublasHandle_t handle, - cublasSideMode_t side, - cublasFillMode_t uplo, - cublasOperation_t transa, - cublasDiagType_t diag, - int m, - int n, - const phi::dtype::complex *alpha, - const phi::dtype::complex *A, - int lda, - phi::dtype::complex *B, - int ldb) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZtrsm( - handle, - side, - uplo, - transa, - diag, - m, - n, - reinterpret_cast(alpha), - reinterpret_cast(A), - lda, - reinterpret_cast(B), - ldb)); - } - - static void TRSM_BATCH(cublasHandle_t handle, - cublasSideMode_t side, - cublasFillMode_t uplo, - cublasOperation_t transa, - cublasDiagType_t diag, - int m, - int n, - const phi::dtype::complex *alpha, - const phi::dtype::complex **A, - int lda, - phi::dtype::complex **B, - int ldb, - int batch_size) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZtrsmBatched( - handle, - side, - uplo, - transa, - diag, - m, - n, - reinterpret_cast(alpha), - reinterpret_cast(A), - lda, - reinterpret_cast(B), - ldb, - batch_size)); - } - - // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply. - // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode - template - static void GEMM_EX(phi::GPUContext *dev_ctx, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const void *alpha, - const void *A, - cudaDataType_t Atype, - int lda, - const void *B, - cudaDataType_t Btype, - int ldb, - const void *beta, - void *C, - cudaDataType_t Ctype, - int ldc, - cublasComputeType_t computeType) { -#if CUDA_VERSION >= 8000 - cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; -#if CUDA_VERSION >= 9000 - bool use_tensor_op_math = dev_ctx->tensor_core_available(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " - << (use_tensor_op_math ? "True" : "False"); -#endif // CUDA_VERSION >= 9000 - - dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle, - transa, - transb, - m, - n, - k, - alpha, - A, - Atype, - lda, - B, - Btype, - ldb, - beta, - C, - Ctype, - ldc, - computeType, - algo)); - }); -#else - PADDLE_THROW(phi::errors::Unimplemented( - "cublasGemmEx is not supported on cuda <= 7.5")); -#endif - } - // &*******************************************新增模版定义************************* - static void GETRF_BATCH(cublasHandle_t handle, - int n, - phi::dtype::complex **A, - int lda, - int *ipiv, - int *info, - int batch_size) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZgetrfBatched( - handle, - n, - reinterpret_cast(A), - lda, - ipiv, - info, - batch_size)); - } - - static void GETRI_BATCH(cublasHandle_t handle, - int n, - const phi::dtype::complex **A, - int lda, - const int *ipiv, - phi::dtype::complex **Ainv, - int ldc, - int *info, - int batch_size) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZgetriBatched( - handle, - n, - reinterpret_cast(A), - lda, - ipiv, - reinterpret_cast(Ainv), - ldc, - info, - batch_size)); - } - - static void MATINV_BATCH(cublasHandle_t handle, - int n, - const phi::dtype::complex **A, - int lda, - phi::dtype::complex **Ainv, - int lda_inv, - int *info, - int batch_size) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZmatinvBatched( - handle, - n, - reinterpret_cast(A), - lda, - reinterpret_cast(Ainv), - lda_inv, - info, - batch_size)); - } - // &*******************************************新增模版定义************************* -}; - -inline void CheckGEMMNSize(int64_t N) { - constexpr int64_t kMaxN = 1073741823; - if (N > kMaxN) { - PADDLE_THROW(common::errors::Unimplemented( - "cublas GEMM does not support N > %ld. Got N = %ld. ", kMaxN, N)); - } -} - -template <> -template -void Blas::GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int64_t M, - int64_t N, - int64_t K, - T alpha, - const T *A, - const T *B, - T beta, - T *C) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int64_t lda = (transA == CblasNoTrans) ? K : M; - int64_t ldb = (transB == CblasNoTrans) ? N : K; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - -#if CUDA_VERSION >= 8000 - if (FLAGS_enable_cublas_tensor_op_math && std::is_same::value) { - auto &cuda_ctx = const_cast(dev_ctx_); - if (M > INT_MAX_VALUE || N > INT_MAX_VALUE || K > INT_MAX_VALUE) { -#if CUDA_VERSION >= 12030 && defined(__linux__) - PADDLE_THROW(common::errors::Unimplemented( - "CUBlas::GEMM_EX_64 is not complete")); -#else - PADDLE_THROW(common::errors::Unimplemented( - "GEMM_EX_64 is not supported on cuda < 12.3")); -#endif - } else { - CheckGEMMNSize(N); - CUBlas::GEMM_EX(&cuda_ctx, - cuTransB, - cuTransA, - N, - M, - K, - &alpha, - B, - CUDA_R_32F, - ldb, - A, - CUDA_R_32F, - lda, - &beta, - C, - CUDA_R_32F, - N); - } - } else { -#endif // CUDA_VERSION >= 8000 - if (M > INT_MAX_VALUE || N > INT_MAX_VALUE || K > INT_MAX_VALUE) { - PADDLE_THROW(common::errors::Unimplemented( - "GEMM_EX_64 is not supported on cuda < 12.3")); - } else { - dev_ctx_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GEMM(handle, - cuTransB, - cuTransA, - N, - M, - K, - &alpha, - B, - ldb, - A, - lda, - &beta, - C, - N); - }); - } - -#if CUDA_VERSION >= 8000 - } -#endif // CUDA_VERSION >= 8000 -} - -template <> -template <> -inline void Blas::GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int64_t M, - int64_t N, - int64_t K, - phi::dtype::float16 alpha, - const phi::dtype::float16 *A, - const phi::dtype::float16 *B, - phi::dtype::float16 beta, - phi::dtype::float16 *C) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int64_t lda = (transA == CblasNoTrans) ? K : M; - int64_t ldb = (transB == CblasNoTrans) ? N : K; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - - // TODO(kexinzhao): add processing code for compute capability < 53 case - PADDLE_ENFORCE_GE( - dev_ctx_.GetComputeCapability(), - 53, - phi::errors::InvalidArgument( - "cublas fp16 gemm requires GPU compute capability >= 53," - "but received %d", - dev_ctx_.GetComputeCapability())); - - float h_alpha = static_cast(alpha); - float h_beta = static_cast(beta); - -#if CUDA_VERSION >= 8000 - // cublasHgemm does true FP16 computation which is slow for non-Volta - // GPUs. So use cublasGemmEx instead which does pesudo FP16 computation: - // input/output in fp16, computation in fp32, which can also be accelerated - // using tensor cores in volta GPUs. - auto &cuda_ctx = const_cast(dev_ctx_); - CUBlas::GEMM_EX(&cuda_ctx, - cuTransB, - cuTransA, - N, - M, - K, - &h_alpha, - B, - CUDA_R_16F, - ldb, - A, - CUDA_R_16F, - lda, - &h_beta, - C, - CUDA_R_16F, - N, - CUBLAS_COMPUTE_32F); -#else - // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm - - dev_ctx_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GEMM(handle, - cuTransB, - cuTransA, - N, - M, - K, - &h_alpha, - h_B, - ldb, - h_A, - lda, - &h_beta, - h_C, - N); - }); -#endif // CUDA_VERSION >= 8000 -} - -template <> -template -void Blas::GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int64_t M, - int64_t N, - int64_t K, - U alpha, - const T *A, - const T *B, - U beta, - T *C) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int64_t lda = (transA == CblasNoTrans) ? K : M; - int64_t ldb = (transB == CblasNoTrans) ? N : K; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - - T t_alpha = static_cast(alpha); - T t_beta = static_cast(beta); - -#if CUDA_VERSION >= 8000 - if (FLAGS_enable_cublas_tensor_op_math && std::is_same::value) { - auto &cuda_ctx = const_cast(dev_ctx_); - if (M > INT_MAX_VALUE || N > INT_MAX_VALUE || K > INT_MAX_VALUE) { -#if CUDA_VERSION >= 12030 && defined(__linux__) - PADDLE_THROW(common::errors::Unimplemented("GEMM_EX_64 is not complete")); -#else - PADDLE_THROW(common::errors::Unimplemented( - "GEMM_EX_64 is not supported on cuda < 12.3")); -#endif - } else { - CheckGEMMNSize(N); - CUBlas::GEMM_EX(&cuda_ctx, - cuTransB, - cuTransA, - static_cast(N), - static_cast(M), - static_cast(K), - &t_alpha, - B, - CUDA_R_32F, - static_cast(ldb), - A, - CUDA_R_32F, - static_cast(lda), - &t_beta, - C, - CUDA_R_32F, - static_cast(N)); - } - } else { -#endif // CUDA_VERSION >= 8000 - if (M > INT_MAX_VALUE || N > INT_MAX_VALUE || K > INT_MAX_VALUE) { - PADDLE_THROW(common::errors::Unimplemented( - "GEMM_EX_64 is not supported on cuda < 12.3")); - } else { - dev_ctx_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GEMM(handle, - cuTransB, - cuTransA, - static_cast(N), - static_cast(M), - static_cast(K), - &t_alpha, - B, - static_cast(ldb), - A, - static_cast(lda), - &t_beta, - C, - static_cast(N)); - }); - } - -#if CUDA_VERSION >= 8000 - } -#endif // CUDA_VERSION >= 8000 -} - -template <> -template <> -inline void Blas::GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int64_t M, - int64_t N, - int64_t K, - float alpha, - const phi::dtype::float16 *A, - const phi::dtype::float16 *B, - float beta, - phi::dtype::float16 *C) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int64_t lda = (transA == CblasNoTrans) ? K : M; - int64_t ldb = (transB == CblasNoTrans) ? N : K; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - - // TODO(kexinzhao): add processing code for compute capability < 53 case - // PADDLE_ENFORCE_GE( - // dev_ctx_.GetComputeCapability(), - // 53, - // common::errors::InvalidArgument( - // "cublas fp16 gemm requires GPU compute capability >= 53," - // "but received %d", - // dev_ctx_.GetComputeCapability())); - - float h_alpha = alpha; - float h_beta = beta; - -#if CUDA_VERSION >= 8000 - auto &cuda_ctx = const_cast(dev_ctx_); -#endif - // cublasHgemm does true FP16 computation which is slow for non-Volta - // GPUs. So use cublasGemmEx instead which does pseudo FP16 computation: - // input/output in fp16, computation in fp32, which can also be accelerated - // using tensor cores in volta GPUs. - if (M > INT_MAX_VALUE || N > INT_MAX_VALUE || K > INT_MAX_VALUE) { -#if CUDA_VERSION >= 12030 && defined(__linux__) - PADDLE_THROW(common::errors::Unimplemented("GEMM_EX_64 is not complete")); -#else - PADDLE_THROW(common::errors::Unimplemented( - "GEMM_EX_64 is not supported on cuda < 12.3")); -#endif // CUDA_VERSION >= 12030 - } else { -#if CUDA_VERSION >= 8000 - CheckGEMMNSize(N); - CUBlas::GEMM_EX(&cuda_ctx, - cuTransB, - cuTransA, - static_cast(N), - static_cast(M), - static_cast(K), - &h_alpha, - B, - CUDA_R_16F, - static_cast(ldb), - A, - CUDA_R_16F, - static_cast(lda), - &h_beta, - C, - CUDA_R_16F, - static_cast(N), - CUBLAS_COMPUTE_32F); -#else - // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm - dev_ctx_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GEMM(handle, - cuTransB, - cuTransA, - static_cast(N), - static_cast(M), - static_cast(K), - &h_alpha, - h_B, - static_cast(ldb), - h_A, - static_cast(lda), - &h_beta, - h_C, - static_cast(N)); - }); -#endif // CUDA_VERSION >= 8000 - } -} - -template <> -template <> -inline void Blas::GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int64_t M, - int64_t N, - int64_t K, - phi::dtype::bfloat16 alpha, - const phi::dtype::bfloat16 *A, - const phi::dtype::bfloat16 *B, - phi::dtype::bfloat16 beta, - phi::dtype::bfloat16 *C) const { -#if CUDA_VERSION >= 11000 - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int64_t lda = (transA == CblasNoTrans) ? K : M; - int64_t ldb = (transB == CblasNoTrans) ? N : K; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - - PADDLE_ENFORCE_GE( - dev_ctx_.GetComputeCapability(), - 80, - phi::errors::InvalidArgument( - "cublas bf16 gemm requires GPU compute capability >= 80," - "but received %d", - dev_ctx_.GetComputeCapability())); - - float h_alpha = static_cast(alpha); - float h_beta = static_cast(beta); - - cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; - bool use_tensor_op_math = dev_ctx_.tensor_core_available(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False"); - - if (M > INT_MAX_VALUE || N > INT_MAX_VALUE || K > INT_MAX_VALUE) { -#if CUDA_VERSION >= 12030 && defined(__linux__) - PADDLE_THROW( - common::errors::Unimplemented("cublasGemmEx_64 is not complete")); -#else - PADDLE_THROW(common::errors::Unimplemented( - "cublasGemmEx_64 is not supported on cuda < 12.3")); -#endif // CUDA_VERSION >= 12030 - } else { - CheckGEMMNSize(N); - dev_ctx_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle, - cuTransB, - cuTransA, - N, - M, - K, - &h_alpha, - B, - CUDA_R_16BF, - ldb, - A, - CUDA_R_16BF, - lda, - &h_beta, - C, - CUDA_R_16BF, - N, - CUBLAS_COMPUTE_32F, - algo)); - }); - } -#else - // raise error - PADDLE_THROW(phi::errors::Unimplemented( - "cublasGemmEx with bfloat16 is not supported on cuda <= 11")); - -#endif // CUDA_VERSION >= 11000 -} - -template <> -template <> -inline void Blas::GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int64_t M, - int64_t N, - int64_t K, - phi::dtype::complex alpha, - const phi::dtype::complex *A, - const phi::dtype::complex *B, - phi::dtype::complex beta, - phi::dtype::complex *C) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int64_t lda = (transA == CblasNoTrans) ? K : M; - int64_t ldb = (transB == CblasNoTrans) ? N : K; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - - // TODO(kexinzhao): add processing code for compute capability < 53 case - PADDLE_ENFORCE_GE( - dev_ctx_.GetComputeCapability(), - 53, - phi::errors::InvalidArgument( - "cublas complex64 gemm requires GPU compute capability >= 53," - "but received %d", - dev_ctx_.GetComputeCapability())); - - thrust::complex c_alpha = - thrust::complex(alpha.real, alpha.imag); - thrust::complex c_beta = thrust::complex(beta.real, beta.imag); - -#if CUDA_VERSION >= 8000 - auto &cuda_ctx = const_cast(dev_ctx_); -#endif - - if (M > INT_MAX_VALUE || N > INT_MAX_VALUE || K > INT_MAX_VALUE) { -#if CUDA_VERSION >= 12030 && defined(__linux__) - PADDLE_THROW(common::errors::Unimplemented("GEMM_EX_64 is not complete")); -#else - PADDLE_THROW(common::errors::Unimplemented( - "GEMM_EX_64 is not supported on cuda < 12.3")); -#endif // CUDA_VERSION >= 12030 - } else { -#if CUDA_VERSION >= 8000 - CheckGEMMNSize(N); - CUBlas>::GEMM_EX(&cuda_ctx, - cuTransB, - cuTransA, - static_cast(N), - static_cast(M), - static_cast(K), - &c_alpha, - B, - CUDA_C_32F, - static_cast(ldb), - A, - CUDA_C_32F, - static_cast(lda), - &c_beta, - C, - CUDA_C_32F, - static_cast(N), - CUBLAS_COMPUTE_32F); -#else - // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm - - dev_ctx_.CublasCall([&](cublasHandle_t handle) { - CUBlas>::GEMM(handle, - cuTransB, - cuTransA, - static_cast(N), - static_cast(M), - static_cast(K), - &c_alpha, - h_B, - static_cast(ldb), - h_A, - static_cast(lda), - &c_beta, - h_C, - static_cast(N)); - }); -#endif // CUDA_VERSION >= 8000 - } -} - -template <> -template <> -inline void Blas::GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int64_t M, - int64_t N, - int64_t K, - phi::dtype::complex alpha, - const phi::dtype::complex *A, - const phi::dtype::complex *B, - phi::dtype::complex beta, - phi::dtype::complex *C) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int64_t lda = (transA == CblasNoTrans) ? K : M; - int64_t ldb = (transB == CblasNoTrans) ? N : K; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - - // TODO(kexinzhao): add processing code for compute capability < 53 case - PADDLE_ENFORCE_GE( - dev_ctx_.GetComputeCapability(), - 53, - phi::errors::InvalidArgument( - "cublas complex128 gemm requires GPU compute capability >= 53," - "but received %d", - dev_ctx_.GetComputeCapability())); - - thrust::complex c_alpha = - thrust::complex(alpha.real, alpha.imag); - thrust::complex c_beta = - thrust::complex(beta.real, beta.imag); - -#if CUDA_VERSION >= 8000 - auto &cuda_ctx = const_cast(dev_ctx_); -#endif - - if (M > INT_MAX_VALUE || N > INT_MAX_VALUE || K > INT_MAX_VALUE) { -#if CUDA_VERSION >= 12030 && defined(__linux__) - PADDLE_THROW(common::errors::Unimplemented("GEMM_EX_64 is not complete")); -#else - PADDLE_THROW(common::errors::Unimplemented( - "GEMM_EX_64 is not supported on cuda < 12.3")); -#endif // CUDA_VERSION >= 12030 - } else { -#if CUDA_VERSION >= 8000 - CheckGEMMNSize(N); - CUBlas>::GEMM_EX(&cuda_ctx, - cuTransB, - cuTransA, - static_cast(N), - static_cast(M), - static_cast(K), - &c_alpha, - B, - CUDA_C_64F, - static_cast(ldb), - A, - CUDA_C_64F, - static_cast(lda), - &c_beta, - C, - CUDA_C_64F, - static_cast(N), - CUBLAS_COMPUTE_64F); -#else - // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm - - dev_ctx_.CublasCall([&](cublasHandle_t handle) { - CUBlas>::GEMM(handle, - cuTransB, - cuTransA, - static_cast(N), - static_cast(M), - static_cast(K), - &c_alpha, - h_B, - static_cast(ldb), - h_A, - static_cast(lda), - &c_beta, - h_C, - static_cast(N)); - }); -#endif // CUDA_VERSION >= 8000 - } -} - -template <> -template <> -inline void Blas::GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int64_t M, - int64_t N, - int64_t K, - float alpha, - const phi::dtype::bfloat16 *A, - const phi::dtype::bfloat16 *B, - float beta, - phi::dtype::bfloat16 *C) const { -#if CUDA_VERSION >= 11000 - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int64_t lda = (transA == CblasNoTrans) ? K : M; - int64_t ldb = (transB == CblasNoTrans) ? N : K; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - - // PADDLE_ENFORCE_GE( - // dev_ctx_.GetComputeCapability(), - // 80, - // common::errors::InvalidArgument( - // "cublas bf16 gemm requires GPU compute capability >= 80," - // "but received %d", - // dev_ctx_.GetComputeCapability())); - - float h_alpha = alpha; - float h_beta = beta; - - cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; - bool use_tensor_op_math = dev_ctx_.tensor_core_available(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False"); - if (M > INT_MAX_VALUE || N > INT_MAX_VALUE || K > INT_MAX_VALUE) { -#if CUDA_VERSION >= 12030 && defined(__linux__) - PADDLE_THROW( - common::errors::Unimplemented("cublasGemmEx_64 is not complete")); -#else - PADDLE_THROW(common::errors::Unimplemented( - "cublasGemmEx_64 is not supported on cuda < 12.3")); -#endif // CUDA_VERSION >= 12030 - } else { - CheckGEMMNSize(N); - dev_ctx_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::cublasGemmEx(handle, - cuTransB, - cuTransA, - static_cast(N), - static_cast(M), - static_cast(K), - &h_alpha, - B, - CUDA_R_16BF, - static_cast(ldb), - A, - CUDA_R_16BF, - static_cast(lda), - &h_beta, - C, - CUDA_R_16BF, - static_cast(N), - CUDA_R_32F, - algo)); - }); - } -#else - // raise error - PADDLE_THROW(common::errors::Unimplemented( - "cublasGemmEx with bfloat16 is not supported on cuda <= 11")); - -#endif // CUDA_VERSION >= 11000 -} - -template <> -template -void Blas::GEMM(bool transA, - bool transB, - int M, - int N, - int K, - T alpha, - const T *A, - int lda, - const T *B, - int ldb, - T beta, - T *C, - int ldc) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N; - cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N; - -#if CUDA_VERSION >= 8000 - if (FLAGS_enable_cublas_tensor_op_math && std::is_same::value) { - auto &cuda_ctx = const_cast(dev_ctx_); - CUBlas::GEMM_EX(&cuda_ctx, - cuTransB, - cuTransA, - N, - M, - K, - &alpha, - B, - CUDA_R_32F, - ldb, - A, - CUDA_R_32F, - lda, - &beta, - C, - CUDA_R_32F, - ldc); - } else { -#endif // CUDA_VERSION >= 8000 - - dev_ctx_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GEMM(handle, - cuTransB, - cuTransA, - N, - M, - K, - &alpha, - B, - ldb, - A, - lda, - &beta, - C, - ldc); - }); - -#if CUDA_VERSION >= 8000 - } -#endif // CUDA_VERSION >= 8000 -} - -template <> -template <> -inline void Blas::GEMM(bool transA, - bool transB, - int M, - int N, - int K, - phi::dtype::float16 alpha, - const phi::dtype::float16 *A, - int lda, - const phi::dtype::float16 *B, - int ldb, - phi::dtype::float16 beta, - phi::dtype::float16 *C, - int ldc) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N; - cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N; - - dev_ctx_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GEMM(handle, - cuTransB, - cuTransA, - N, - M, - K, - &alpha, - B, - ldb, - A, - lda, - &beta, - C, - ldc); - }); -} - -template <> -template <> -inline void Blas::GEMM(bool transA, - bool transB, - int M, - int N, - int K, - phi::dtype::bfloat16 alpha, - const phi::dtype::bfloat16 *A, - int lda, - const phi::dtype::bfloat16 *B, - int ldb, - phi::dtype::bfloat16 beta, - phi::dtype::bfloat16 *C, - int ldc) const { -#if CUDA_VERSION >= 11000 - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N; - cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N; - - // PADDLE_ENFORCE_GE( - // dev_ctx_.GetComputeCapability(), - // 80, - // phi::errors::InvalidArgument( - // "cublas bf16 gemm requires GPU compute capability >= 80," - // "but received %d", - // dev_ctx_.GetComputeCapability())); - - float h_alpha = static_cast(alpha); - float h_beta = static_cast(beta); - - cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; - bool use_tensor_op_math = dev_ctx_.tensor_core_available(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False"); - - dev_ctx_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle, - cuTransB, - cuTransA, - N, - M, - K, - &h_alpha, - B, - CUDA_R_16BF, - ldb, - A, - CUDA_R_16BF, - lda, - &h_beta, - C, - CUDA_R_16BF, - ldc, - CUBLAS_COMPUTE_32F, - algo)); - }); -#else - // raise error - PADDLE_THROW(phi::errors::Unimplemented( - "cublasGemmEx with bfloat16 is not supported on cuda <= 11")); - -#endif // CUDA_VERSION >= 11000 -} - -template <> -template -void Blas::AXPY(int n, T alpha, const T *x, T *y) const { - dev_ctx_.CublasCall([&](cublasHandle_t handle) { - CUBlas::AXPY(handle, n, &alpha, x, 1, y, 1); - }); -} - -template <> -template -void Blas::SCAL(int n, const T alpha, T *x) const { - dev_ctx_.CublasCall( - [&](cublasHandle_t handle) { CUBlas::SCAL(handle, n, &alpha, x, 1); }); -} - -template <> -template -void Blas::VCOPY(int n, const T *x, T *y) const { - dev_ctx_.CublasCall( - [&](cublasHandle_t handle) { CUBlas::VCOPY(handle, n, x, 1, y, 1); }); -} - -template <> -template -void Blas::GEMV(bool trans_a, - int M, - int N, - T alpha, - const T *A, - const T *B, - T beta, - T *C) const { - cublasOperation_t cuTransA = !trans_a ? CUBLAS_OP_T : CUBLAS_OP_N; - - dev_ctx_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GEMV(handle, cuTransA, N, M, &alpha, A, N, B, 1, &beta, C, 1); - }); -} - -template <> -template <> -inline void Blas::GEMV(bool trans_a, - int M, - int N, - phi::dtype::float16 alpha, - const phi::dtype::float16 *A, - const phi::dtype::float16 *B, - phi::dtype::float16 beta, - phi::dtype::float16 *C) const { - // Because cublas doesn't support half gemv, we use cublasHgemm to achieve it. - if (trans_a) { - this->template GEMM( - CblasNoTrans, CblasNoTrans, 1, N, M, alpha, B, A, beta, C); - } else { - this->template GEMM( - CblasNoTrans, CblasNoTrans, M, 1, N, alpha, A, B, beta, C); - } -} - -template <> -template <> -inline void Blas::GEMV(bool trans_a, - int M, - int N, - phi::dtype::bfloat16 alpha, - const phi::dtype::bfloat16 *A, - const phi::dtype::bfloat16 *B, - phi::dtype::bfloat16 beta, - phi::dtype::bfloat16 *C) const { - // Because cublas doesn't support bfloat gemv, we use cublasHgemm to achieve - // it. - if (trans_a) { - this->template GEMM( - CblasNoTrans, CblasNoTrans, 1, N, M, alpha, B, A, beta, C); - } else { - this->template GEMM( - CblasNoTrans, CblasNoTrans, M, 1, N, alpha, A, B, beta, C); - } -} - -template <> -template -void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int64_t M, - int64_t N, - int64_t K, - T alpha, - const T *A, - const T *B, - T beta, - T *C, - int64_t batchCount, - int64_t strideA, - int64_t strideB) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int64_t lda = (transA == CblasNoTrans) ? K : M; - int64_t ldb = (transB == CblasNoTrans) ? N : K; - int64_t ldc = N; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - const int64_t strideC = M * N; - -#if CUDA_VERSION >= 9010 - if ((FLAGS_enable_cublas_tensor_op_math && (std::is_same::value)) || - std::is_same::value) { - cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; - bool use_tensor_op_math = dev_ctx_.tensor_core_available(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " - << (use_tensor_op_math ? "True" : "False"); - VLOG(4) << "use_half_precision_compute_type: " - << FLAGS_gemm_use_half_precision_compute_type; - - auto fp = std::is_same::value ? CUDA_R_32F : CUDA_R_16F; -#if CUDA_VERSION >= 11000 - auto compute_type = CUBLAS_COMPUTE_32F; -#else - auto compute_type = CUDA_R_32F; -#endif - - float h_alpha = static_cast(alpha); - float h_beta = static_cast(beta); - void *a = static_cast(&h_alpha); - void *b = static_cast(&h_beta); - // set ComputeType as CUDA_R_32F for fp16, for better accuracy - if (FLAGS_gemm_use_half_precision_compute_type == true && - std::is_same::value) { - a = static_cast(&alpha); - b = static_cast(&beta); -#if CUDA_VERSION >= 11000 - compute_type = CUBLAS_COMPUTE_16F; -#else - compute_type = CUDA_R_16F; -#endif - } - - if (M > INT_MAX_VALUE || N > INT_MAX_VALUE || K > INT_MAX_VALUE) { -#if CUDA_VERSION >= 12030 && defined(__linux__) - PADDLE_THROW(common::errors::Unimplemented( - "cublasGemmStridedBatchedEx_64 is not complete")); -#else - PADDLE_THROW(common::errors::Unimplemented( - "cublasGemmStridedBatchedEx_64 is not supported on cuda < 12.3")); -#endif // CUDA_VERSION >= 12030 - } else { - dev_ctx_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::cublasGemmStridedBatchedEx(handle, - cuTransB, - cuTransA, - N, - M, - K, - a, - B, - fp, - ldb, - strideB, - A, - fp, - lda, - strideA, - b, - C, - fp, - ldc, - strideC, - batchCount, - compute_type, - algo)); - }); - } - } else { -#endif // CUDA_VERSION >= 9010 - - dev_ctx_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GEMM_STRIDED_BATCH(handle, - cuTransB, - cuTransA, - static_cast(N), - static_cast(M), - static_cast(K), - &alpha, - B, - static_cast(ldb), - strideB, - A, - static_cast(lda), - strideA, - &beta, - C, - ldc, - strideC, - static_cast(batchCount)); - }); - -#if CUDA_VERSION >= 9010 - } -#endif // CUDA_VERSION >= 9010 -} - -template <> -template -void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int64_t M, - int64_t N, - int64_t K, - U alpha, - const T *A, - const T *B, - U beta, - T *C, - int64_t batchCount, - int64_t strideA, - int64_t strideB) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int64_t lda = (transA == CblasNoTrans) ? K : M; - int64_t ldb = (transB == CblasNoTrans) ? N : K; - int64_t ldc = N; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - const int64_t strideC = M * N; -#if CUDA_VERSION >= 9010 - if ((FLAGS_enable_cublas_tensor_op_math && (std::is_same::value)) || - std::is_same::value) { - cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; - bool use_tensor_op_math = dev_ctx_.tensor_core_available(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " - << (use_tensor_op_math ? "True" : "False"); - VLOG(4) << "use_half_precision_compute_type: " - << FLAGS_gemm_use_half_precision_compute_type; - - auto fp = std::is_same::value ? CUDA_R_32F : CUDA_R_16F; -#if CUDA_VERSION >= 11000 - auto compute_type = CUBLAS_COMPUTE_32F; -#else - auto compute_type = CUDA_R_32F; -#endif - - float h_alpha = static_cast(alpha); - float h_beta = static_cast(beta); - void *a = static_cast(&h_alpha); - void *b = static_cast(&h_beta); - // set ComputeType as CUDA_R_32F for fp16, for better accuracy - if (FLAGS_gemm_use_half_precision_compute_type == true && - std::is_same::value) { - a = static_cast(&alpha); - b = static_cast(&beta); -#if CUDA_VERSION >= 11000 - compute_type = CUBLAS_COMPUTE_16F; -#else - compute_type = CUDA_R_16F; -#endif - } - - if (M > INT_MAX_VALUE || N > INT_MAX_VALUE || K > INT_MAX_VALUE || - batchCount > INT_MAX_VALUE) { -#if CUDA_VERSION >= 12030 && defined(__linux__) - PADDLE_THROW(common::errors::Unimplemented( - "cublasGemmStridedBatchedEx_64 is not complete")); -#else - PADDLE_THROW(common::errors::Unimplemented( - "cublasGemmStridedBatchedEx_64 is not supported on cuda < 12.3")); -#endif // CUDA_VERSION >= 12030 - } else { - dev_ctx_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmStridedBatchedEx( - handle, - cuTransB, - cuTransA, - static_cast(N), - static_cast(M), - static_cast(K), - a, - B, - fp, - static_cast(ldb), - strideB, - A, - fp, - static_cast(lda), - strideA, - b, - C, - fp, - static_cast(ldc), - strideC, - static_cast(batchCount), - compute_type, - algo)); - }); - } - } else { -#endif // CUDA_VERSION >= 9010 - T h_alpha = static_cast(alpha); - T h_beta = static_cast(beta); - dev_ctx_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GEMM_STRIDED_BATCH(handle, - cuTransB, - cuTransA, - static_cast(N), - static_cast(M), - static_cast(K), - &h_alpha, - B, - static_cast(ldb), - strideB, - A, - static_cast(lda), - strideA, - &h_beta, - C, - static_cast(ldc), - strideC, - static_cast(batchCount)); - }); - -#if CUDA_VERSION >= 9010 - } -#endif // CUDA_VERSION >= 9010 -} - -template <> -template <> -inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int64_t M, - int64_t N, - int64_t K, - phi::dtype::bfloat16 alpha, - const phi::dtype::bfloat16 *A, - const phi::dtype::bfloat16 *B, - phi::dtype::bfloat16 beta, - phi::dtype::bfloat16 *C, - int64_t batchCount, - int64_t strideA, - int64_t strideB) const { -#if CUDA_VERSION >= 11000 - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int64_t lda = (transA == CblasNoTrans) ? K : M; - int64_t ldb = (transB == CblasNoTrans) ? N : K; - int64_t ldc = N; - - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - const int64_t strideC = M * N; - - float h_alpha = static_cast(alpha); - float h_beta = static_cast(beta); - - cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; - bool use_tensor_op_math = dev_ctx_.tensor_core_available(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False"); - if (M > INT_MAX_VALUE || N > INT_MAX_VALUE || K > INT_MAX_VALUE || - batchCount > INT_MAX_VALUE) { -#if CUDA_VERSION >= 12030 && defined(__linux__) - PADDLE_THROW(common::errors::Unimplemented( - "cublasGemmStridedBatchedEx_64 is not complete")); -#else - PADDLE_THROW(common::errors::Unimplemented( - "cublasGemmStridedBatchedEx_64 is not supported on cuda < 12.3")); -#endif // CUDA_VERSION >= 12030 - } else { - dev_ctx_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::cublasGemmStridedBatchedEx(handle, - cuTransB, - cuTransA, - static_cast(N), - static_cast(M), - static_cast(K), - &h_alpha, - B, - CUDA_R_16BF, - static_cast(ldb), - strideB, - A, - CUDA_R_16BF, - static_cast(lda), - strideA, - &h_beta, - C, - CUDA_R_16BF, - static_cast(ldc), - strideC, - static_cast(batchCount), - CUBLAS_COMPUTE_32F, - algo)); - }); - } -#else - // raise error - PADDLE_THROW(common::errors::Unimplemented( - "cublasGemmStridedBatchedEx with bfloat16 is not supported on cuda <= " - "11")); -#endif // CUDA_VERSION >= 11000 -} - -template <> -template <> -inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int64_t M, - int64_t N, - int64_t K, - float alpha, - const phi::dtype::bfloat16 *A, - const phi::dtype::bfloat16 *B, - float beta, - phi::dtype::bfloat16 *C, - int64_t batchCount, - int64_t strideA, - int64_t strideB) const { -#if CUDA_VERSION >= 11000 - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - int ldc = N; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - const int64_t strideC = M * N; - - float h_alpha = alpha; - float h_beta = beta; - - cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; - bool use_tensor_op_math = dev_ctx_.tensor_core_available(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False"); - if (M > INT_MAX_VALUE || N > INT_MAX_VALUE || K > INT_MAX_VALUE || - batchCount > INT_MAX_VALUE) { -#if CUDA_VERSION >= 12030 && defined(__linux__) - PADDLE_THROW(common::errors::Unimplemented( - "cublasGemmStridedBatchedEx_64 is not complete")); -#else - PADDLE_THROW(common::errors::Unimplemented( - "cublasGemmStridedBatchedEx_64 is not supported on cuda < 12.3")); -#endif // CUDA_VERSION >= 12030 - } else { - dev_ctx_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::cublasGemmStridedBatchedEx(handle, - cuTransB, - cuTransA, - static_cast(N), - static_cast(M), - static_cast(K), - &h_alpha, - B, - CUDA_R_16BF, - static_cast(ldb), - strideB, - A, - CUDA_R_16BF, - static_cast(lda), - strideA, - &h_beta, - C, - CUDA_R_16BF, - static_cast(ldc), - strideC, - static_cast(batchCount), - CUBLAS_COMPUTE_32F, - algo)); - }); - } -#else - // raise error - PADDLE_THROW(common::errors::Unimplemented( - "cublasGemmStridedBatchedEx with bfloat16 is not supported on cuda <= " - "11")); -#endif // CUDA_VERSION >= 11000 -} - -// /*** -// * Uknow bug, parameters dislocation when calling BatchedGEMM. -// * Reference: paddle github PR #45530 and #55612 -// */ -// template <> -// template <> -// inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, -// CBLAS_TRANSPOSE transB, -// int M, -// int N, -// int K, -// float16 alpha, -// const float16 *A, -// const float16 *B, -// float16 beta, -// float16 *C, -// int batchCount, -// int64_t strideA, -// int64_t strideB) const { -// // Note that cublas follows fortran order, so the order is different from -// // the cblas convention. -// int lda = (transA == CblasNoTrans) ? K : M; -// int ldb = (transB == CblasNoTrans) ? N : K; -// int ldc = N; -// cublasOperation_t cuTransA = -// (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; -// cublasOperation_t cuTransB = -// (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; -// const int64_t strideC = M * N; - -// #if CUDA_VERSION >= 9010 -// if ((FLAGS_enable_cublas_tensor_op_math && -// (std::is_same::value)) || -// std::is_same::value) { -// cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; -// bool use_tensor_op_math = dev_ctx_.tensor_core_available(); -// if (use_tensor_op_math) { -// algo = CUBLAS_GEMM_DFALT_TENSOR_OP; -// } -// VLOG(5) << "use_tensor_op_math: " -// << (use_tensor_op_math ? "True" : "False"); -// VLOG(4) << "use_half_precision_compute_type: " -// << FLAGS_gemm_use_half_precision_compute_type; - -// auto fp = std::is_same::value ? CUDA_R_32F : CUDA_R_16F; -// #if CUDA_VERSION >= 11000 -// auto compute_type = CUBLAS_COMPUTE_32F; -// #else -// auto compute_type = CUDA_R_32F; -// #endif - -// float h_alpha = static_cast(alpha); -// float h_beta = static_cast(beta); -// void *a = static_cast(&h_alpha); -// void *b = static_cast(&h_beta); -// // set ComputeType as CUDA_R_32F for fp16, for better accuracy -// if (FLAGS_gemm_use_half_precision_compute_type == true && -// std::is_same::value) { -// a = static_cast(&alpha); -// b = static_cast(&beta); -// #if CUDA_VERSION >= 11000 -// compute_type = CUBLAS_COMPUTE_16F; -// #else -// compute_type = CUDA_R_16F; -// #endif -// } - -// dev_ctx_.TensorCoreCublasCallIfAvailable( -// [&](cublasHandle_t handle) { -// PADDLE_ENFORCE_GPU_SUCCESS( -// phi::dynload::cublasGemmStridedBatchedEx(handle, -// cuTransB, -// cuTransA, -// N, -// M, -// K, -// a, -// B, -// fp, -// ldb, -// strideB, -// A, -// fp, -// lda, -// strideA, -// b, -// C, -// fp, -// ldc, -// strideC, -// batchCount, -// compute_type, -// algo)); -// }); -// } else { -// #endif // CUDA_VERSION >= 9010 - -// dev_ctx_.CublasCall( -// [&](cublasHandle_t handle) { -// CUBlas::GEMM_STRIDED_BATCH(handle, -// cuTransB, -// cuTransA, -// N, -// M, -// K, -// &alpha, -// B, -// ldb, -// strideB, -// A, -// lda, -// strideA, -// &beta, -// C, -// ldc, -// strideC, -// batchCount); -// }, -// dev_ctx_.stream()); - -// #if CUDA_VERSION >= 9010 -// } -// #endif // CUDA_VERSION >= 9010 -// } - -// /*** -// * Uknow bug, parameters dislocation when calling BatchedGEMM. -// * Reference: paddle github PR #45530 and #55612 -// */ -// template <> -// template <> -// inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, -// CBLAS_TRANSPOSE transB, -// int M, -// int N, -// int K, -// double alpha, -// const double *A, -// const double *B, -// double beta, -// double *C, -// int batchCount, -// int64_t strideA, -// int64_t strideB) const { -// // Note that cublas follows fortran order, so the order is different from -// // the cblas convention. -// int lda = (transA == CblasNoTrans) ? K : M; -// int ldb = (transB == CblasNoTrans) ? N : K; -// int ldc = N; -// cublasOperation_t cuTransA = -// (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; -// cublasOperation_t cuTransB = -// (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; -// const int64_t strideC = M * N; -// dev_ctx_.CublasCall( -// [&](cublasHandle_t handle) { -// PADDLE_ENFORCE_GPU_SUCCESS( -// phi::dynload::cublasDgemmStridedBatched(handle, -// cuTransB, -// cuTransA, -// N, -// M, -// K, -// &alpha, -// B, -// ldb, -// strideB, -// A, -// lda, -// strideA, -// &beta, -// C, -// ldc, -// strideC, -// batchCount)); -// }, -// dev_ctx_.stream()); -// } - -// template <> -// template <> -// inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, -// CBLAS_TRANSPOSE transB, -// int M, -// int N, -// int K, -// phi::dtype::bfloat16 alpha, -// const phi::dtype::bfloat16 *A, -// const phi::dtype::bfloat16 *B, -// phi::dtype::bfloat16 beta, -// phi::dtype::bfloat16 *C, -// int batchCount, -// int64_t strideA, -// int64_t strideB) const { -// #if CUDA_VERSION >= 11000 -// // Note that cublas follows fortran order, so the order is different from -// // the cblas convention. -// int lda = (transA == CblasNoTrans) ? K : M; -// int ldb = (transB == CblasNoTrans) ? N : K; -// int ldc = N; -// cublasOperation_t cuTransA = -// (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; -// cublasOperation_t cuTransB = -// (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; -// const int64_t strideC = M * N; - -// float h_alpha = static_cast(alpha); -// float h_beta = static_cast(beta); - -// cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; -// bool use_tensor_op_math = dev_ctx->tensor_core_available(); -// if (use_tensor_op_math) { -// algo = CUBLAS_GEMM_DFALT_TENSOR_OP; -// } -// VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : -// "False"); - -// dev_ctx_.TensorCoreCublasCallIfAvailable( -// [&](cublasHandle_t handle) { -// PADDLE_ENFORCE_GPU_SUCCESS( -// phi::dynload::cublasGemmStridedBatchedEx(handle, -// cuTransB, -// cuTransA, -// N, -// M, -// K, -// &h_alpha, -// B, -// CUDA_R_16BF, -// ldb, -// strideB, -// A, -// CUDA_R_16BF, -// lda, -// strideA, -// &h_beta, -// C, -// CUDA_R_16BF, -// ldc, -// strideC, -// batchCount, -// CUBLAS_COMPUTE_32F, -// algo)); -// }); -// #else -// // raise error -// PADDLE_THROW(phi::errors::Unimplemented( -// "cublasGemmStridedBatchedEx with bfloat16 is not supported on cuda <= " -// "11")); -// #endif // CUDA_VERSION >= 11000 -// } - -template <> -template -void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int M, - int N, - int K, - T alpha, - const T **A, - const T **B, - T beta, - T **C, - int batchCount) const { - for (int k = 0; k < batchCount; ++k) { - this->template GEMM( - transA, transB, M, N, K, alpha, A[k], B[k], beta, C[k]); - } -} - -#if defined(__NVCC__) -template <> -template <> -inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int M, - int N, - int K, - double alpha, - const double **A, - const double **B, - double beta, - double **C, - int batchCount) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - int ldc = N; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - thrust::device_vector A_ptr(A, A + batchCount); - thrust::device_vector B_ptr(B, B + batchCount); - thrust::device_vector C_ptr(C, C + batchCount); - - dev_ctx_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GEMM_BATCH(handle, - cuTransB, - cuTransA, - N, - M, - K, - &alpha, - B_ptr.data().get(), - ldb, - A_ptr.data().get(), - lda, - &beta, - C_ptr.data().get(), - ldc, - batchCount); - }); -} - -template <> -template <> -inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int M, - int N, - int K, - float alpha, - const float **A, - const float **B, - float beta, - float **C, - int batchCount) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - int ldc = N; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - thrust::device_vector A_ptr(A, A + batchCount); - thrust::device_vector B_ptr(B, B + batchCount); - thrust::device_vector C_ptr(C, C + batchCount); - - dev_ctx_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GEMM_BATCH(handle, - cuTransB, - cuTransA, - N, - M, - K, - &alpha, - B_ptr.data().get(), - ldb, - A_ptr.data().get(), - lda, - &beta, - C_ptr.data().get(), - ldc, - batchCount); - }); -} - -template <> -template <> -inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int M, - int N, - int K, - phi::dtype::float16 alpha, - const phi::dtype::float16 **A, - const phi::dtype::float16 **B, - phi::dtype::float16 beta, - phi::dtype::float16 **C, - int batchCount) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - int ldc = N; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - - PADDLE_ENFORCE_GE( - dev_ctx_.GetComputeCapability(), - 53, - phi::errors::InvalidArgument( - "cublas fp16 gemm requires GPU compute capability >= 53," - "but received %d", - dev_ctx_.GetComputeCapability())); - float f_alpha = static_cast(alpha); - float f_beta = static_cast(beta); - auto &cuda_ctx = const_cast(dev_ctx_); - CUBlas::GEMM_BATCH(&cuda_ctx, - cuTransB, - cuTransA, - N, - M, - K, - &f_alpha, - B, - CUDA_R_16F, - ldb, - A, - CUDA_R_16F, - lda, - &f_beta, - C, - CUDA_R_16F, - ldc, - batchCount, - CUBLAS_COMPUTE_32F); -} - -template <> -template <> -inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int M, - int N, - int K, - phi::dtype::bfloat16 alpha, - const phi::dtype::bfloat16 **A, - const phi::dtype::bfloat16 **B, - phi::dtype::bfloat16 beta, - phi::dtype::bfloat16 **C, - int batchCount) const { -#if CUDA_VERSION >= 11000 - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - int ldc = N; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - - // PADDLE_ENFORCE_GE( - // dev_ctx_.GetComputeCapability(), - // 80, - // phi::errors::InvalidArgument( - // "cublas bf16 gemm requires GPU compute capability >= 80," - // "but received %d", - // dev_ctx_.GetComputeCapability())); - - float f_alpha = static_cast(alpha); - float f_beta = static_cast(beta); - - cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; - bool use_tensor_op_math = dev_ctx_.tensor_core_available(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False"); - - thrust::device_vector A_ptr(A, A + batchCount); - thrust::device_vector B_ptr(B, B + batchCount); - thrust::device_vector C_ptr(C, C + batchCount); - dev_ctx_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::cublasGemmBatchedEx(handle, - cuTransB, - cuTransA, - N, - M, - K, - &f_alpha, - B_ptr.data().get(), - CUDA_R_16BF, - ldb, - A_ptr.data().get(), - CUDA_R_16BF, - lda, - &f_beta, - C_ptr.data().get(), - CUDA_R_16BF, - ldc, - batchCount, - CUBLAS_COMPUTE_32F, - algo)); - }); -#else - // raise error - PADDLE_THROW(phi::errors::Unimplemented( - "cublasGemmBatchedEx with bfloat16 is not supported on cuda <= 11")); - -#endif // CUDA_VERSION >= 11000 -} -#endif - -template <> -template -void Blas::TRSM(CBLAS_SIDE side, - CBLAS_UPLO uplo, - CBLAS_TRANSPOSE transA, - CBLAS_DIAG diag, - int M, - int N, - T alpha, - const T *A, - int lda, - T *B, - int ldb) const { - // solve row major `op ( A ) X = α B` by taking it as `X' op ( A' ) = α B'` - // where ' stands for transpose - cublasSideMode_t cuSide = - (side == CblasLeft) ? CUBLAS_SIDE_RIGHT : CUBLAS_SIDE_LEFT; - cublasFillMode_t cuUplo = - (uplo == CblasLower) ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER; - // use CUBLAS_OP_C (conjugate transpose) for complex - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasDiagType_t cuDiag = - (diag == CblasUnit) ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT; - - dev_ctx_.CublasCall([&](cublasHandle_t handle) { - CUBlas::TRSM( - handle, cuSide, cuUplo, cuTransA, cuDiag, N, M, &alpha, A, lda, B, ldb); - }); -} - -template <> -template -void Blas::BatchedGETRF( - int n, T **a, int *ipiv, int *info, int batch_size) const { - dev_ctx_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GETRF_BATCH(handle, n, a, n, ipiv, info, batch_size); - }); -} - -template <> -template -void Blas::BatchedGETRI(int n, - const T **a, - const int *ipiv, - T **a_inv, - int *info, - int batch_size) const { - PADDLE_ENFORCE_NE( - a_inv, - a, - phi::errors::InvalidArgument( - "cuBLAS fuction 'cublasgetrfBatched' cannot be executed " - "in-place. The memory space of output matrix (address: %p) cannot " - "overlap memory space of input matrix (address: %p).", - a_inv, - a)); - dev_ctx_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GETRI_BATCH(handle, n, a, n, ipiv, a_inv, n, info, batch_size); - }); -} - -template <> -template -void Blas::BatchedMatInv( - int n, const T **a, T **a_inv, int *info, int batch_size) const { - dev_ctx_.CublasCall([&](cublasHandle_t handle) { - CUBlas::MATINV_BATCH(handle, n, a, n, a_inv, n, info, batch_size); - }); -} - -template <> -template -void Blas::BatchedGETRS(CBLAS_TRANSPOSE trans, - int n, - int nrhs, - const T **a, - int lda, - int *ipiv, - T **b, - int ldb, - int *info, - int batch_size) const { - // use CUBLAS_OP_C (conjugate transpose) for complex - cublasOperation_t cuTrans = - (trans == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - dev_ctx_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GETRS_BATCH( - handle, cuTrans, n, nrhs, a, lda, ipiv, b, ldb, info, batch_size); - }); -} - -template <> -template -void Blas::BatchedTRSM(CBLAS_SIDE side, - CBLAS_UPLO uplo, - CBLAS_TRANSPOSE transA, - CBLAS_DIAG diag, - int M, - int N, - T alpha, - const T **A, - int lda, - T **B, - int ldb, - int batch_size) const { - // solve row major `op ( A ) X = α B` by taking it as `X' op ( A' ) = α B'` - // where ' stands for transpose - cublasSideMode_t cuSide = - (side == CblasLeft) ? CUBLAS_SIDE_RIGHT : CUBLAS_SIDE_LEFT; - cublasFillMode_t cuUplo = - (uplo == CblasLower) ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER; - // use CUBLAS_OP_C (conjugate transpose) for complex - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasDiagType_t cuDiag = - (diag == CblasUnit) ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT; - - dev_ctx_.CublasCall([&](cublasHandle_t handle) { - CUBlas::TRSM_BATCH(handle, - cuSide, - cuUplo, - cuTransA, - cuDiag, - N, - M, - &alpha, - A, - lda, - B, - ldb, - batch_size); - }); -} - -} // namespace funcs -} // namespace phi diff --git a/backends/metax_gpu/kernels/funcs/blas/blas_impl.h b/backends/metax_gpu/kernels/funcs/blas/blas_impl.h deleted file mode 100644 index cb59d73bef8..00000000000 --- a/backends/metax_gpu/kernels/funcs/blas/blas_impl.h +++ /dev/null @@ -1,2003 +0,0 @@ -// 2024 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights -// Reserved. -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once -#include -#include -#include -#include - -#include "paddle/phi/backends/cpu/cpu_context.h" -#include "paddle/phi/common/bfloat16.h" -#include "paddle/phi/common/complex.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -#define INT_MAX_VALUE 2147483647 - -namespace phi { -namespace funcs { - -namespace detail { -template -static void axpy( - int n, const T alpha, const T *x, const int incx, T *y, const int incy) { - // Y = Y + alpha * X - while (n-- > 0) { - *y += alpha * *x; - y = y + incy; - x = x + incx; - } -} -} // namespace detail - -template -struct CBlas; - -template <> -struct CBlas { - template - static void VCOPY(ARGS... args) { - PADDLE_THROW(phi::errors::Unimplemented( - "Blas VCOPY do not supported on CPU, please check your code")); - } -}; - -template <> -struct CBlas { - template - static void VCOPY(ARGS... args) { - PADDLE_THROW(phi::errors::Unimplemented( - "Blas VCOPY do not supported on CPU, please check your code")); - } -}; - -template <> -struct CBlas { - template - static void AXPY(ARGS... args) { - detail::axpy(args...); - } - - template - static void VCOPY(ARGS... args UNUSED) { - PADDLE_THROW(phi::errors::Unimplemented( - "Blas VCOPY do not supported on CPU with bfloat16," - " please check your code")); - } - - template - static void VADD(int n, - const phi::dtype::bfloat16 *x, - const phi::dtype::bfloat16 *y, - phi::dtype::bfloat16 *z) { - for (int i = 0; i < n; ++i) { - z[i] = x[i] + y[i]; - } - } - - template - static void VMUL(int n, - const phi::dtype::bfloat16 *x, - const phi::dtype::bfloat16 *y, - phi::dtype::bfloat16 *z) { - for (int i = 0; i < n; ++i) { - z[i] = x[i] * y[i]; - } - } - - template - static void VSUB(int n, - const phi::dtype::bfloat16 *x, - const phi::dtype::bfloat16 *y, - phi::dtype::bfloat16 *z) { - for (int i = 0; i < n; ++i) { - z[i] = x[i] - y[i]; - } - } -}; - -#ifdef PADDLE_WITH_MKLML -template <> -struct CBlas { - template - static void GEMM(ARGS... args) { - phi::dynload::cblas_sgemm(args...); - } - - template - static float *GEMM_ALLOC(ARGS... args) { - return phi::dynload::cblas_sgemm_alloc(args...); - } - - template - static void GEMM_PACK(ARGS... args) { - phi::dynload::cblas_sgemm_pack(args...); - } - - template - static void GEMM_COMPUTE(ARGS... args) { - phi::dynload::cblas_sgemm_compute(args...); - } - - template - static void GEMM_FREE(ARGS... args) { - phi::dynload::cblas_sgemm_free(args...); - } - -#ifdef PADDLE_WITH_LIBXSMM - template - static void SMM_GEMM(ARGS... args) { - libxsmm_sgemm(args...); - } -#endif - - template - static void AXPY(ARGS... args) { - phi::dynload::cblas_saxpy(args...); - } - - template - static void VCOPY(ARGS... args) { - phi::dynload::cblas_scopy(args...); - } - - template - static void GEMV(ARGS... args) { - phi::dynload::cblas_sgemv(args...); - } - - template - static float DOT(ARGS... args) { - return phi::dynload::cblas_sdot(args...); - } - - template - static void SCAL(ARGS... args) { - phi::dynload::cblas_sscal(args...); - } - - template - static float ASUM(ARGS... args) { - return phi::dynload::cblas_sasum(args...); - } - - template - static void GEMM_BATCH(ARGS... args) { - phi::dynload::cblas_sgemm_batch(args...); - } - - template - static void VADD(ARGS... args) { - phi::dynload::vsAdd(args...); - } - - template - static void VSUB(ARGS... args) { - phi::dynload::vsSub(args...); - } - - template - static void VMUL(ARGS... args) { - phi::dynload::vsMul(args...); - } - - template - static void VDIV(ARGS... args) { - phi::dynload::vsDiv(args...); - } - - template - static void VEXP(ARGS... args) { - phi::dynload::vsExp(args...); - } - - template - static void VSQUARE(ARGS... args) { - phi::dynload::vsSqr(args...); - } - - template - static void VPOW(ARGS... args) { - phi::dynload::vsPowx(args...); - } - - template - static void VINV(ARGS... args) { - phi::dynload::vsInv(args...); - } - - template - static void VMERF(ARGS... args) { - phi::dynload::vmsErf(args...); - } -#if !defined(_WIN32) - template - static void CSRMM(ARGS... args) { - phi::dynload::mkl_scsrmm(args...); - } -#endif - - template - static void TRSM(ARGS... args) { - phi::dynload::cblas_strsm(args...); - } -}; - -template <> -struct CBlas { - template - static void GEMM(ARGS... args) { - phi::dynload::cblas_dgemm(args...); - } - - template - static double *GEMM_ALLOC(ARGS... args) { - return phi::dynload::cblas_dgemm_alloc(args...); - } - - template - static void GEMM_PACK(ARGS... args) { - phi::dynload::cblas_dgemm_pack(args...); - } - - template - static void GEMM_COMPUTE(ARGS... args) { - phi::dynload::cblas_dgemm_compute(args...); - } - - template - static void GEMM_FREE(ARGS... args) { - phi::dynload::cblas_dgemm_free(args...); - } - -#ifdef PADDLE_WITH_LIBXSMM - template - static void SMM_GEMM(ARGS... args) { - libxsmm_dgemm(args...); - } -#endif - - template - static void AXPY(ARGS... args) { - phi::dynload::cblas_daxpy(args...); - } - - template - static void VCOPY(ARGS... args) { - phi::dynload::cblas_dcopy(args...); - } - - template - static void GEMV(ARGS... args) { - phi::dynload::cblas_dgemv(args...); - } - - template - static double DOT(ARGS... args) { - return phi::dynload::cblas_ddot(args...); - } - - template - static void SCAL(ARGS... args) { - phi::dynload::cblas_dscal(args...); - } - - template - static double ASUM(ARGS... args) { - return phi::dynload::cblas_dasum(args...); - } - - template - static void GEMM_BATCH(ARGS... args) { - phi::dynload::cblas_dgemm_batch(args...); - } - - template - static void VADD(ARGS... args) { - phi::dynload::vdAdd(args...); - } - - template - static void VSUB(ARGS... args) { - phi::dynload::vdSub(args...); - } - - template - static void VMUL(ARGS... args) { - phi::dynload::vdMul(args...); - } - - template - static void VDIV(ARGS... args) { - phi::dynload::vdDiv(args...); - } - - template - static void VEXP(ARGS... args) { - phi::dynload::vdExp(args...); - } - - template - static void VSQUARE(ARGS... args) { - phi::dynload::vdSqr(args...); - } - - template - static void VPOW(ARGS... args) { - phi::dynload::vdPowx(args...); - } - - template - static void VINV(ARGS... args) { - phi::dynload::vdInv(args...); - } - - template - static void VMERF(ARGS... args) { - phi::dynload::vmdErf(args...); - } -#if !defined(_WIN32) - template - static void CSRMM(ARGS... args) { - phi::dynload::mkl_dcsrmm(args...); - } -#endif - - template - static void TRSM(ARGS... args) { - phi::dynload::cblas_dtrsm(args...); - } -}; - -template <> -struct CBlas> { - template - static void AXPY(int n, - const phi::dtype::complex alpha, - const phi::dtype::complex *X, - const int incX, - phi::dtype::complex *Y, - const int incY) { - phi::dynload::cblas_caxpy(n, &alpha, X, incX, Y, incY); - } - - template - static void VCOPY(ARGS... args) { - phi::dynload::cblas_ccopy(args...); - } - - // the libmklml_intel.so paddle used has no vcAdd, vcSub, - // vcMul, vcDiv apis before rebuild from source - // so replace with the raw operator methods - /* - template - static void VADD(ARGS... args) { - phi::dynload::vcAdd(args...); - } - - template - static void VSUB(ARGS... args) { - phi::dynload::vcSub(args...); - } - - template - static void VMUL(ARGS... args) { - phi::dynload::vcMul(args...); - } - - template - static void VDIV(ARGS... args) { - phi::dynload::vcDiv(args...); - } - */ - - template - static void VADD(int n, - const phi::dtype::complex *a, - const phi::dtype::complex *b, - phi::dtype::complex *y) { - for (int i = 0; i < n; ++i) { - y[i] = a[i] + b[i]; - } - } - - template - static void VSUB(int n, - const phi::dtype::complex *a, - const phi::dtype::complex *b, - phi::dtype::complex *y) { - for (int i = 0; i < n; ++i) { - y[i] = a[i] - b[i]; - } - } - - template - static void VMUL(int n, - const phi::dtype::complex *a, - const phi::dtype::complex *b, - phi::dtype::complex *y) { - for (int i = 0; i < n; ++i) { - y[i] = a[i] * b[i]; - } - } - template - static void VDIV(int n, - const phi::dtype::complex *a, - const phi::dtype::complex *b, - phi::dtype::complex *y) { - for (int i = 0; i < n; ++i) { - y[i] = a[i] / b[i]; - } - } - - template - static void GEMV(CBLAS_LAYOUT layout, - CBLAS_TRANSPOSE trans, - int M, - int N, - phi::dtype::complex alpha, - const phi::dtype::complex *A, - int lda, - const phi::dtype::complex *X, - int incx, - phi::dtype::complex beta, - phi::dtype::complex *Y, - int incy) { - const void *a_ = (const void *)(A); - const void *x_ = (const void *)(X); - void *y_ = static_cast(Y); - phi::dynload::cblas_cgemv( - layout, trans, M, N, &alpha, a_, lda, x_, incx, &beta, y_, incy); - } - - template - static void GEMM(CBLAS_LAYOUT layout, - CBLAS_TRANSPOSE trans_a, - CBLAS_TRANSPOSE trans_b, - int M, - int N, - int K, - phi::dtype::complex alpha, - const phi::dtype::complex *A, - int lda, - const phi::dtype::complex *B, - int ldb, - phi::dtype::complex beta, - phi::dtype::complex *C, - int ldc) { - const void *a_ = (const void *)(A); - const void *b_ = (const void *)(B); - void *c_ = static_cast(C); - phi::dynload::cblas_cgemm(layout, - trans_a, - trans_b, - M, - N, - K, - &alpha, - a_, - lda, - b_, - ldb, - &beta, - c_, - ldc); - } - - static void TRSM(CBLAS_LAYOUT layout, - CBLAS_SIDE side, - CBLAS_UPLO uplo, - CBLAS_TRANSPOSE trans_a, - CBLAS_DIAG diag, - int M, - int N, - phi::dtype::complex alpha, - const phi::dtype::complex *A, - int lda, - phi::dtype::complex *B, - int ldb) { - const void *a_ = (const void *)(A); - void *b_ = static_cast(B); - phi::dynload::cblas_ctrsm( - layout, side, uplo, trans_a, diag, M, N, &alpha, a_, lda, b_, ldb); - } - - template - static void GEMM_BATCH(CBLAS_LAYOUT layout, - CBLAS_TRANSPOSE *trans_a, - CBLAS_TRANSPOSE *trans_b, - int *M, - int *N, - int *K, - phi::dtype::complex *alpha, - const phi::dtype::complex **A, - const int *lda, - const phi::dtype::complex **B, - const int *ldb, - phi::dtype::complex *beta, - phi::dtype::complex **C, - const int *ldc, - int group_count, - int *group_size) { - const void **A_void = (const void **)(&(*A)); - const void **B_void = (const void **)(&(*B)); - void **C_void = reinterpret_cast(C); - - phi::dynload::cblas_cgemm_batch(layout, - trans_a, - trans_b, - M, - N, - K, - alpha, - A_void, - lda, - B_void, - ldb, - beta, - C_void, - ldc, - group_count, - group_size); - } - - template - static void GEMM_EX(ARGS... args) { - phi::dynload::cblas_cgemm_batch(args...); - } -}; - -template <> -struct CBlas> { - template - static void AXPY(int n, - const phi::dtype::complex alpha, - const phi::dtype::complex *X, - const int incX, - phi::dtype::complex *Y, - const int incY) { - phi::dynload::cblas_zaxpy(n, &alpha, X, incX, Y, incY); - } - - template - static void VCOPY(ARGS... args) { - phi::dynload::cblas_zcopy(args...); - } - - // the libmklml_intel.so paddle used has no vzAdd, vzSub, - // vzMul, vzDiv apis before rebuild from source - // so replace with the raw operator methods - /* - template - static void VADD(ARGS... args) { - phi::dynload::vzAdd(args...); - } - - template - static void VSUB(ARGS... args) { - phi::dynload::vzSub(args...); - } - - template - static void VMUL(ARGS... args) { - phi::dynload::vzMul(args...); - } - - template - static void VDIV(ARGS... args) { - phi::dynload::vzDiv(args...); - } - */ - - template - static void VADD(int n, - const phi::dtype::complex *a, - const phi::dtype::complex *b, - phi::dtype::complex *y) { - for (int i = 0; i < n; ++i) { - y[i] = a[i] + b[i]; - } - } - - template - static void VSUB(int n, - const phi::dtype::complex *a, - const phi::dtype::complex *b, - phi::dtype::complex *y) { - for (int i = 0; i < n; ++i) { - y[i] = a[i] - b[i]; - } - } - - template - static void VMUL(int n, - const phi::dtype::complex *a, - const phi::dtype::complex *b, - phi::dtype::complex *y) { - for (int i = 0; i < n; ++i) { - y[i] = a[i] * b[i]; - } - } - template - static void VDIV(int n, - const phi::dtype::complex *a, - const phi::dtype::complex *b, - phi::dtype::complex *y) { - for (int i = 0; i < n; ++i) { - y[i] = a[i] / b[i]; - } - } - - template - static void GEMV(CBLAS_LAYOUT layout, - CBLAS_TRANSPOSE trans, - int M, - int N, - phi::dtype::complex alpha, - const phi::dtype::complex *A, - int lda, - const phi::dtype::complex *X, - int incx, - phi::dtype::complex beta, - phi::dtype::complex *Y, - int incy) { - const void *a_ = (const void *)(A); - const void *x_ = (const void *)(X); - void *y_ = static_cast(Y); - phi::dynload::cblas_zgemv( - layout, trans, M, N, &alpha, a_, lda, x_, incx, &beta, y_, incy); - } - - template - static void GEMM(CBLAS_LAYOUT layout, - CBLAS_TRANSPOSE trans_a, - CBLAS_TRANSPOSE trans_b, - int M, - int N, - int K, - phi::dtype::complex alpha, - const phi::dtype::complex *A, - int lda, - const phi::dtype::complex *B, - int ldb, - phi::dtype::complex beta, - phi::dtype::complex *C, - int ldc) { - const void *a_ = (const void *)(A); - const void *b_ = (const void *)(B); - void *c_ = static_cast(C); - phi::dynload::cblas_zgemm(layout, - trans_a, - trans_b, - M, - N, - K, - &alpha, - a_, - lda, - b_, - ldb, - &beta, - c_, - ldc); - } - - static void TRSM(CBLAS_LAYOUT layout, - CBLAS_SIDE side, - CBLAS_UPLO uplo, - CBLAS_TRANSPOSE trans_a, - CBLAS_DIAG diag, - int M, - int N, - phi::dtype::complex alpha, - const phi::dtype::complex *A, - int lda, - phi::dtype::complex *B, - int ldb) { - const void *a_ = (const void *)(A); - void *b_ = static_cast(B); - phi::dynload::cblas_ztrsm( - layout, side, uplo, trans_a, diag, M, N, &alpha, a_, lda, b_, ldb); - } - - template - static void GEMM_BATCH(CBLAS_LAYOUT layout, - CBLAS_TRANSPOSE *trans_a, - CBLAS_TRANSPOSE *trans_b, - int *M, - int *N, - int *K, - phi::dtype::complex *alpha, - const phi::dtype::complex **A, - const int *lda, - const phi::dtype::complex **B, - const int *ldb, - phi::dtype::complex *beta, - phi::dtype::complex **C, - const int *ldc, - int group_count, - int *group_size) { - const void **A_void = (const void **)(&(*A)); - const void **B_void = (const void **)(&(*B)); - void **C_void = reinterpret_cast(C); - - phi::dynload::cblas_zgemm_batch(layout, - trans_a, - trans_b, - M, - N, - K, - alpha, - A_void, - lda, - B_void, - ldb, - beta, - C_void, - ldc, - group_count, - group_size); - } - - template - static void GEMM_EX(ARGS... args) { - phi::dynload::cblas_zgemm_batch(args...); - } -}; - -#else - -template <> -struct CBlas { - template - static void GEMM(ARGS... args) { - cblas_sgemm(args...); - } - - template - static void AXPY(ARGS... args) { - cblas_saxpy(args...); - } - - template - static void VCOPY(ARGS... args) { - cblas_scopy(args...); - } - - template - static void GEMV(ARGS... args) { - cblas_sgemv(args...); - } - - template - static void TRSM(ARGS... args) { - cblas_strsm(args...); - } -}; - -template <> -struct CBlas { - template - static void GEMM(ARGS... args) { - cblas_dgemm(args...); - } - - template - static void AXPY(ARGS... args) { - cblas_daxpy(args...); - } - - template - static void VCOPY(ARGS... args) { - cblas_dcopy(args...); - } - - template - static void GEMV(ARGS... args) { - cblas_dgemv(args...); - } - - template - static void TRSM(ARGS... args) { - cblas_dtrsm(args...); - } -}; - -template <> -struct CBlas> { - template - static void VCOPY(ARGS... args) { - cblas_ccopy(args...); - } - - template - static void AXPY(int n, - const phi::dtype::complex alpha, - const phi::dtype::complex *X, - const int incX, - phi::dtype::complex *Y, - const int incY) { - cblas_caxpy(n, &alpha, X, incX, Y, incY); - } - - template - static void GEMV(const CBLAS_LAYOUT layout, - const CBLAS_TRANSPOSE TransA, - const int M, - const int N, - const phi::dtype::complex alpha, - const phi::dtype::complex *A, - const int lda, - const phi::dtype::complex *X, - const int incX, - const phi::dtype::complex beta, - phi::dtype::complex *Y, - const int incY) { - cblas_cgemv(layout, TransA, M, N, &alpha, A, lda, X, incX, &beta, Y, incY); - } - - template - static void GEMM(const CBLAS_LAYOUT layout, - const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, - const int M, - const int N, - const int K, - const phi::dtype::complex alpha, - const phi::dtype::complex *A, - const int lda, - const phi::dtype::complex *B, - const int ldb, - const phi::dtype::complex beta, - phi::dtype::complex *C, - const int ldc) { - cblas_cgemm( - layout, TransA, TransB, M, N, K, &alpha, A, lda, B, ldb, &beta, C, ldc); - } - - static void TRSM(const CBLAS_LAYOUT layout, - const CBLAS_SIDE side, - const CBLAS_UPLO uplo, - const CBLAS_TRANSPOSE transA, - const CBLAS_DIAG diag, - const int M, - const int N, - const phi::dtype::complex alpha, - const phi::dtype::complex *A, - const int lda, - phi::dtype::complex *B, - const int ldb) { - cblas_ctrsm(layout, side, uplo, transA, diag, M, N, &alpha, A, lda, B, ldb); - } -}; - -template <> -struct CBlas> { - template - static void VCOPY(ARGS... args) { - cblas_zcopy(args...); - } - - template - static void AXPY(int n, - const phi::dtype::complex alpha, - const phi::dtype::complex *X, - const int incX, - phi::dtype::complex *Y, - const int incY) { - cblas_zaxpy(n, &alpha, X, incX, Y, incY); - } - - template - static void GEMV(const CBLAS_LAYOUT layout, - const CBLAS_TRANSPOSE TransA, - const int M, - const int N, - const phi::dtype::complex alpha, - const phi::dtype::complex *A, - const int lda, - const phi::dtype::complex *X, - const int incX, - const phi::dtype::complex beta, - phi::dtype::complex *Y, - const int incY) { - cblas_zgemv(layout, TransA, M, N, &alpha, A, lda, X, incX, &beta, Y, incY); - } - - template - static void GEMM(const CBLAS_LAYOUT layout, - const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, - const int M, - const int N, - const int K, - const phi::dtype::complex alpha, - const phi::dtype::complex *A, - const int lda, - const phi::dtype::complex *B, - const int ldb, - const phi::dtype::complex beta, - phi::dtype::complex *C, - const int ldc) { - cblas_zgemm( - layout, TransA, TransB, M, N, K, &alpha, A, lda, B, ldb, &beta, C, ldc); - } - - static void TRSM(const CBLAS_LAYOUT layout, - const CBLAS_SIDE side, - const CBLAS_UPLO uplo, - const CBLAS_TRANSPOSE transA, - const CBLAS_DIAG diag, - const int M, - const int N, - const phi::dtype::complex alpha, - const phi::dtype::complex *A, - const int lda, - phi::dtype::complex *B, - const int ldb) { - cblas_ztrsm(layout, side, uplo, transA, diag, M, N, &alpha, A, lda, B, ldb); - } -}; - -#endif - -template <> -struct CBlas { - static void GEMM(...) { - PADDLE_THROW(phi::errors::Unimplemented( - "float16 GEMM not supported on CPU, please check your code")); - } - - static void SMM_GEMM(...) { - PADDLE_THROW(phi::errors::Unimplemented( - "float16 SMM_GEMM not supported on CPU, please check your code")); - } - static void VMUL(...) { - PADDLE_THROW(phi::errors::Unimplemented( - "float16 VMUL not supported on CPU, please check your code")); - } - static void VEXP(...) { - PADDLE_THROW(phi::errors::Unimplemented( - "float16 VEXP not supported on CPU, please check your code")); - } - static void VSQUARE(...) { - PADDLE_THROW(phi::errors::Unimplemented( - "float16 VSQUARE not supported on CPU, please check your code")); - } - static void VPOW(...) { - PADDLE_THROW(phi::errors::Unimplemented( - "float16 VPOW not supported on CPU, please check your code")); - } - static void DOT(...) { - PADDLE_THROW(phi::errors::Unimplemented( - "float16 DOT not supported on CPU, please check your code")); - }; - static void SCAL(...) { - PADDLE_THROW(phi::errors::Unimplemented( - "float16 SCAL not supported on CPU, please check your code")); - }; - static void ASUM(...) { - PADDLE_THROW(phi::errors::Unimplemented( - "float16 ASUM not supported on CPU, please check your code")); - }; -#ifdef PADDLE_WITH_MKLML - static void GEMM_BATCH(...) { - PADDLE_THROW(phi::errors::Unimplemented( - "float16 GEMM_BATCH not supported on CPU, please check your code")); - } -#endif -}; - -#ifdef PADDLE_WITH_MKLML -template <> -template -T *Blas::GEMM_ALLOC(const CBLAS_IDENTIFIER id, - const int M, - const int N, - const int K) const { - return CBlas::GEMM_ALLOC(id, M, N, K); -} - -template <> -template -void Blas::GEMM_PACK(const CBLAS_IDENTIFIER id, - const CBLAS_TRANSPOSE trans, - int M, - int N, - int K, - const T alpha, - const T *src, - const int ld, - T *dst) const { - CBlas::GEMM_PACK(CblasRowMajor, id, trans, M, N, K, alpha, src, ld, dst); -} - -template <> -template -void Blas::GEMM_COMPUTE(int transA, - int transB, - int M, - int N, - int K, - const T *A, - const int lda, - const T *B, - const int ldb, - T beta, - T *C, - const int ldc) const { - CBlas::GEMM_COMPUTE( - CblasRowMajor, transA, transB, M, N, K, A, lda, B, ldb, beta, C, ldc); -} - -template <> -template -void Blas::GEMM_FREE(T *data) const { - CBlas::GEMM_FREE(data); -} -#endif - -template <> -template -void Blas::GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int64_t M, - int64_t N, - int64_t K, - T alpha, - const T *A, - const T *B, - T beta, - T *C) const { - if (M > INT_MAX_VALUE || N > INT_MAX_VALUE || K > INT_MAX_VALUE) { - PADDLE_THROW( - common::errors::Unimplemented("GEMM not supported for large tensor " - "size on CPU, please check your code!")); - } - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - int ldc = N; - CBlas::GEMM(CblasRowMajor, - transA, - transB, - M, - N, - K, - alpha, - A, - lda, - B, - ldb, - beta, - C, - ldc); -} - -template <> -template -void Blas::GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int64_t M, - int64_t N, - int64_t K, - U alpha, - const T *A, - const T *B, - U beta, - T *C) const { - if (M > INT_MAX_VALUE || N > INT_MAX_VALUE || K > INT_MAX_VALUE) { - PADDLE_THROW( - common::errors::Unimplemented("GEMM not supported for large tensor " - "size on CPU, please check your code!")); - } - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - int ldc = N; - CBlas::GEMM(CblasRowMajor, - transA, - transB, - static_cast(M), - static_cast(N), - static_cast(K), - alpha, - A, - lda, - B, - ldb, - beta, - C, - ldc); -} - -template <> -template -void Blas::GEMM(bool transA, - bool transB, - int M, - int N, - int K, - T alpha, - const T *A, - int lda, - const T *B, - int ldb, - T beta, - T *C, - int ldc) const { - CBlas::GEMM(CblasRowMajor, - transA == false ? CblasNoTrans : CblasTrans, - transB == false ? CblasNoTrans : CblasTrans, - M, - N, - K, - alpha, - A, - lda, - B, - ldb, - beta, - C, - ldc); -} - -template <> -template -void Blas::GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int M, - int N, - int K, - T alpha, - const T *A, - int lda, - const T *B, - int ldb, - T beta, - T *C, - int ldc) const { - CBlas::GEMM(CblasRowMajor, - transA, - transB, - M, - N, - K, - alpha, - A, - lda, - B, - ldb, - beta, - C, - ldc); -} - -template -template -void Blas::MatMul(const phi::DenseTensor &mat_a, - bool trans_a, - const phi::DenseTensor &mat_b, - bool trans_b, - T alpha, - phi::DenseTensor *mat_out, - T beta) const { - const auto &dim_a = mat_a.dims(); - const auto &dim_b = mat_b.dims(); - const auto &dim_out = mat_out->dims(); - PADDLE_ENFORCE_EQ( - dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2, - true, - phi::errors::InvalidArgument( - "The input and output of matmul should be matrix, the dim size must " - "be 2," - "but received dim size input_a:%d, input_b:%d, output:%d", - dim_a.size(), - dim_b.size(), - dim_out.size())); - PADDLE_ENFORCE_EQ( - mat_a.place() == mat_b.place() && mat_a.place() == mat_out->place(), - true, - phi::errors::InvalidArgument("The places of matrices in the matmul " - "should be same, please check your " - "code.")); - - int M = dim_out[0]; - int N = dim_out[1]; - int K = !trans_a ? dim_a[1] : dim_a[0]; - - CBLAS_TRANSPOSE transA = !trans_a ? CblasNoTrans : CblasTrans; - CBLAS_TRANSPOSE transB = !trans_b ? CblasNoTrans : CblasTrans; - - this->GEMM(transA, - transB, - M, - N, - K, - alpha, - mat_a.data(), - mat_b.data(), - beta, - mat_out->data()); -} - -template <> -template -void Blas::AXPY(int n, T alpha, const T *x, T *y) const { - CBlas::AXPY(n, alpha, x, 1, y, 1); -} - -template <> -template -void Blas::VCOPY(int n, const T *x, T *y) const { - CBlas::VCOPY(n, x, 1, y, 1); -} - -template <> -template -void Blas::VADD(int n, const T *x, const T *y, T *z) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VADD(n, x, y, z); -#else - if (x == z) { - this->template AXPY(n, (T)(1.), y, z); - } else { - this->template VCOPY(n, y, z); - this->template AXPY(n, (T)(1.), x, z); - } -#endif -} - -template <> -template -void Blas::VSUB(int n, const T *x, const T *y, T *z) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VSUB(n, x, y, z); -#else - // try to find if openblas support vsub - for (int i = 0; i < n; ++i) { - z[i] = x[i] - y[i]; - } -#endif -} - -template <> -template -void Blas::VMUL(int n, const T *x, const T *y, T *z) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VMUL(n, x, y, z); -#else - // try to find if openblas support vmul - for (int i = 0; i < n; ++i) { - z[i] = x[i] * y[i]; - } -#endif -} - -template <> -template -void Blas::VDIV(int n, const T *x, const T *y, T *z) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VDIV(n, x, y, z); -#else - // try to find if openblas support vdiv - for (int i = 0; i < n; ++i) { - z[i] = x[i] / y[i]; - } -#endif -} - -template <> -template -void Blas::VEXP(int n, const T *x, T *y) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VEXP(n, x, y); -#else - // try to find if openblas support vexp - for (int i = 0; i < n; ++i) { - y[i] = std::exp(x[i]); - } -#endif -} - -template <> -template -void Blas::VSQUARE(int n, const T *x, T *y) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VSQUARE(n, x, y); -#else - for (int i = 0; i < n; ++i) { - y[i] = x[i] * x[i]; - } -#endif -} - -template <> -template -void Blas::VPOW(int n, const T *x, T a, T *y) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VPOW(n, x, a, y); -#else - for (int i = 0; i < n; ++i) { - y[i] = std::pow(x[i], a); - } -#endif -} - -template <> -template -T Blas::DOT(int n, const T *x, const T *y) const { -#ifdef PADDLE_WITH_MKLML - return CBlas::DOT(n, x, 1, y, 1); -#else - // try to find if openblas support cblas_dot - T sum = 0; - for (int i = 0; i < n; ++i) { - sum += x[i] * y[i]; - } - return sum; -#endif -} - -template <> -template -void Blas::SCAL(int n, const T a, T *x) const { -#ifdef PADDLE_WITH_MKLML - CBlas::SCAL(n, a, x, 1); -#else - // try to find if openblas support cblas_scal - for (int i = 0; i < n; ++i) { - x[i] = a * x[i]; - } -#endif -} - -template <> -template -T Blas::ASUM(int n, T *x, int inc) const { - auto sum = static_cast(0.0); -#ifdef PADDLE_WITH_MKLML - sum = CBlas::ASUM(n, x, inc); -#else - // TODO(jczaja): check if openblas does provide cblas_sasum/cblas_dasum - for (int c = 0; c < n; ++c) { - sum += x[c]; - } -#endif - return sum; -} - -template <> -template -void Blas::GEMV(bool trans_a, - int M, - int N, - T alpha, - const T *A, - const T *B, - T beta, - T *C) const { - CBLAS_TRANSPOSE transA = !trans_a ? CblasNoTrans : CblasTrans; - CBlas::GEMV(CblasRowMajor, transA, M, N, alpha, A, N, B, 1, beta, C, 1); -} - -template <> -template -void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int64_t M, - int64_t N, - int64_t K, - T alpha, - const T *A, - const T *B, - T beta, - T *C, - int64_t batchCount, - int64_t strideA, - int64_t strideB) const { - PADDLE_ENFORCE_NOT_NULL( - A, phi::errors::InvalidArgument("Pointer A should not be null.")); - PADDLE_ENFORCE_NOT_NULL( - B, phi::errors::InvalidArgument("Pointer B should not be null.")); - PADDLE_ENFORCE_NOT_NULL( - C, phi::errors::InvalidArgument("Pointer C should not be null.")); - - if (M > INT_MAX_VALUE || N > INT_MAX_VALUE || K > INT_MAX_VALUE) { - PADDLE_THROW( - common::errors::Unimplemented("CPU GEMM not supported for large tensor " - "size.")); - } - -#ifdef PADDLE_WITH_MKLML - if (batchCount > INT_MAX_VALUE) { - PADDLE_THROW(common::errors::Unimplemented( - "CPU GEMM not supported for large batch size in MKLML.")); - } - - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - int ldc = N; - auto a_array = std::vector(batchCount); - auto b_array = std::vector(batchCount); - auto c_array = std::vector(batchCount); - for (int k = 0; k < batchCount; ++k) { - a_array[k] = &A[k * strideA]; - b_array[k] = &B[k * strideB]; - c_array[k] = &C[k * M * N]; - } - - CBlas::GEMM_BATCH(CblasRowMajor, - &transA, - &transB, - reinterpret_cast(&M), - reinterpret_cast(&N), - reinterpret_cast(&K), - &alpha, - a_array.data(), - &lda, - b_array.data(), - &ldb, - &beta, - c_array.data(), - &ldc, - 1 /* group_count */, - reinterpret_cast(&batchCount)); -#else - for (int k = 0; k < batchCount; ++k) { - auto *Ak = &A[k * strideA]; - auto *Bk = &B[k * strideB]; - auto *Ck = &C[k * M * N]; - this->template GEMM(transA, - transB, - reinterpret_cast(M), - reinterpret_cast(N), - reinterpret_cast(K), - alpha, - Ak, - Bk, - beta, - Ck); - } -#endif -} - -template <> -template -void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int M, - int N, - int K, - T alpha, - const T **A, - const T **B, - T beta, - T **C, - int batchCount) const { -#ifdef PADDLE_WITH_MKLML - const int lda = (std::max)((transA == CblasNoTrans) ? K : M, 1); - const int ldb = (std::max)((transB == CblasNoTrans) ? N : K, 1); - const int ldc = (std::max)(N, 1); - CBlas::GEMM_BATCH(CblasRowMajor, - &transA, - &transB, - &M, - &N, - &K, - &alpha, - A, - &lda, - B, - &ldb, - &beta, - C, - &ldc, - 1 /* group_count */, - &batchCount); -#else - for (int k = 0; k < batchCount; ++k) { - this->template GEMM( - transA, transB, M, N, K, alpha, A[k], B[k], beta, C[k]); - } -#endif -} - -#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) && \ - !defined(PADDLE_WITH_HIP) // @{ Group Blas MKLML: BatchedGEMMWithHead -template <> -template -void Blas::BatchedGEMMWithHead(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int W1, - int H1, - int W2, - int H2, - T alpha, - const T *A, - const T *B, - T beta, - T *C, - int batchCount, - int64_t strideA, - int64_t strideB, - int64_t head_number, - bool split_b_vertical) const { - int lda = (transA == CblasNoTrans) ? W1 : H1; - int ldb = (transB == CblasNoTrans) ? W2 : H2; - auto a_array = std::vector(batchCount); - auto b_array = std::vector(batchCount); - auto c_array = std::vector(batchCount); - - if (split_b_vertical) { - int ldc = W2; - int sub_width = W2 / head_number; - - for (int i = 0; i < head_number; i++) { - int sub_matA_offset = (transA == CblasNoTrans) - ? i * (W1 / head_number) - : i * (W1 / head_number) * H1; - int sub_matB_offset = (transB == CblasNoTrans) - ? i * (W2 / head_number) - : i * (W2 / head_number) * H2; - int sub_matC_offset = i * W2 / head_number; - for (int k = 0; k < batchCount; ++k) { - a_array[k] = &A[k * strideA] + sub_matA_offset; - b_array[k] = &B[k * strideB] + sub_matB_offset; - c_array[k] = &C[k * H1 * W2] + sub_matC_offset; - } - - CBlas::GEMM_BATCH(CblasRowMajor, - &transA, - &transB, - &H1, - &sub_width, - &H2, - &alpha, - a_array.data(), - &lda, - b_array.data(), - &ldb, - &beta, - c_array.data(), - &ldc, - 1 /* group_count */, - &batchCount); - } - - } else { - PADDLE_ENFORCE_EQ( - W1, - H2, - phi::errors::InvalidArgument( - "The fisrt matrix width should be same as second matrix height," - "but received fisrt matrix width %d" - ", second matrix height %d", - W1, - H2)); - int ldc = W2 * head_number; - int sub_width = W1 / head_number; - - for (int i = 0; i < head_number; i++) { - int sub_matA_offset = (transA == CblasNoTrans) - ? i * (W1 / head_number) - : i * (W1 / head_number) * H1; - int sub_matB_offset = (transB == CblasNoTrans) - ? i * (W1 / head_number) * W2 - : i * (W1 / head_number); - int sub_matC_offset = i * W2; - for (int k = 0; k < batchCount; ++k) { - a_array[k] = &A[k * strideA] + sub_matA_offset; - b_array[k] = &B[k * strideB] + sub_matB_offset; - c_array[k] = &C[k * H1 * head_number * W2] + sub_matC_offset; - } - - CBlas::GEMM_BATCH(CblasRowMajor, - &transA, - &transB, - &H1, - &W2, - &sub_width, - &alpha, - a_array.data(), - &lda, - b_array.data(), - &ldb, - &beta, - c_array.data(), - &ldc, - 1 /* group_count */, - &batchCount); - } - } -} -#endif // @} End Group Blas MKLML: BatchedGEMMWithHead - -template -template -void Blas::MatMul( - const int M, const int N, const int K, const T *A, const T *B, T *C) const { - this->template GEMM(CblasRowMajor, - CblasNoTrans, - CblasNoTrans, - M, - N, - K, - static_cast(1), - A, - K, - B, - N, - static_cast(0), - C, - N); -} - -template <> -template -void Blas::MatMul( - const int M, const int N, const int K, const T *A, const T *B, T *C) const { -#ifdef PADDLE_WITH_LIBXSMM - // Refer to https://github.com/hfp/libxsmm/blob/master/README.md - // But the threshold is custom constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20; - - // Since the matrix is very small, - // so the unit of calculation is already very fast, - // and the if( M*N*K < LIBXSMM_THRESHOLD) would be overhead, - // use xsmm directly. - // Note: SMM use ColMajor - const char transa = 'N'; - const char transb = 'N'; - const T alpha = static_cast(1); - const T beta = static_cast(0); - CBlas::SMM_GEMM( - &transa, &transb, &N, &M, &K, &alpha, B, &N, A, &K, &beta, C, &N); - return; -#endif - - CBlas::GEMM(CblasRowMajor, - CblasNoTrans, - CblasNoTrans, - M, - N, - K, - static_cast(1), - A, - K, - B, - N, - static_cast(0), - C, - N); -} - -template -template -void Blas::MatMul(const phi::DenseTensor &mat_a, - const MatDescriptor &dim_a, - const phi::DenseTensor &mat_b, - const MatDescriptor &dim_b, - T alpha, - phi::DenseTensor *mat_out, - T beta) const { - MatMul(mat_a.data(), - dim_a, - mat_b.data(), - dim_b, - alpha, - mat_out->data(), - beta); -} - -template -template -void Blas::MatMul(const T *mat_a, - const MatDescriptor &dim_a, - const T *mat_b, - const MatDescriptor &dim_b, - T alpha, - T *mat_out, - T beta) const { - PADDLE_ENFORCE_EQ( - dim_a.width_, - dim_b.height_, - phi::errors::InvalidArgument( - "The fisrt matrix width should be same as second matrix height," - "but received fisrt matrix width %d" - ", second matrix height %d", - dim_a.width_, - dim_b.height_)); - - CBLAS_TRANSPOSE transA = !dim_a.trans_ ? CblasNoTrans : CblasTrans; - CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans; - if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) { - this->template GEMM(transA, - transB, - dim_a.height_, - dim_b.width_, - dim_a.width_, - alpha, - mat_a, - mat_b, - beta, - mat_out); - } else { - PADDLE_ENFORCE_EQ( - dim_a.batch_size_ == dim_b.batch_size_ || dim_a.batch_size_ == 0 || - dim_b.batch_size_ == 0, - true, - phi::errors::InvalidArgument( - "dim_a.batch_size should be equal to dim_b.batch_size, or " - "one of dim_a.batch_size and dim_b.batch_size should be 0. " - "But got dim_a.batch_size = %d, dim_b.batch_size = %d.", - dim_a.batch_size_, - dim_b.batch_size_)); - this->template BatchedGEMM( - transA, - transB, - dim_a.height_, - dim_b.width_, - dim_a.width_, - alpha, - mat_a, - mat_b, - beta, - mat_out, - dim_a.batch_size_ == 0 ? dim_b.batch_size_ : dim_a.batch_size_, - dim_a.stride_, - dim_b.stride_); - } -} - -#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) && \ - !defined(PADDLE_WITH_HIP) -// @{ Group Blas MKLML: MatMulWithHead -/* - * Multiple two matrixes with multiple heads - * - * A new parameter, i.e head_number is added compared to normal MatMul. - * The head_number describes the number of heads a matrix is vertically - * split. - * - * When user calls this API, the multiplication of two big matrixes is split - * into multiplication of several (head_number_) small matrixes. e.g. if Mat A - * is [3, 24] and Mat B is [24, 4], when multiple A and B with head_number as - * 4, Mat A will be split as 4 matrix of [3, 6] and Mat B will be - * (horizontally) split as 4 matrix of [6, 4]. The result of final matrix - * will be 4 matrix of [3, 4], i.e. [3, 16]. - * Another example is A is [3, 8], B is [2, 16], head_number is 4. In this - * case, A will be split as [3, 2], B will be (vertically) split as - * [2, 4]. The final result will be 4 matrix of 4 matrix of [3,4], i.e. [3, 16] - */ -template -template -void Blas::MatMulWithHead(const phi::DenseTensor &mat_a, - const MatDescriptor &dim_a, - const phi::DenseTensor &mat_b, - const MatDescriptor &dim_b, - T alpha, - int head_number, - phi::DenseTensor *mat_out, - T beta, - bool mat_b_split_vertical) const { - PADDLE_ENFORCE_EQ( - dim_a.width_ % head_number, - 0, - phi::errors::InvalidArgument( - "The first input width must be some times the head number" - "but received first input width %d" - ", head_number %d", - dim_a.width_, - head_number)); - PADDLE_ENFORCE_GE( - head_number, - 1, - phi::errors::InvalidArgument("The head number should be greater equal 1," - "but received head number %d", - head_number)); - PADDLE_ENFORCE_LE( - head_number, - dim_a.width_, - phi::errors::InvalidArgument( - "The head number should be less equal first input width," - "but received first input width %d" - ", head_number %d", - dim_a.width_, - head_number)); - CBLAS_TRANSPOSE transA = !dim_a.trans_ ? CblasNoTrans : CblasTrans; - CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans; - - if (mat_b_split_vertical) { - PADDLE_ENFORCE_EQ( - dim_b.height_, - dim_a.width_ / head_number, - phi::errors::InvalidArgument( - "The second input height should be equal than first input width," - "but received second input height %d, first input width %d", - dim_b.height_, - dim_a.width_ / head_number)); - PADDLE_ENFORCE_EQ( - dim_a.width_ % head_number, - 0, - phi::errors::InvalidArgument( - "The second input width should be some times the head number" - "but received second input width %d" - ", head_number %d", - dim_b.width_, - head_number)); - } - - if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) { - int lda = !dim_a.trans_ ? dim_a.width_ : dim_a.height_; - int ldb = !dim_b.trans_ ? dim_b.width_ : dim_b.height_; - int sub_matA_offset; - int sub_matB_offset; - int sub_matC_offset; - int sub_mat_M = dim_a.height_; - int sub_mat_N; - int sub_mat_K; - int ldc; - - for (int i = 0; i < head_number; i++) { - sub_matA_offset = dim_a.trans_ - ? i * (dim_a.width_ / head_number) * dim_a.height_ - : i * (dim_a.width_ / head_number); - if (mat_b_split_vertical) { - sub_matB_offset = dim_b.trans_ - ? i * (dim_b.width_ / head_number) * dim_b.height_ - : i * (dim_b.width_ / head_number); - sub_matC_offset = i * dim_b.width_ / head_number; - - sub_mat_N = dim_b.width_ / head_number; - sub_mat_K = dim_b.height_; - - ldc = dim_b.width_; - } else { - sub_matB_offset = - dim_b.trans_ ? i * (dim_b.height_ / head_number) - : i * (dim_b.height_ / head_number) * dim_b.width_; - sub_matC_offset = i * dim_b.width_; - - sub_mat_N = dim_b.width_; - sub_mat_K = dim_a.width_ / head_number; - - ldc = head_number * dim_b.width_; - } - - this->template GEMM(transA, - transB, - sub_mat_M, - sub_mat_N, - sub_mat_K, - alpha, - mat_a.data() + sub_matA_offset, - lda, - mat_b.data() + sub_matB_offset, - ldb, - beta, - mat_out->data() + sub_matC_offset, - ldc); - } - } else { - PADDLE_ENFORCE_EQ( - (dim_a.batch_size_ == dim_b.batch_size_ || dim_a.batch_size_ == 0 || - dim_b.batch_size_ == 0), - true, - phi::errors::InvalidArgument( - "The first input batch size should be equal than second input," - "either two input batch size is 0, but received first input batch " - "size" - " %d, second input batch size %d", - dim_a.batch_size_, - dim_b.batch_size_)); - - this->template BatchedGEMMWithHead( - transA, - transB, - dim_a.width_, - dim_a.height_, - dim_b.width_, - dim_b.height_, - alpha, - mat_a.data(), - mat_b.data(), - beta, - mat_out->data(), - dim_a.batch_size_ == 0 ? dim_b.batch_size_ : dim_a.batch_size_, - dim_a.stride_, - dim_b.stride_, - head_number, - mat_b_split_vertical); - } -} -#endif // @} End Group Blas MKLML: MatMulWithHead - -template -template -void Blas::VINV(int n, const T *a, T *y) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VINV(n, a, y); -#else - for (int i = 0; i < n; ++i) { - y[i] = 1.0 / a[i]; - } -#endif -} - -template <> -template -void Blas::VMERF(int n, const T *a, T *y, int64_t mode) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VMERF(n, a, y, mode); -#else - for (int i = 0; i < n; ++i) { - y[i] = std::erf(a[i]); - } -#endif -} - -#ifdef PADDLE_WITH_MKLML -template <> -template -void Blas::CSRMM(const char *transa, - const int *m, - const int *n, - const int *k, - const T *alpha, - const char *matdescra, - const T *val, - const int *indx, - const int *pntrb, - const int *pntre, - const T *b, - const int *ldb, - const T *beta, - T *c, - const int *ldc) const { - CBlas::CSRMM(transa, - m, - n, - k, - alpha, - matdescra, - val, - indx, - pntrb, - pntre, - b, - ldb, - beta, - c, - ldc); -} -#endif - -template <> -template -void Blas::TRSM(CBLAS_SIDE side, - CBLAS_UPLO uplo, - CBLAS_TRANSPOSE transA, - CBLAS_DIAG diag, - int M, - int N, - T alpha, - const T *A, - int lda, - T *B, - int ldb) const { - CBlas::TRSM( - CblasRowMajor, side, uplo, transA, diag, M, N, alpha, A, lda, B, ldb); -} - -} // namespace funcs -} // namespace phi diff --git a/backends/metax_gpu/kernels/funcs/blas/blaslt_gemm_search.h b/backends/metax_gpu/kernels/funcs/blas/blaslt_gemm_search.h deleted file mode 100644 index 6dcc56f8569..00000000000 --- a/backends/metax_gpu/kernels/funcs/blas/blaslt_gemm_search.h +++ /dev/null @@ -1,794 +0,0 @@ -// 2024 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights -// Reserved. -/* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include - -#include -#include -#include -#include - -#include "paddle/common/flags.h" -#include "paddle/phi/api/include/context_pool.h" -#include "paddle/phi/backends/dynload/cublasLt.h" -#include "paddle/phi/backends/gpu/gpu_info.h" -#include "paddle/phi/common/data_type.h" -#include "paddle/phi/common/memory_utils.h" -#include "paddle/phi/common/place.h" -#include "paddle/phi/core/allocator.h" -#include "paddle/phi/core/dense_tensor.h" - -COMMON_DECLARE_string(cublaslt_device_best_config); - -namespace phi { -namespace funcs { -namespace cublaslt_internal { - -const std::array split_k_candidates = {2, 3, 4, 5, 6, 8, 12, 16, 32}; - -struct CublasLtAlgoConfig { - int m; - int n; - int k; - int algo_id; - int swizzle; - int custom_option; - int tile; - int split_k_val; - int reduction_scheme; - int stages; -}; - -struct CublasLtAlgoSelectorParam { - float time{0.0}; - cublasLtMatmulAlgo_t algo; - CublasLtAlgoConfig algo_config; -}; - -inline bool compare_algo_time(const CublasLtAlgoSelectorParam& param_a, - const CublasLtAlgoSelectorParam& param_b) { - return (param_a.time < param_b.time); -} - -class CublasLtAlgoCache { - public: - static CublasLtAlgoCache& Instance() { - static CublasLtAlgoCache instance(100 /*search_times*/); - return instance; - } - - template - void RunAndMeasureAlgo(cublasLtHandle_t handle, - cublasLtMatmulDesc_t matmul_desc, - cublasLtMatrixLayout_t a_desc, - cublasLtMatrixLayout_t b_desc, - cublasLtMatrixLayout_t bias_desc, - cublasLtMatrixLayout_t c_desc, - void* alpha, - void* beta, - const InT* a, - const InT* b, - const OutT* bias, - OutT* c, - CublasLtAlgoSelectorParam& param, // NOLINT - cudaEvent_t& start_event, // NOLINT - cudaEvent_t& stop_event, // NOLINT - cudaStream_t stream) { - cublasStatus_t status; - cublasLtMatmulHeuristicResult_t heuristic_result; - status = dynload::cublasLtMatmulAlgoCheck(handle, - matmul_desc, - a_desc, - b_desc, - bias_desc, - c_desc, - ¶m.algo, - &heuristic_result); - PADDLE_ENFORCE_GPU_SUCCESS(status); - if (status != CUBLAS_STATUS_SUCCESS) { - param.time = std::numeric_limits::max(); - return; - } - size_t workspace_size = heuristic_result.workspaceSize; - auto workspace = phi::memory_utils::Alloc( - phi::GPUPlace(phi::backends::gpu::GetCurrentDeviceId()), - workspace_size, - phi::Stream(reinterpret_cast(stream))); - - PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(start_event, stream)); - int repeats = search_times_; - - for (int loop = 0; loop < repeats; loop++) { - status = dynload::cublasLtMatmul(handle, - matmul_desc, - alpha, - a, - a_desc, - b, - b_desc, - beta, - bias, - bias_desc, - c, - c_desc, - ¶m.algo, - workspace->ptr(), - workspace_size, - stream); - if (status != CUBLAS_STATUS_SUCCESS) { - param.time = std::numeric_limits::max(); - return; - } - } - - PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(stop_event, stream)); - PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); - - float time; - PADDLE_ENFORCE_GPU_SUCCESS( - cudaEventElapsedTime(&time, start_event, stop_event)); - - param.time = time / repeats; - } - - template - cublasLtMatmulAlgo_t* CublasLtAlgoSelect(cublasLtHandle_t handle, - int m, - int n, - int k, - int batch_count, - const InT* a, - const InT* b, - const OutT* bias, - OutT* c, - void* alpha, - void* beta, - cublasLtMatmulDesc_t matmul_desc, - cublasLtMatrixLayout_t a_desc, - cublasLtMatrixLayout_t b_desc, - cublasLtMatrixLayout_t bias_desc, - cublasLtMatrixLayout_t c_desc, - cublasComputeType_t compute_type, - cudaDataType_t scale_type, - cudaDataType_t a_type, - cudaDataType_t b_type, - cudaDataType_t bias_type, - cudaDataType_t c_type, - cudaStream_t stream) { - // If we don't have config file and we do not search, here return nullptr - if (!has_config_file_ && search_times_ <= 0) { - return nullptr; - } - - // VLOG(0) << "m n k: " << m << " " << n << " " << k; - - int64_t seed = 0; - std::hash hash_fn; - - HashMatmulDesc(matmul_desc, &seed, hash_fn); - HashMatrixLayoutDesc(a_desc, &seed, hash_fn); - HashMatrixLayoutDesc(b_desc, &seed, hash_fn); - HashMatrixLayoutDesc(bias_desc, &seed, hash_fn); - HashMatrixLayoutDesc(c_desc, &seed, hash_fn); - - { - std::lock_guard lock(cache_mutex_); - if (algo_caches_.count(seed)) { - VLOG(3) << "CublasLtAlgoSelect Found in cache"; - return &algo_caches_[seed]; - } - } - - if (search_configs_.empty()) { - std::ifstream infile; - std::string config_file_path = FLAGS_cublaslt_device_best_config; - infile.open(config_file_path.c_str()); - if (infile.is_open()) { - size_t workspace_size; - float time; - char comma; - while (!infile.eof()) { - CublasLtAlgoConfig search_config; - infile >> search_config.m >> comma >> search_config.k >> comma >> - search_config.n >> comma >> search_config.algo_id >> comma >> - search_config.swizzle >> comma >> search_config.custom_option >> - comma >> search_config.tile >> comma >> - search_config.split_k_val >> comma >> - search_config.reduction_scheme >> comma >> search_config.stages >> - comma >> workspace_size >> comma >> time; - search_configs_.push_back(search_config); - } - infile.close(); - VLOG(3) << "Loaded " << search_configs_.size() << " configs"; - } - } - if (!search_configs_.empty()) { - auto configure_algo = [&](const CublasLtAlgoConfig& search_config) - -> cublasLtMatmulAlgo_t* { - cublasLtMatmulAlgo_t algo; - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulAlgoInit(handle, - compute_type, - scale_type, - b_type, - a_type, - c_type, - c_type, - search_config.algo_id, - &algo)); - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulAlgoConfigSetAttribute( - &algo, - CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, - &search_config.custom_option, - sizeof(search_config.custom_option))); - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulAlgoConfigSetAttribute( - &algo, - CUBLASLT_ALGO_CONFIG_TILE_ID, - &search_config.tile, - sizeof(search_config.tile))); - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulAlgoConfigSetAttribute( - &algo, - CUBLASLT_ALGO_CONFIG_SPLITK_NUM, - &search_config.split_k_val, - sizeof(search_config.split_k_val))); - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulAlgoConfigSetAttribute( - &algo, - CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, - &search_config.swizzle, - sizeof(search_config.swizzle))); - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulAlgoConfigSetAttribute( - &algo, - CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, - &search_config.reduction_scheme, - sizeof(search_config.reduction_scheme))); - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulAlgoConfigSetAttribute( - &algo, - CUBLASLT_ALGO_CONFIG_STAGES_ID, - &search_config.stages, - sizeof(search_config.stages))); - std::lock_guard lock(cache_mutex_); - algo_caches_[seed] = algo; - return &algo_caches_[seed]; - }; - const CublasLtAlgoConfig* pre = nullptr; - for (size_t i = 0; i < search_configs_.size(); i++) { - if (search_configs_[i].n == n && search_configs_[i].k == k && - m <= search_configs_[i].m) { - return configure_algo(search_configs_[i]); - } else if (search_configs_[i].n == n && search_configs_[i].k == k && - m > search_configs_[i].m) { - if (pre == nullptr || pre->m < search_configs_[i].m) - pre = &search_configs_[i]; - } - } - if (pre != nullptr) { - // use max m in file - return configure_algo(*pre); - } - } - - // if we have cache but not found algo, and we don't want to search, - // here return nullptr - if (search_times_ <= 0) { - return nullptr; - } - - VLOG(3) << "CublasLtAlgoSelect Not Found in cache"; - - // Get Ids - // https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmulAlgoGetIds - cublasStatus_t status = CUBLAS_STATUS_SUCCESS; - int algo_ids[requested_algo_count_]; // NOLINT - - int num_algo_ids; - status = dynload::cublasLtMatmulAlgoGetIds(handle, - compute_type, - scale_type, - a_type, - b_type, - bias_type, - c_type, - requested_algo_count_, - algo_ids, - &num_algo_ids); - PADDLE_ENFORCE_GPU_SUCCESS(status); - - // Traverse all possible algo combinations - int step = 0; - int limit = 20000; - std::vector params; - - for (int idx = 0; idx < num_algo_ids; idx++) { - cublasLtMatmulAlgo_t algo; - - /* Initialize algo structure with given Algp ID */ - // https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmulAlgoInit - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulAlgoInit(handle, - compute_type, - scale_type, - a_type, - b_type, - bias_type, - c_type, - algo_ids[idx], - &algo)); - - // Query the tiles enums supported by that algo which is used to alloc - // enough space to store it - // https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmulAlgoCapGetAttribute - size_t attr_size = 0; - - int batch_support; - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulAlgoCapGetAttribute( - &algo, - CUBLASLT_ALGO_CAP_STRIDED_BATCH_SUPPORT, - &batch_support, - sizeof(batch_support), - &attr_size)); - if (batch_count > 1 && batch_support == 0) { - continue; - } - - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulAlgoCapGetAttribute( - &algo, CUBLASLT_ALGO_CAP_TILE_IDS, nullptr, 0, &attr_size)); - - int num_tiles = static_cast(attr_size / sizeof(int)); - std::vector tiles(num_tiles == 0 ? 1 : num_tiles); - if (num_tiles == 0) { - tiles[0] = CUBLASLT_MATMUL_TILE_UNDEFINED; - num_tiles = 1; - } else { - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulAlgoCapGetAttribute( - &algo, - CUBLASLT_ALGO_CAP_TILE_IDS, - tiles.data(), - sizeof(int) * num_tiles, - &attr_size)); - } - - // Query the stages enums supported by that algo (cuda must >= 11.0) - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulAlgoCapGetAttribute( - &algo, CUBLASLT_ALGO_CAP_STAGES_IDS, nullptr, 0, &attr_size)); - int num_stages = static_cast(attr_size / sizeof(int)); - std::vector stages(num_stages == 0 ? 1 : num_stages); - if (num_stages == 0) { - stages[0] = CUBLASLT_MATMUL_STAGES_UNDEFINED; - num_stages = 1; - } else { - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulAlgoCapGetAttribute( - &algo, - CUBLASLT_ALGO_CAP_STAGES_IDS, - stages.data(), - sizeof(int) * num_stages, - &attr_size)); - } - - // Retrieve Other Algo Capabilities attributes - int splitk_support, red_mask, swizzling_max, custom_option_max; - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulAlgoCapGetAttribute( - &algo, - CUBLASLT_ALGO_CAP_SPLITK_SUPPORT, - &splitk_support, - sizeof(splitk_support), - &attr_size)); - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulAlgoCapGetAttribute( - &algo, - CUBLASLT_ALGO_CAP_REDUCTION_SCHEME_MASK, - &red_mask, - sizeof(red_mask), - &attr_size)); - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulAlgoCapGetAttribute( - &algo, - CUBLASLT_ALGO_CAP_CTA_SWIZZLING_SUPPORT, - &swizzling_max, - sizeof(swizzling_max), - &attr_size)); - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulAlgoCapGetAttribute( - &algo, - CUBLASLT_ALGO_CAP_CUSTOM_OPTION_MAX, - &custom_option_max, - sizeof(custom_option_max), - &attr_size)); - - /* Loop over the different tiles */ - for (int tile_id = 0; tile_id < num_tiles && step < limit; tile_id++) { - /* Loop over different stages count */ - for (int stage_id = 0; stage_id < num_stages && step < limit; - stage_id++) { - /* Loop over the different custom option if any */ - for (int custom_option = 0; - custom_option <= custom_option_max && step < limit; - custom_option++) { - /* Loop over the CTAs swizzling support */ - for (int k = 0; k <= swizzling_max && step < limit; k++) { - int splir_k_trial = 0; - if (splitk_support) { - splir_k_trial += - sizeof(split_k_candidates) / sizeof(split_k_candidates[0]); - } - - for (int l = 0; (l < (1 + splir_k_trial)) && (step < limit); - l++) { - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulAlgoConfigSetAttribute( - &algo, - CUBLASLT_ALGO_CONFIG_TILE_ID, - &tiles[tile_id], - sizeof(tiles[tile_id]))); - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulAlgoConfigSetAttribute( - &algo, - CUBLASLT_ALGO_CONFIG_STAGES_ID, - &stages[stage_id], - sizeof(stages[stage_id]))); - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulAlgoConfigSetAttribute( - &algo, - CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, - &custom_option, - sizeof(custom_option))); - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulAlgoConfigSetAttribute( - &algo, - CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, - &k, - sizeof(k))); - int split_k_val = 1; - int reduction_scheme = CUBLASLT_REDUCTION_SCHEME_NONE; - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulAlgoConfigSetAttribute( - &algo, - CUBLASLT_ALGO_CONFIG_SPLITK_NUM, - &split_k_val, - sizeof(split_k_val))); - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulAlgoConfigSetAttribute( - &algo, - CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, - &reduction_scheme, - sizeof(int))); - if (l > 0) { // Split-K case - split_k_val = split_k_candidates[l - 1]; - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulAlgoConfigSetAttribute( - &algo, - CUBLASLT_ALGO_CONFIG_SPLITK_NUM, - &split_k_candidates[l - 1], - sizeof(split_k_candidates[l - 1]))); - for (reduction_scheme = 1; - reduction_scheme < - static_cast(CUBLASLT_REDUCTION_SCHEME_MASK) && - (step < limit); - reduction_scheme = reduction_scheme << 1) { - if (reduction_scheme & red_mask) { - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulAlgoConfigSetAttribute( - &algo, - CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, - &reduction_scheme, - sizeof(reduction_scheme))); - - cublasLtMatmulHeuristicResult_t heurResult; - status = dynload::cublasLtMatmulAlgoCheck(handle, - matmul_desc, - a_desc, - b_desc, - bias_desc, - c_desc, - &algo, - &heurResult); - if (status == CUBLAS_STATUS_SUCCESS) { - CublasLtAlgoSelectorParam param; - param.algo = algo; - param.algo_config.m = m; - param.algo_config.n = n; - param.algo_config.k = k; - param.algo_config.algo_id = algo_ids[idx]; - param.algo_config.tile = tiles[tile_id]; - param.algo_config.swizzle = k; - param.algo_config.custom_option = custom_option; - param.algo_config.split_k_val = split_k_val; - param.algo_config.reduction_scheme = reduction_scheme; - param.algo_config.stages = stages[stage_id]; - params.emplace_back(param); - step++; - } - } // end if - } - } else { - // Prepare algos - cublasLtMatmulHeuristicResult_t heurResult; - // https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmulAlgoCheck - status = dynload::cublasLtMatmulAlgoCheck(handle, - matmul_desc, - a_desc, - b_desc, - bias_desc, - c_desc, - &algo, - &heurResult); - if (status == CUBLAS_STATUS_SUCCESS) { - CublasLtAlgoSelectorParam param; - param.algo = algo; - param.algo_config.m = m; - param.algo_config.n = n; - param.algo_config.k = k; - param.algo_config.algo_id = algo_ids[idx]; - param.algo_config.tile = tiles[tile_id]; - param.algo_config.swizzle = k; - param.algo_config.custom_option = custom_option; - param.algo_config.split_k_val = split_k_val; - param.algo_config.reduction_scheme = reduction_scheme; - param.algo_config.stages = stages[stage_id]; - params.emplace_back(param); - step++; - } - } - } - } - } - } - } - } - cudaEvent_t start_event; - cudaEvent_t stop_event; - - PADDLE_ENFORCE_GPU_SUCCESS(cudaEventCreate(&start_event)); - PADDLE_ENFORCE_GPU_SUCCESS(cudaEventCreate(&stop_event)); - - if (step == 0) { - VLOG(3) << "No algo can be used"; - return nullptr; - } - - VLOG(3) << "CublasLtAlgoSelect Start testRun " << step << " " - << params.size(); - - for (int i = 0; i < step; i++) { - RunAndMeasureAlgo(handle, - matmul_desc, - a_desc, - b_desc, - bias_desc, - c_desc, - alpha, - beta, - a, - b, - bias, - c, - params[i], - start_event, - stop_event, - stream); - } - std::sort(params.begin(), params.end(), compare_algo_time); - - size_t res_id = 0; - while (params[res_id].time == 0.0) { - res_id++; - if (res_id >= params.size()) break; - } - - if (res_id >= params.size()) { - VLOG(3) << "No algo can be used"; - return nullptr; - } - - VLOG(3) << "algo selected"; - - std::lock_guard lock(cache_mutex_); - algo_caches_[seed] = params[res_id].algo; - return &algo_caches_[seed]; - } - - ~CublasLtAlgoCache() { SerializeAlgoCachesToFile(); } - - private: - std::string algo_caches_file_{"./cublaslt_algo_caches_from_paddle"}; - std::unordered_map algo_caches_; - std::vector search_configs_; - int search_times_; - static constexpr int requested_algo_count_ = 100; - std::mutex cache_mutex_; - bool has_config_file_; - - explicit CublasLtAlgoCache(int search_times) - : search_times_(search_times), has_config_file_(true) { - // Init algo_caches_ from cache file - std::ifstream infile; - infile.open(algo_caches_file_); - if (!infile.is_open()) { - has_config_file_ = false; - VLOG(3) << "No CublasLtAlgoCache file found"; - return; - } - size_t cublaslt_version = 0, real_cublaslt_version = 0; - int64_t seed = 0; - std::array algo_data; - infile >> cublaslt_version; - VLOG(1) << "cublaslt_version " << cublaslt_version; - - if (dynload::cublasLtGetCudartVersion() != cublaslt_version) { - LOG(INFO) << algo_caches_file_ - << " is not compatible with current cublaslt_version " - << real_cublaslt_version; - return; - } - - while (!infile.eof()) { - infile >> seed >> algo_data[0] >> algo_data[1] >> algo_data[2] >> - algo_data[3] >> algo_data[4] >> algo_data[5] >> algo_data[6] >> - algo_data[7]; - - for (int i = 0; i < 8; ++i) { - algo_caches_[seed].data[i] = algo_data[i]; - } - } - infile.close(); - } - - // Serialize algo_caches_ to cache file - void SerializeAlgoCachesToFile() { - if (search_times_ > 0) { - int dev; - cudaGetDevice(&dev); - if (dev == 0) { - std::ofstream outfile; - outfile.open(algo_caches_file_, std::ios::out | std::ios::trunc); - outfile << dynload::cublasLtGetCudartVersion() << std::endl; - - for (const auto& [seed, algo] : algo_caches_) { - outfile << seed << " "; - for (size_t value : algo.data) { - outfile << value << " "; - } - outfile << std::endl; - } - outfile.close(); - } - } - } - - inline int64_t RoundToNextHighPowOfTwo(int64_t n, int64_t min_val) { - n--; - n |= (n >> 1); - n |= (n >> 2); - n |= (n >> 4); - n |= (n >> 8); - n |= (n >> 16); - return std::max(min_val, (n + 1)); - } - - void HashMatmulDesc(cublasLtMatmulDesc_t desc, - int64_t* seed, - const std::hash& hash_fn) { - size_t size_to_write; - int trans_a, trans_b; - uint32_t epilogue; - // int8_t fast_accum; - - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulDescGetAttribute(desc, - CUBLASLT_MATMUL_DESC_TRANSA, - &trans_a, - sizeof(trans_a), - &size_to_write)); - HashValue(seed, hash_fn, static_cast(trans_a)); - - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulDescGetAttribute(desc, - CUBLASLT_MATMUL_DESC_TRANSB, - &trans_b, - sizeof(trans_b), - &size_to_write)); - HashValue(seed, hash_fn, static_cast(trans_b)); - - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulDescGetAttribute(desc, - CUBLASLT_MATMUL_DESC_EPILOGUE, - &epilogue, - sizeof(epilogue), - &size_to_write)); - HashValue(seed, hash_fn, static_cast(epilogue)); - - // PADDLE_ENFORCE_GPU_SUCCESS( - // dyl::cublasLtMatmulDescGetAttribute(desc, - // CUBLASLT_MATMUL_DESC_FAST_ACCUM, - // &fast_accum, - // sizeof(fast_accum), - // &size_to_write)); - // HashValue(seed, hash_fn, static_cast(fast_accum)); - } - - void HashMatrixLayoutDesc(cublasLtMatrixLayout_t desc, - int64_t* seed, - const std::hash& hash_fn) { - size_t size_to_write; - uint32_t dtype; - int32_t batch; - uint64_t row, col; - int64_t ld, batch_offset; - - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatrixLayoutGetAttribute(desc, - CUBLASLT_MATRIX_LAYOUT_TYPE, - &dtype, - sizeof(dtype), - &size_to_write)); - HashValue(seed, hash_fn, static_cast(dtype)); - - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatrixLayoutGetAttribute( - desc, - CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, - &batch, - sizeof(batch), - &size_to_write)); - HashValue(seed, hash_fn, static_cast(batch)); - - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatrixLayoutGetAttribute( - desc, CUBLASLT_MATRIX_LAYOUT_ROWS, &row, sizeof(row), &size_to_write)); - HashValue(seed, hash_fn, RoundToNextHighPowOfTwo(row, 32)); - - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatrixLayoutGetAttribute( - desc, CUBLASLT_MATRIX_LAYOUT_COLS, &col, sizeof(col), &size_to_write)); - HashValue(seed, hash_fn, RoundToNextHighPowOfTwo(col, 32)); - - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatrixLayoutGetAttribute( - desc, CUBLASLT_MATRIX_LAYOUT_LD, &ld, sizeof(ld), &size_to_write)); - HashValue(seed, hash_fn, RoundToNextHighPowOfTwo(ld, 32)); - - // PADDLE_ENFORCE_GPU_SUCCESS(dyl::cublasLtMatrixLayoutGetAttribute( - // desc, CUBLASLT_MATRIX_LAYOUT_ROWS, &row, sizeof(row), - // &size_to_write)); - // HashValue(seed, hash_fn, row); - - // PADDLE_ENFORCE_GPU_SUCCESS(dyl::cublasLtMatrixLayoutGetAttribute( - // desc, CUBLASLT_MATRIX_LAYOUT_COLS, &col, sizeof(col), - // &size_to_write)); - // HashValue(seed, hash_fn, col); - - // PADDLE_ENFORCE_GPU_SUCCESS(dyl::cublasLtMatrixLayoutGetAttribute( - // desc, CUBLASLT_MATRIX_LAYOUT_LD, &ld, sizeof(ld), &size_to_write)); - // HashValue(seed, hash_fn, ld); - - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatrixLayoutGetAttribute( - desc, - CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, - &batch_offset, - sizeof(batch_offset), - &size_to_write)); - HashValue(seed, hash_fn, static_cast(batch_offset)); - } - - void HashValue(int64_t* seed, - const std::hash& hash_fn, - int64_t value) { - *seed ^= hash_fn(value) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2); - } -}; - -} // namespace cublaslt_internal -} // namespace funcs -} // namespace phi diff --git a/backends/metax_gpu/kernels/funcs/blas/blaslt_impl.cu.h b/backends/metax_gpu/kernels/funcs/blas/blaslt_impl.cu.h deleted file mode 100755 index d98182abef3..00000000000 --- a/backends/metax_gpu/kernels/funcs/blas/blaslt_impl.cu.h +++ /dev/null @@ -1,1137 +0,0 @@ -// 2024 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights -// Reserved. -/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 && 0 - -#include // NOLINT - -#include "cuda.h" // NOLINT -#include "glog/logging.h" -// #include "paddle/phi/backends/dynload/cublasLt.h" -#include "paddle/phi/backends/gpu/cuda/cuda_helper.h" -#include "paddle/phi/common/amp_type_traits.h" -#include "paddle/phi/common/memory_utils.h" -#include "paddle/phi/core/flags.h" -#include "paddle/phi/kernels/autotune/gpu_timer.h" -#include "paddle/phi/kernels/autotune/switch_autotune.h" - -PHI_DECLARE_int64(cublaslt_exhaustive_search_times); -#endif - -namespace phi { -namespace funcs { - -#if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 && 0) - -// Set this enum according to -// https://docs.nvidia.com/cuda/cublas/index.html#cublasltepilogue-t -// While kMatmul, kMatmulGrad, kMatmulGradWithoutBias share the same -// enum value, but if all elements for MatmulPlanner->GetKey() is same, -// no matter forward or backward, they could share the same descriptor -// cache, in that the descriptor is for description of matmul operation. -enum MatmulFusedType { - kMatmul = 0, - kMatmulGrad = 1, - kMatmulGradWithoutBias = 2, - kMatmulBias = 3, - kMatmulRelu = 4, - kMatmulBiasRelu = 5, - kMatmulBiasGelu = 6, - kMatmulBiasReluWithReservedData = 7, - kMatmulBiasGeluWithReservedData = 8, - kMatmulReluGrad = 9, - kMatmulGeluGrad = 10, - kMatmulBiasGradToA = 11, - kMatmulBiasGradToB = 12 -}; - -static cublasLtEpilogue_t ConvertFusedType(MatmulFusedType fused_type) { - static std::map fused_type_map = { - {MatmulFusedType::kMatmul, CUBLASLT_EPILOGUE_DEFAULT}, - {MatmulFusedType::kMatmulGrad, CUBLASLT_EPILOGUE_DEFAULT}, - {MatmulFusedType::kMatmulGradWithoutBias, CUBLASLT_EPILOGUE_DEFAULT}, - {MatmulFusedType::kMatmulBias, CUBLASLT_EPILOGUE_BIAS}, - {MatmulFusedType::kMatmulRelu, CUBLASLT_EPILOGUE_RELU}, - {MatmulFusedType::kMatmulBiasRelu, CUBLASLT_EPILOGUE_RELU_BIAS}, - {MatmulFusedType::kMatmulBiasGelu, CUBLASLT_EPILOGUE_GELU_BIAS}, - {MatmulFusedType::kMatmulBiasReluWithReservedData, - CUBLASLT_EPILOGUE_RELU_AUX_BIAS}, - {MatmulFusedType::kMatmulBiasGeluWithReservedData, - CUBLASLT_EPILOGUE_GELU_AUX_BIAS}, - {MatmulFusedType::kMatmulReluGrad, CUBLASLT_EPILOGUE_DRELU}, - {MatmulFusedType::kMatmulGeluGrad, CUBLASLT_EPILOGUE_DGELU}, - {MatmulFusedType::kMatmulBiasGradToA, CUBLASLT_EPILOGUE_BGRADA}, - {MatmulFusedType::kMatmulBiasGradToB, CUBLASLT_EPILOGUE_BGRADB}}; - - return fused_type_map[fused_type]; -} - -enum FusedGEMMGradInType { kDX = 0, kDY = 1, kDZ = 2 }; - -template -struct FusedGEMMGradTrait; - -template <> -struct FusedGEMMGradTrait { - static constexpr auto kXGradA = FusedGEMMGradInType::kDZ; - static constexpr auto kXGradB = FusedGEMMGradInType::kDY; - static constexpr auto kXGradATrans = false; - static constexpr auto kXGradBTrans = true; - - static constexpr auto kYGradA = FusedGEMMGradInType::kDX; - static constexpr auto kYGradB = FusedGEMMGradInType::kDZ; - static constexpr auto kYGradATrans = true; - static constexpr auto kYGradBTrans = false; -}; - -template <> -struct FusedGEMMGradTrait { - static constexpr auto kXGradA = FusedGEMMGradInType::kDY; - static constexpr auto kXGradB = FusedGEMMGradInType::kDZ; - static constexpr auto kXGradATrans = false; - static constexpr auto kXGradBTrans = true; - - static constexpr auto kYGradA = FusedGEMMGradInType::kDX; - static constexpr auto kYGradB = FusedGEMMGradInType::kDZ; - static constexpr auto kYGradATrans = false; - static constexpr auto kYGradBTrans = false; -}; - -template <> -struct FusedGEMMGradTrait { - static constexpr auto kXGradA = FusedGEMMGradInType::kDZ; - static constexpr auto kXGradB = FusedGEMMGradInType::kDY; - static constexpr auto kXGradATrans = false; - static constexpr auto kXGradBTrans = false; - - static constexpr auto kYGradA = FusedGEMMGradInType::kDZ; - static constexpr auto kYGradB = FusedGEMMGradInType::kDX; - static constexpr auto kYGradATrans = true; - static constexpr auto kYGradBTrans = false; -}; - -template <> -struct FusedGEMMGradTrait { - static constexpr auto kXGradA = FusedGEMMGradInType::kDY; - static constexpr auto kXGradB = FusedGEMMGradInType::kDZ; - static constexpr auto kXGradATrans = true; - static constexpr auto kXGradBTrans = true; - - static constexpr auto kYGradA = FusedGEMMGradInType::kDZ; - static constexpr auto kYGradB = FusedGEMMGradInType::kDX; - static constexpr auto kYGradATrans = true; - static constexpr auto kYGradBTrans = true; -}; - -// To tell any matmul or fused matmul operation from each other. -struct MatmulPlanner { - public: - const void* bias{nullptr}; - void* aux_data{nullptr}; - - MatmulPlanner() {} - MatmulPlanner(const std::vector& x_dims, - const std::vector& y_dims, - const bool trans_x, - const bool trans_y, - phi::DataType dtype, - MatmulFusedType fused_type, - const void* bias_data = nullptr, - void* reserve_data = nullptr, // Commonly for ReLu bit-mask. - bool use_addto = false, - bool no_exchange = true) - : bias(bias_data), aux_data(reserve_data), fused_type_(fused_type) { - use_addto_ = use_addto; - key_ = phi::autotune::GenKey(x_dims, - y_dims, - static_cast(trans_x), - static_cast(trans_y), - static_cast(dtype), - static_cast(fused_type_), - static_cast(use_addto_), - static_cast(no_exchange)); - } - - bool UseAddTo() const { return use_addto_; } - size_t GetKey() const { return key_; } - MatmulFusedType GetFusedType() const { return fused_type_; } - - size_t GenSubKey() const { return key_; } - - private: - MatmulFusedType fused_type_; - bool use_addto_; - size_t key_; -}; - -template -cublasComputeType_t GetCudaComputeType() { - if (std::is_same::value) { - return CUBLAS_COMPUTE_64F; - } else if (std::is_same::value) { - return CUBLAS_COMPUTE_32I; - } else { - return CUBLAS_COMPUTE_32F; - } -} - -struct MatmulDescriptor { - public: - cublasLtMatmulDesc_t op_desc{nullptr}; - cublasLtMatrixLayout_t x_desc{nullptr}; - cublasLtMatrixLayout_t y_desc{nullptr}; - cublasLtMatrixLayout_t out_desc{nullptr}; - cublasLtMatmulAlgo_t* algo{nullptr}; - bool is_cached{false}; - - MatmulDescriptor() {} - MatmulDescriptor(const MatmulDescriptor& obj) { - algo = obj.algo; - x_desc = obj.x_desc; - y_desc = obj.y_desc; - op_desc = obj.op_desc; - out_desc = obj.out_desc; - is_cached = obj.is_cached; - } - - MatmulDescriptor& operator=(const MatmulDescriptor& obj) { - algo = obj.algo; - x_desc = obj.x_desc; - y_desc = obj.y_desc; - op_desc = obj.op_desc; - out_desc = obj.out_desc; - is_cached = obj.is_cached; - - return *this; - } - - ~MatmulDescriptor() PADDLE_MAY_THROW { - if (!is_cached) { - PADDLE_WARN_GPU_SUCCESS(dynload::cublasLtMatmulDescDestroy(op_desc)); - PADDLE_WARN_GPU_SUCCESS(dynload::cublasLtMatrixLayoutDestroy(y_desc)); - PADDLE_WARN_GPU_SUCCESS(dynload::cublasLtMatrixLayoutDestroy(x_desc)); - PADDLE_WARN_GPU_SUCCESS(dynload::cublasLtMatrixLayoutDestroy(out_desc)); - delete algo; - - op_desc = nullptr; - x_desc = nullptr; - y_desc = nullptr; - out_desc = nullptr; - algo = nullptr; - } - } - - // x_desc, y_desc, op_desc are allocated in heap memory. - template - void Create(const int64_t M, - const int64_t N, - const int64_t K, - const bool trans_x, - const bool trans_y, - phi::funcs::MatmulPlanner* planner, - const int batch_size = 1, - const int64_t stride_x = 0, - const int64_t stride_y = 0, - const int64_t stride_out = 0, - bool grad_for_dx = true) { - using MT = typename phi::dtype::MPTypeTrait::Type; - cudaDataType_t mat_type = phi::backends::gpu::ToCudaDataType(); - cudaDataType_t out_mat_type = phi::backends::gpu::ToCudaDataType(); - cudaDataType_t scale_type = phi::backends::gpu::ToCudaDataType(); - cublasComputeType_t compute_type = GetCudaComputeType(); - - if (std::is_same::value) { - out_mat_type = phi::backends::gpu::ToCudaDataType(); - scale_type = phi::backends::gpu::ToCudaDataType(); - } - - // Create operation descriptor; see cublasLtMatmulDescAttributes_t for - // details about defaults; just need to set the transforms for A and B - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulDescCreate(&op_desc, compute_type, scale_type)); - SetFusedEpilogueOpDescriptor(planner, trans_x, trans_y, N); - - // Create matrix descriptors - CreateMatrixLayout(&x_desc, mat_type, M, K, trans_x); - CreateMatrixLayout(&y_desc, mat_type, K, N, trans_y); - CreateMatrixLayout(&out_desc, out_mat_type, M, N, false); - - // Config batch size and stride. - if (batch_size > 1) { - SetBatchAndStride(x_desc, batch_size, stride_x); - SetBatchAndStride(y_desc, batch_size, stride_y); - SetBatchAndStride(out_desc, batch_size, stride_out); - } - } - - cublasLtMatmulAlgo_t* SetAlgo() { - // while entering this function, the desc shall be cached. - is_cached = true; - algo = new cublasLtMatmulAlgo_t; - return algo; - } - - template - void SetFusedEpiloguePtr(phi::funcs::MatmulPlanner* planner) { - if (planner->bias != nullptr) { - const T* bias_data = static_cast(planner->bias); - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulDescSetAttribute( - op_desc, - CUBLASLT_MATMUL_DESC_BIAS_POINTER, - &bias_data, - sizeof(bias_data))); - } - if (planner->aux_data != nullptr) { - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulDescSetAttribute( - op_desc, - CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, - &(planner->aux_data), - sizeof(planner->aux_data))); - } - } - - std::string GetDescResultString(std::string prefix, - bool has_algo = true) const { - std::ostringstream out; - out << prefix << " \n"; -#define GET_DESC_DATA_STRING(src) \ - do { \ - out << " " << #src << " = ["; \ - int num = sizeof((*src)) / sizeof(src->data[0]); \ - for (int i = 0; i < num; ++i) { \ - if (i == 0) { \ - out << src->data[i]; \ - } else { \ - out << ", " << src->data[i]; \ - } \ - } \ - out << "]\n"; \ - } while (0); - - if (has_algo) { - GET_DESC_DATA_STRING(algo); - } - GET_DESC_DATA_STRING(x_desc); - GET_DESC_DATA_STRING(y_desc); - GET_DESC_DATA_STRING(out_desc); - GET_DESC_DATA_STRING(op_desc); -#undef GET_DESC_DATA_STRING - return out.str(); - } - - void ExchangeXYDesc(bool no_exchange) {} - - protected: - void SetFusedEpilogueOpDescriptor(phi::funcs::MatmulPlanner* planner, - const bool trans_x, - const bool trans_y, - int64_t lead_dim) { - cublasOperation_t cublas_trans_x = trans_x ? CUBLAS_OP_T : CUBLAS_OP_N; - cublasOperation_t cublas_trans_y = trans_y ? CUBLAS_OP_T : CUBLAS_OP_N; - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulDescSetAttribute(op_desc, - CUBLASLT_MATMUL_DESC_TRANSB, - &cublas_trans_x, - sizeof(cublas_trans_x))); - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulDescSetAttribute(op_desc, - CUBLASLT_MATMUL_DESC_TRANSA, - &cublas_trans_y, - sizeof(cublas_trans_y))); - MatmulFusedType fused_type = planner->GetFusedType(); - if (fused_type != MatmulFusedType::kMatmul) { - cublasLtEpilogue_t cublaslt_fused_type = ConvertFusedType(fused_type); - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulDescSetAttribute(op_desc, - CUBLASLT_MATMUL_DESC_EPILOGUE, - &cublaslt_fused_type, - sizeof(fused_type))); - } - if (planner->aux_data) { - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulDescSetAttribute( - op_desc, - CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, - &lead_dim, - sizeof(lead_dim))); - } - } - - void CreateMatrixLayout(cublasLtMatrixLayout_t* desc, - cudaDataType type, - uint64_t rows, - uint64_t cols, - bool trans) { - if (trans) { - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatrixLayoutCreate(desc, type, rows, cols, rows)); - } else { - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatrixLayoutCreate(desc, type, cols, rows, cols)); - } - } - - void SetBatchAndStride(cublasLtMatrixLayout_t desc, - int batch_size, - int64_t stride) { - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatrixLayoutSetAttribute( - desc, - CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, - &batch_size, - sizeof(batch_size))); - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatrixLayoutSetAttribute( - desc, - CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, - &stride, - sizeof(stride))); - } -}; - -struct MatmulGradDescriptor : MatmulDescriptor { - public: - MatmulGradDescriptor() {} - - template - void Create(const int64_t M, - const int64_t N, - const int64_t K, - const bool trans_x, - const bool trans_y, - phi::funcs::MatmulPlanner* planner, - const int batch_size = 1, - int64_t stride_x = 0, - int64_t stride_y = 0, - int64_t stride_out = 0, - bool grad_for_dx = true) { - using MT = typename phi::dtype::MPTypeTrait::Type; - cudaDataType_t mat_type = phi::backends::gpu::ToCudaDataType(); - cudaDataType_t scale_type = phi::backends::gpu::ToCudaDataType(); - cublasComputeType_t compute_type = GetCudaComputeType(); - - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulDescCreate(&op_desc, compute_type, scale_type)); - this->SetFusedEpilogueOpDescriptor( - planner, trans_x, trans_y, TransX ? M : K); - - // Create operation desciriptor; see cublasLtMatmulDescAttributes_t for - // details about defaults; just need to set the transforms for A and B - this->CreateMatrixLayout(&x_desc, mat_type, N, M, true); - if (grad_for_dx) { - this->CreateMatrixLayout(&y_desc, mat_type, K, N, TransY); - this->CreateMatrixLayout( - &out_desc, phi::backends::gpu::ToCudaDataType(), M, K, TransX); - } else { - this->CreateMatrixLayout(&y_desc, mat_type, M, K, TransX); - this->CreateMatrixLayout( - &out_desc, phi::backends::gpu::ToCudaDataType(), K, N, TransY); - } - } - - void ExchangeXYDesc(bool no_exchange) { - if (no_exchange) { - return; - } - auto* temp = y_desc; - y_desc = x_desc; - x_desc = temp; - } -}; - -template -struct CublasLtBase { - public: - using MT = typename phi::dtype::MPTypeTrait::Type; - static phi::Allocator::AllocationPtr GetWorkspace(const phi::GPUContext& ctx, - size_t workspace_size) { - return phi::memory_utils::Alloc( - ctx.GetPlace(), - workspace_size, - phi::Stream(reinterpret_cast(ctx.stream()))); - } - - static void RunImpl(const phi::GPUContext& ctx, - MatmulDescT* desc, - const size_t sub_key, - const T* x_ptr, - const T* y_ptr, - OutT* out_ptr, - phi::funcs::MatmulPlanner* planner) { - MT alpha = static_cast(1); - MT beta = planner->UseAddTo() ? static_cast(1) : static_cast(0); - cublasLtHandle_t cublaslt_handle = ctx.cublaslt_handle(); - - // NOTE(limingshu): As workspace_size varies from different DL framework, - // I wonder is there any smarter idea for workspace setting, currently I - // just followed the settings from the NVIDIA colleague`s setting. - size_t workspace_size = static_cast(4) * 1024 * 1024; - phi::Allocator::AllocationPtr workspace = GetWorkspace(ctx, workspace_size); - - if (planner != nullptr) { - if (phi::autotune::AutoTuneStatus::Instance().UseAutoTune() && - (!desc->is_cached)) { - SearchBestAlgo(ctx, - cublaslt_handle, - desc, - static_cast(&alpha), - static_cast(&beta), - y_ptr, - x_ptr, - out_ptr, - workspace->ptr(), - workspace_size); - MatmulDescT* best_desc = new MatmulDescT(*desc); - VLOG(6) << best_desc->GetDescResultString( - "[Searched CublasltDescriptor] "); - - auto& cache = phi::autotune::AutoTuneCache::Instance().GetMatmul(); - cache.SetSubKey(sub_key, reinterpret_cast(best_desc)); - } - } - - VLOG(7) << desc->GetDescResultString("[Impl CublasltDescriptor] "); - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmul(cublaslt_handle, - desc->op_desc, - static_cast(&alpha), - y_ptr, - desc->y_desc, - x_ptr, - desc->x_desc, - static_cast(&beta), - out_ptr, - desc->out_desc, - out_ptr, - desc->out_desc, - desc->algo, - workspace->ptr(), - workspace_size, - ctx.stream())); - } - - static void SearchBestAlgo(const phi::GPUContext& ctx, - const cublasLtHandle_t& lt_handle, - MatmulDescT* desc, - const void* alpha, - const void* beta, - const void* y_data, - const void* x_data, - void* out_data, - void* workspace_ptr, - size_t workspace_size) { - cublasLtMatmulPreference_t preference; - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulPreferenceCreate(&preference)); - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulPreferenceSetAttribute( - preference, - CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, - &workspace_size, - sizeof(workspace_size))); - - int returned_results = 0; - constexpr int requested_algo_count = 10; - std::vector heuristic_results( - requested_algo_count); - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulAlgoGetHeuristic(lt_handle, - desc->op_desc, - desc->y_desc, - desc->x_desc, - desc->out_desc, - desc->out_desc, - preference, - requested_algo_count, - heuristic_results.data(), - &returned_results)); - PADDLE_ENFORCE_GT(returned_results, - 0, - phi::errors::Unavailable("No GEMM algorithm avaliable.")); - int best_algo_idx = -1; - if (returned_results == 1 || FLAGS_cublaslt_exhaustive_search_times <= 0) { - best_algo_idx = 0; - } else { - float min_time_cost = std::numeric_limits::max(); - for (int algo_idx = 0; algo_idx < returned_results; ++algo_idx) { - float cur_time_cost = - RunAndMeasureAlgo(ctx, - lt_handle, - desc, - alpha, - beta, - y_data, - x_data, - out_data, - workspace_ptr, - workspace_size, - &(heuristic_results[algo_idx].algo)); - VLOG(6) << "[MatmulWithCublaslt] algo[" << algo_idx - << "] time: " << cur_time_cost << " s"; - - if ((best_algo_idx == 0 && (1.05 * cur_time_cost < min_time_cost)) || - (cur_time_cost < min_time_cost)) { - best_algo_idx = algo_idx; - min_time_cost = cur_time_cost; - } - } - } - VLOG(6) << "[MatmulWithCublaslt] best_algo_idx: " << best_algo_idx; - - cublasLtMatmulAlgo_t* best_algo = desc->SetAlgo(); - *best_algo = heuristic_results[best_algo_idx].algo; - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulPreferenceDestroy(preference)); - } - - static float RunAndMeasureAlgo(const phi::GPUContext& ctx, - const cublasLtHandle_t& lt_handle, - MatmulDescT* desc, - const void* alpha, - const void* beta, - const void* y_data, - const void* x_data, - void* out_data, - void* workspace_ptr, - size_t workspace_size, - cublasLtMatmulAlgo_t* algo) { - int repeats = FLAGS_cublaslt_exhaustive_search_times; - if (repeats <= 0) { - return std::numeric_limits::max(); - } - - phi::GpuTimer timer; - float time_cost = 0.f; - const auto& stream = ctx.stream(); - - for (int i = 0; i < repeats; ++i) { - timer.Start(stream); - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmul(lt_handle, - desc->op_desc, - alpha, - y_data, - desc->y_desc, - x_data, - desc->x_desc, - beta, - out_data, - desc->out_desc, - out_data, - desc->out_desc, - algo, - workspace_ptr, - workspace_size, - stream)); - timer.Stop(stream); - ctx.Wait(); - auto time = timer.ElapsedTime(); - if (i > 0) { - // Exclude the warmup runtime. - time_cost += time; - } - } - return (time_cost / (repeats - 1)); - } -}; - -template <> -struct CublasLtBase { - public: - static phi::Allocator::AllocationPtr GetWorkspace(const phi::GPUContext& ctx, - size_t workspace_size) { - return phi::memory_utils::Alloc( - ctx.GetPlace(), - workspace_size, - phi::Stream(reinterpret_cast(ctx.stream()))); - } - - static void RunImpl(const phi::GPUContext& ctx, - MatmulDescriptor* desc, - const size_t sub_key, - const int8_t* x_ptr, - const int8_t* y_ptr, - int32_t* out_ptr, - phi::funcs::MatmulPlanner* planner) { - int32_t alpha = 1; - int32_t beta = - planner->UseAddTo() ? static_cast(1) : static_cast(0); - cublasLtHandle_t cublaslt_handle = ctx.cublaslt_handle(); - - size_t workspace_size = static_cast(4) * 1024 * 1024; - phi::Allocator::AllocationPtr workspace = GetWorkspace(ctx, workspace_size); - - if (planner != nullptr) { - if (phi::autotune::AutoTuneStatus::Instance().UseAutoTune() && - (!desc->is_cached)) { - SearchBestAlgo(ctx, - cublaslt_handle, - desc, - static_cast(&alpha), - static_cast(&beta), - y_ptr, - x_ptr, - out_ptr, - workspace->ptr(), - workspace_size); - MatmulDescriptor* best_desc = new MatmulDescriptor(*desc); - VLOG(6) << best_desc->GetDescResultString( - "[Searched CublasltDescriptor] "); - - auto& cache = phi::autotune::AutoTuneCache::Instance().GetMatmul(); - cache.SetSubKey(sub_key, reinterpret_cast(best_desc)); - } - } - - VLOG(7) << desc->GetDescResultString("[Impl CublasltDescriptor] "); - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmul(cublaslt_handle, - desc->op_desc, - static_cast(&alpha), - y_ptr, - desc->y_desc, - x_ptr, - desc->x_desc, - static_cast(&beta), - out_ptr, - desc->out_desc, - out_ptr, - desc->out_desc, - desc->algo, - workspace->ptr(), - workspace_size, - ctx.stream())); - } - - static void SearchBestAlgo(const phi::GPUContext& ctx, - const cublasLtHandle_t& lt_handle, - MatmulDescriptor* desc, - const void* alpha, - const void* beta, - const void* y_data, - const void* x_data, - void* out_data, - void* workspace_ptr, - size_t workspace_size) { - cublasLtMatmulPreference_t preference; - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulPreferenceCreate(&preference)); - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulPreferenceSetAttribute( - preference, - CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, - &workspace_size, - sizeof(workspace_size))); - - int returned_results = 0; - constexpr int requested_algo_count = 10; - std::vector heuristic_results( - requested_algo_count); - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulAlgoGetHeuristic(lt_handle, - desc->op_desc, - desc->y_desc, - desc->x_desc, - desc->out_desc, - desc->out_desc, - preference, - requested_algo_count, - heuristic_results.data(), - &returned_results)); - PADDLE_ENFORCE_GT(returned_results, - 0, - phi::errors::Unavailable("No GEMM algorithm avaliable.")); - int best_algo_idx = -1; - if (returned_results == 1 || FLAGS_cublaslt_exhaustive_search_times <= 0) { - best_algo_idx = 0; - } else { - float min_time_cost = std::numeric_limits::max(); - for (int algo_idx = 0; algo_idx < returned_results; ++algo_idx) { - float cur_time_cost = - RunAndMeasureAlgo(ctx, - lt_handle, - desc, - alpha, - beta, - y_data, - x_data, - out_data, - workspace_ptr, - workspace_size, - &(heuristic_results[algo_idx].algo)); - VLOG(6) << "[MatmulWithCublaslt] algo[" << algo_idx - << "] time: " << cur_time_cost << " s"; - - if ((best_algo_idx == 0 && (1.05 * cur_time_cost < min_time_cost)) || - (cur_time_cost < min_time_cost)) { - best_algo_idx = algo_idx; - min_time_cost = cur_time_cost; - } - } - } - VLOG(6) << "[MatmulWithCublaslt] best_algo_idx: " << best_algo_idx; - - cublasLtMatmulAlgo_t* best_algo = desc->SetAlgo(); - *best_algo = heuristic_results[best_algo_idx].algo; - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulPreferenceDestroy(preference)); - } - - static float RunAndMeasureAlgo(const phi::GPUContext& ctx, - const cublasLtHandle_t& lt_handle, - MatmulDescriptor* desc, - const void* alpha, - const void* beta, - const void* y_data, - const void* x_data, - void* out_data, - void* workspace_ptr, - size_t workspace_size, - cublasLtMatmulAlgo_t* algo) { - int repeats = FLAGS_cublaslt_exhaustive_search_times; - if (repeats <= 0) { - return std::numeric_limits::max(); - } - - phi::GpuTimer timer; - float time_cost = 0.f; - const auto& stream = ctx.stream(); - - for (int i = 0; i < repeats; ++i) { - timer.Start(stream); - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmul(lt_handle, - desc->op_desc, - alpha, - y_data, - desc->y_desc, - x_data, - desc->x_desc, - beta, - out_data, - desc->out_desc, - out_data, - desc->out_desc, - algo, - workspace_ptr, - workspace_size, - stream)); - timer.Stop(stream); - ctx.Wait(); - auto time = timer.ElapsedTime(); - if (i > 0) { - // Exclude the warmup runtime. - time_cost += time; - } - } - return (time_cost / (repeats - 1)); - } -}; - -// To judge if desc is cached or not. -template -struct DescriptorSetter { - public: - DescT desc; - size_t sub_key{std::numeric_limits::min()}; - - DescriptorSetter(phi::funcs::MatmulPlanner* planner, - const int64_t M, - const int64_t N, - const int64_t K, - const bool trans_x, - const bool trans_y, - const int batch_size = 1, - int64_t stride_x = 0, - int64_t stride_y = 0, - int64_t stride_out = 0, - const bool no_exchange = true, - bool grad_for_dx = true) { - if (std::is_same::value) { - if (!trans_x && !trans_y) { - PADDLE_ENFORCE_EQ( - (N % 4 == 0 || N == 1), - true, - phi::errors::InvalidArgument( - "The dimension size N used in int8 matmul must be 1 or a " - "multiple of 4 does not " - "match the size (%d) currently contained in the container.", - N)); - PADDLE_ENFORCE_EQ( - (K % 4 == 0), - true, - phi::errors::InvalidArgument( - "The dimension size K used in int8 matmul must be a multiple " - "of 4 does not " - "match the size (%d) currently contained in the container.", - K)); - } else if (!trans_x && trans_y) { - PADDLE_ENFORCE_EQ( - (K % 4 == 0), - true, - phi::errors::InvalidArgument( - "The dimension size K used in int8 matmul must be a multiple " - "of 4 does not " - "match the size (%d) currently contained in the container.", - K)); - } else if (trans_x && !trans_y) { - PADDLE_ENFORCE_EQ( - (M % 4 == 0 || M == 1), - true, - phi::errors::InvalidArgument( - "The dimension size M used in int8 matmul must be 1 or a " - "multiple of 4 does not " - "match the size (%d) currently contained in the container.", - M)); - PADDLE_ENFORCE_EQ( - (N % 4 == 0 || N == 1), - true, - phi::errors::InvalidArgument( - "The dimension size N used in int8 matmul must be 1 or a " - "multiple of 4 does not " - "match the size (%d) currently contained in the container.", - N)); - } else { - PADDLE_ENFORCE_EQ( - (M % 4 == 0 || M == 1), - true, - phi::errors::InvalidArgument( - "The dimension size M used in int8 matmul must be 1 or a " - "multiple of 4 does not " - "match the size (%d) currently contained in the container.", - M)); - PADDLE_ENFORCE_EQ( - (K % 4 == 0), - true, - phi::errors::InvalidArgument( - "The dimension size K used in int8 matmul must be a multiple " - "of 4 does not " - "match the size (%d) currently contained in the container.", - K)); - } - } - - if (planner != nullptr) { - sub_key = planner->GenSubKey(); - } - - auto& mamtul_cache = phi::autotune::AutoTuneCache::Instance().GetMatmul(); - if (mamtul_cache.FindSubKey(sub_key)) { - desc = *(reinterpret_cast(mamtul_cache.GetSubKey(sub_key))); - desc.template SetFusedEpiloguePtr(planner); - VLOG(7) << desc.GetDescResultString("[Heap CublasltDescriptor] "); - } else { - desc.template Create(M, - N, - K, - trans_x, - trans_y, - planner, - batch_size, - stride_x, - stride_y, - stride_out, - grad_for_dx); - desc.ExchangeXYDesc(no_exchange); - if (planner != nullptr) { - desc.template SetFusedEpiloguePtr(planner); - } - VLOG(7) << desc.GetDescResultString("[Stack CublasltDescriptor] ", false); - } - } -}; - -// For matmul with kernels autotune -template -struct MatmulWithCublasLt : public CublasLtBase { - public: - static void Run(const phi::GPUContext& ctx, - const T* x_data, - const T* y_data, - OutT* out_data, - const int64_t M, - const int64_t N, - const int64_t K, - const bool trans_x, - const bool trans_y, - phi::funcs::MatmulPlanner* planner = nullptr) { - auto setter = DescriptorSetter( - planner, M, N, K, trans_x, trans_y); - CublasLtBase::RunImpl( - ctx, &setter.desc, setter.sub_key, x_data, y_data, out_data, planner); - } - - static void RunWithBatch(const phi::GPUContext& ctx, - const T* x_data, - const T* y_data, - OutT* out_data, - const int64_t M, - const int64_t N, - const int64_t K, - bool trans_x, - bool trans_y, - int batch_size, - int64_t stride_x, - int64_t stride_y, - int64_t stride_out, - phi::funcs::MatmulPlanner* planner = nullptr) { - auto setter = DescriptorSetter(planner, - M, - N, - K, - trans_x, - trans_y, - batch_size, - stride_x, - stride_y, - stride_out); - CublasLtBase::RunImpl( - ctx, &setter.desc, setter.sub_key, x_data, y_data, out_data, planner); - } - - static void RunWithBatch(const phi::GPUContext& ctx, - const T** x_data, - const T** y_data, - OutT** out_data, - const int64_t M, - const int64_t N, - const int64_t K, - bool trans_x, - bool trans_y, - int batch_size, - phi::funcs::MatmulPlanner* planner = nullptr) { - for (int i = 0; i < batch_size; ++i) { - Run(ctx, - x_data[i], - y_data[i], - out_data[i], - M, - N, - K, - trans_x, - trans_y, - planner); - } - } -}; - -// As for just Linear fused ephilogue below: out = matmul(x, y) + bias. -template -struct LinearWithCublasLt : public CublasLtBase { - static void Run(const phi::GPUContext& ctx, - const phi::DenseTensor* x, - const phi::DenseTensor* y, - phi::DenseTensor* out, - const void* bias_data, - void* reserve_data, - const int64_t M, - const int64_t N, - const int64_t K, - const bool trans_x, - const bool trans_y, - const MatmulFusedType fused_type) { - auto planner = phi::funcs::MatmulPlanner(common::vectorize(x->dims()), - common::vectorize(y->dims()), - trans_x, - trans_y, - phi::CppTypeToDataType::Type(), - fused_type, - bias_data, - reserve_data); - auto setter = DescriptorSetter( - &planner, M, N, K, trans_x, trans_y); - CublasLtBase::RunImpl(ctx, - &setter.desc, - setter.sub_key, - x->data(), - y->data(), - out->data(), - &planner); - } -}; - -template -struct LinearGradWithCublasLt : public CublasLtBase { - static void Run( - const phi::GPUContext& ctx, - const phi::DenseTensor* x, - const phi::DenseTensor* y, - phi::DenseTensor* out, - const void* bias_data, - void* reserve_data, - const int64_t M, - const int64_t N, - const int64_t K, - const MatmulFusedType fused_type, - const bool trans_x, - const bool trans_y, - const bool use_addto, - const bool no_exchange, // exchange x_desc and y_desc for grad. - bool grad_for_dx = true) { - auto planner = phi::funcs::MatmulPlanner(common::vectorize(x->dims()), - common::vectorize(y->dims()), - trans_x, - trans_y, - phi::CppTypeToDataType::Type(), - fused_type, - bias_data, - reserve_data, - use_addto, - no_exchange); - auto setter = - DescriptorSetter( - &planner, - M, - N, - K, - trans_x, - trans_y, - /*batch_size=*/1, - /*stride_x=*/0, - /*stride_y=*/0, - /*stride_out=*/0, - /*exchange_x_y_desc=*/no_exchange, - /*grad_for_dx=*/grad_for_dx); - - // To setting data type for different kinda out_data. - if (grad_for_dx) { - CublasLtBase::RunImpl( - ctx, - &setter.desc, - setter.sub_key, - no_exchange ? x->data() : y->data(), - no_exchange ? y->data() : x->data(), - out->data(), - &planner); - } else { - CublasLtBase::RunImpl( - ctx, - &setter.desc, - setter.sub_key, - no_exchange ? x->data() : y->data(), - no_exchange ? y->data() : x->data(), - out->data(), - &planner); - } - } -}; -#else -// A void structure just for successfully compile. -struct MatmulPlanner {}; -#endif // (PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 && 0 - -} // namespace funcs -} // namespace phi diff --git a/backends/metax_gpu/kernels/funcs/blas/cublas.cc b/backends/metax_gpu/kernels/funcs/blas/cublas.cc deleted file mode 100644 index 77a0cced00b..00000000000 --- a/backends/metax_gpu/kernels/funcs/blas/cublas.cc +++ /dev/null @@ -1,40 +0,0 @@ -// 2024 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights -// Reserved. -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "cublas.h" // NOLINT - -namespace phi { -namespace dynload { -std::once_flag cublas_dso_flag; -void *cublas_dso_handle = nullptr; - -#define DEFINE_WRAP(__name) DynLoad__##__name __name - -CUBLAS_BLAS_ROUTINE_EACH(DEFINE_WRAP); - -#ifdef CUBLAS_BLAS_ROUTINE_EACH_R2 -CUBLAS_BLAS_ROUTINE_EACH_R2(DEFINE_WRAP); -#endif - -#ifdef CUBLAS_BLAS_ROUTINE_EACH_R3 -CUBLAS_BLAS_ROUTINE_EACH_R3(DEFINE_WRAP); -#endif - -#ifdef CUBLAS_BLAS_ROUTINE_EACH_R4 -CUBLAS_BLAS_ROUTINE_EACH_R4(DEFINE_WRAP); -#endif -} // namespace dynload -} // namespace phi diff --git a/backends/metax_gpu/kernels/funcs/blas/cublas.h b/backends/metax_gpu/kernels/funcs/blas/cublas.h deleted file mode 100755 index 776c7a1723b..00000000000 --- a/backends/metax_gpu/kernels/funcs/blas/cublas.h +++ /dev/null @@ -1,148 +0,0 @@ -// 2024 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights -// Reserved. -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -// clang-format off -#pragma once -#include -#include - -#include // NOLINT -#include - -#include "kernels/dynload/dynamic_loader.h" -#include "./port.h" // NOLINT -// clang-format on -namespace phi { -namespace dynload { - -extern std::once_flag cublas_dso_flag; -extern void* cublas_dso_handle; - -/** - * The following macro definition can generate structs - * (for each function) to dynamic load cublas routine - * via operator overloading. - * - * note: default dynamic linked libs - */ -#define DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(__name) \ - struct DynLoad__##__name { \ - template \ - inline auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \ - using cublas_func = \ - decltype(::__name(std::declval()...)) (*)(Args...); \ - std::call_once(cublas_dso_flag, []() { \ - cublas_dso_handle = phi::dynload::GetCublasDsoHandle(); \ - }); \ - std::string replaced_name = #__name; \ - replaced_name = replaced_name.replace(0, 2, "mc"); \ - int index = replaced_name.find("_", 0); \ - if (index != -1) replaced_name = replaced_name.substr(0, index); \ - static void* p_##__name = \ - dlsym(cublas_dso_handle, replaced_name.c_str()); \ - return reinterpret_cast(p_##__name)(args...); \ - } \ - }; \ - extern DynLoad__##__name __name - -#define CUBLAS_BLAS_ROUTINE_EACH(__macro) \ - __macro(cublasSaxpy_v2); \ - __macro(cublasDaxpy_v2); \ - __macro(cublasCaxpy_v2); \ - __macro(cublasZaxpy_v2); \ - __macro(cublasSscal_v2); \ - __macro(cublasDscal_v2); \ - __macro(cublasScopy_v2); \ - __macro(cublasDcopy_v2); \ - __macro(cublasSgemv_v2); \ - __macro(cublasDgemv_v2); \ - __macro(cublasCgemv_v2); \ - __macro(cublasZgemv_v2); \ - __macro(cublasSgemm_v2); \ - __macro(cublasDgemm_v2); \ - __macro(cublasCgemm_v2); \ - __macro(cublasZgemm_v2); \ - __macro(cublasHgemm); \ - __macro(cublasSgemmEx); \ - __macro(cublasSgeam); \ - __macro(cublasDgeam); \ - __macro(cublasStrsm_v2); \ - __macro(cublasDtrsm_v2); \ - __macro(cublasCtrsm_v2); \ - __macro(cublasZtrsm_v2); \ - __macro(cublasCreate_v2); \ - __macro(cublasDestroy_v2); \ - __macro(cublasSetStream_v2); \ - __macro(cublasSetPointerMode_v2); \ - __macro(cublasGetPointerMode_v2); \ - __macro(cublasSgemmBatched); \ - __macro(cublasDgemmBatched); \ - __macro(cublasCgemmBatched); \ - __macro(cublasZgemmBatched); \ - __macro(cublasStrsmBatched); \ - __macro(cublasDtrsmBatched); \ - __macro(cublasCtrsmBatched); \ - __macro(cublasZtrsmBatched); \ - __macro(cublasSgetrfBatched); \ - __macro(cublasSgetriBatched); \ - __macro(cublasDgetrfBatched); \ - __macro(cublasDgetriBatched); \ - __macro(cublasSmatinvBatched); \ - __macro(cublasDmatinvBatched); \ - __macro(cublasSgetrsBatched); \ - __macro(cublasDgetrsBatched); \ - __macro(cublasCgetrfBatched); \ - __macro(cublasCgetriBatched); \ - __macro(cublasCmatinvBatched); \ - __macro(cublasZgetrfBatched); \ - __macro(cublasZgetriBatched); \ - __macro(cublasZmatinvBatched); - -CUBLAS_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP) - -// APIs available after CUDA 8.0 -#if CUDA_VERSION >= 8000 -#define CUBLAS_BLAS_ROUTINE_EACH_R2(__macro) \ - __macro(cublasGemmEx); \ - __macro(cublasSgemmStridedBatched); \ - __macro(cublasDgemmStridedBatched); \ - __macro(cublasCgemmStridedBatched); \ - __macro(cublasZgemmStridedBatched); \ - __macro(cublasHgemmStridedBatched); - -CUBLAS_BLAS_ROUTINE_EACH_R2(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP) -#endif - -// APIs available after CUDA 9.0 -#if CUDA_VERSION >= 9000 -#define CUBLAS_BLAS_ROUTINE_EACH_R3(__macro) \ - __macro(cublasSetMathMode); \ - __macro(cublasGetMathMode); - -CUBLAS_BLAS_ROUTINE_EACH_R3(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP) -#endif - -// APIs available after CUDA 9.1 -#if CUDA_VERSION >= 9010 -#define CUBLAS_BLAS_ROUTINE_EACH_R4(__macro) \ - __macro(cublasGemmBatchedEx); \ - __macro(cublasGemmStridedBatchedEx); - -CUBLAS_BLAS_ROUTINE_EACH_R4(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP) -#endif - -#undef DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP -} // namespace dynload -} // namespace phi diff --git a/backends/metax_gpu/kernels/funcs/blas/cublasLt.cc b/backends/metax_gpu/kernels/funcs/blas/cublasLt.cc deleted file mode 100644 index 776f7fdd812..00000000000 --- a/backends/metax_gpu/kernels/funcs/blas/cublasLt.cc +++ /dev/null @@ -1,27 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "cublasLt.h" - -namespace phi { -namespace dynload { -std::once_flag cublasLt_dso_flag; -void *cublasLt_dso_handle = nullptr; - -#define DEFINE_WRAP(__name) DynLoad__##__name __name - -CUBLASLT_BLAS_ROUTINE_EACH(DEFINE_WRAP); - -} // namespace dynload -} // namespace phi diff --git a/backends/metax_gpu/kernels/funcs/blas/cublasLt.h b/backends/metax_gpu/kernels/funcs/blas/cublasLt.h deleted file mode 100644 index 2f8a929dd0c..00000000000 --- a/backends/metax_gpu/kernels/funcs/blas/cublasLt.h +++ /dev/null @@ -1,115 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -Copyright (c) 2022 NVIDIA Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include - -#include // NOLINT -#include - -#include "./port.h" -#include "kernels/dynload/dynamic_loader.h" - -namespace phi { -namespace dynload { - -extern std::once_flag cublasLt_dso_flag; -extern void* cublasLt_dso_handle; - -/** - * The following macro definition can generate structs - * (for each function) to dynamic load cublasLt routine - * via operator overloading. - * - * note: default dynamic linked libs - */ -#define DECLARE_DYNAMIC_LOAD_CUBLASLT_WRAP(__name) \ - struct DynLoad__##__name { \ - template \ - inline auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \ - using cublasLt_func = \ - decltype(::__name(std::declval()...)) (*)(Args...); \ - std::call_once(cublasLt_dso_flag, []() { \ - cublasLt_dso_handle = phi::dynload::GetCublasLtDsoHandle(); \ - }); \ - std::string replaced_name = #__name; \ - replaced_name = replaced_name.replace(0, 2, "mc"); \ - static void* p_##__name = \ - dlsym(cublasLt_dso_handle, replaced_name.c_str()); \ - return reinterpret_cast(p_##__name)(args...); \ - } \ - }; \ - extern DynLoad__##__name __name - -// APIs available after CUDA 11.1 -#if CUDA_VERSION >= 11010 -#define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \ - __macro(cublasLtCreate); \ - __macro(cublasLtDestroy); \ - __macro(cublasLtMatmul); \ - __macro(cublasLtMatmulDescCreate); \ - __macro(cublasLtMatmulDescDestroy); \ - __macro(cublasLtMatmulDescSetAttribute); \ - __macro(cublasLtMatmulDescGetAttribute); \ - __macro(cublasLtMatrixLayoutCreate); \ - __macro(cublasLtMatrixLayoutDestroy); \ - __macro(cublasLtMatrixLayoutSetAttribute); \ - __macro(cublasLtMatrixLayoutGetAttribute); \ - __macro(cublasLtMatmulPreferenceCreate); \ - __macro(cublasLtMatmulPreferenceDestroy); \ - __macro(cublasLtMatmulPreferenceSetAttribute); \ - __macro(cublasLtMatmulAlgoGetHeuristic); \ - __macro(cublasLtMatrixTransform); \ - __macro(cublasLtMatrixTransformDescCreate); \ - __macro(cublasLtMatrixTransformDescDestroy); \ - __macro(cublasLtMatrixTransformDescSetAttribute); \ - __macro(cublasLtMatmulAlgoInit); \ - __macro(cublasLtMatmulAlgoConfigSetAttribute); \ - __macro(cublasLtMatmulAlgoConfigGetAttribute); \ - __macro(cublasLtMatmulAlgoGetIds); \ - __macro(cublasLtMatmulAlgoCapGetAttribute); \ - __macro(cublasLtMatmulAlgoCheck); -// __macro(cublasLtGetCudartVersion); -#else -#define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \ - __macro(cublasLtCreate); \ - __macro(cublasLtDestroy); \ - __macro(cublasLtMatmul); \ - __macro(cublasLtMatmulDescCreate); \ - __macro(cublasLtMatmulDescDestroy); \ - __macro(cublasLtMatmulDescSetAttribute); \ - __macro(cublasLtMatmulDescGetAttribute); \ - __macro(cublasLtMatrixLayoutCreate); \ - __macro(cublasLtMatrixLayoutDestroy); \ - __macro(cublasLtMatrixLayoutSetAttribute); \ - __macro(cublasLtMatrixLayoutGetAttribute); \ - __macro(cublasLtMatmulPreferenceCreate); \ - __macro(cublasLtMatmulPreferenceDestroy); \ - __macro(cublasLtMatmulPreferenceSetAttribute); \ - __macro(cublasLtMatmulAlgoGetHeuristic); \ - __macro(cublasLtMatrixTransform); \ - __macro(cublasLtMatrixTransformDescCreate); \ - __macro(cublasLtMatrixTransformDescDestroy); \ - __macro(cublasLtMatrixTransformDescSetAttribute); -#endif - -CUBLASLT_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLASLT_WRAP) -// #endif - -#undef DECLARE_DYNAMIC_LOAD_CUBLASLT_WRAP -} // namespace dynload -} // namespace phi diff --git a/backends/metax_gpu/kernels/funcs/blas/cublaslt.h b/backends/metax_gpu/kernels/funcs/blas/cublaslt.h deleted file mode 100755 index 24505567baf..00000000000 --- a/backends/metax_gpu/kernels/funcs/blas/cublaslt.h +++ /dev/null @@ -1,328 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include -#include - -#include "./cublasLt.h" -#include "paddle/phi/common/float8_e4m3fn.h" -#include "paddle/phi/core/dense_tensor.h" - -namespace dyl = phi::dynload; - -namespace phi { - -struct CublasLtAlgoParam { - int algoId; - int swizzle; - int customOption; - int tile; - int splitK_val; - int reductionScheme; - int stages; - size_t workspace_size; -}; - -const std::map, CublasLtAlgoParam> AlgoParamCache{}; - -class CublasLtHelper { - public: - CublasLtHelper(int m, int k, int n, cublasLtHandle_t handle) - : handle_(handle), alpha_(1), beta_(0), m_(m), k_(k), n_(n) { - cublasStatus_t status; - - cublasComputeType_t cudaComputeType = CUBLAS_COMPUTE_32I; - - // matmul desc - status = dyl::cublasLtMatmulDescCreate( - &matmul_desc_, cudaComputeType, CUDA_R_32I); - - PADDLE_ENFORCE_EQ( - status, - CUBLAS_STATUS_SUCCESS, - common::errors::External( - "cublasLtMatmulDescCreate execution error" - "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " - "information")); - cublasOperation_t op_transpose = CUBLAS_OP_T; - status = dyl::cublasLtMatmulDescSetAttribute(matmul_desc_, - CUBLASLT_MATMUL_DESC_TRANSA, - &op_transpose, - sizeof(op_transpose)); - PADDLE_ENFORCE_EQ( - status, - CUBLAS_STATUS_SUCCESS, - common::errors::External( - "cublasLtMatmulDescSetAttribute execution error" - "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " - "information")); - - // matrix desc - status = dyl::cublasLtMatrixLayoutCreate(&B_desc_, CUDA_R_8I, k, n, k); - PADDLE_ENFORCE_EQ( - status, - CUBLAS_STATUS_SUCCESS, - common::errors::External( - "cublasLtMatrixLayoutCreate execution error" - "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " - "information")); - - status = dyl::cublasLtMatrixLayoutCreate(&A_desc_, CUDA_R_8I, k, m, k); - PADDLE_ENFORCE_EQ( - status, - CUBLAS_STATUS_SUCCESS, - common::errors::External( - "cublasLtMatrixLayoutCreate execution error" - "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " - "information")); - - status = dyl::cublasLtMatrixLayoutCreate(&C_desc_, CUDA_R_32I, n, m, n); - PADDLE_ENFORCE_EQ( - status, - CUBLAS_STATUS_SUCCESS, - common::errors::External( - "cublasLtMatrixLayoutCreate execution error" - "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " - "information")); - -#if CUDA_VERSION >= 11020 - - int algoId = 21; - int swizzle = 0; - int customOption = 0; - int tile = 15; - int splitK_val = 0; - int reductionScheme = 0; - int stages = 23; - workspace_size_ = 0; - if (m >= 128) { - tile = 20; - stages = 17; - } - - std::tuple key(m_, k_, n_); - if (AlgoParamCache.count(key) != 0) { - auto value = AlgoParamCache.at(key); - algoId = value.algoId; - swizzle = value.swizzle; - customOption = value.customOption; - tile = value.tile; - splitK_val = value.splitK_val; - reductionScheme = value.reductionScheme; - stages = value.stages; - workspace_size_ = value.workspace_size; - } - - dyl::cublasLtMatmulAlgoInit(handle_, - cudaComputeType, - CUDA_R_32I, - CUDA_R_8I, - CUDA_R_8I, - CUDA_R_32I, - CUDA_R_32I, - algoId, - &algo_); - dyl::cublasLtMatmulAlgoConfigSetAttribute( - &algo_, - CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, - &(customOption), - sizeof(customOption)); - dyl::cublasLtMatmulAlgoConfigSetAttribute( - &algo_, CUBLASLT_ALGO_CONFIG_TILE_ID, &(tile), sizeof(tile)); - dyl::cublasLtMatmulAlgoConfigSetAttribute(&algo_, - CUBLASLT_ALGO_CONFIG_SPLITK_NUM, - &(splitK_val), - sizeof(splitK_val)); - dyl::cublasLtMatmulAlgoConfigSetAttribute( - &algo_, - CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, - &(swizzle), - sizeof(swizzle)); - dyl::cublasLtMatmulAlgoConfigSetAttribute( - &algo_, - CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, - &(reductionScheme), - sizeof(int)); -#if CUDA_VERSION >= 11000 - dyl::cublasLtMatmulAlgoConfigSetAttribute( - &algo_, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(stages), sizeof(stages)); -#endif -#endif - } - ~CublasLtHelper() {} - - void GEMM(const int8_t* A_dev, - const int8_t* B_dev, - int32_t* C_dev, - cudaStream_t stream, - void* workspace = nullptr) { - cublasStatus_t status; - - status = dyl::cublasLtMatmul(handle_, - matmul_desc_, - &alpha_, - B_dev, - B_desc_, - A_dev, - A_desc_, - &beta_, - C_dev, - C_desc_, - C_dev, - C_desc_, -#if CUDA_VERSION >= 11020 - &algo_, - workspace, - workspace_size_, -#else - nullptr, - nullptr, - 0, -#endif - stream); - PADDLE_ENFORCE_EQ( - status, - CUBLAS_STATUS_SUCCESS, - common::errors::External( - "cublasLtMatmul execution error" - "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " - "information")); - } - - private: - cublasLtHandle_t handle_; - cublasLtMatmulDesc_t matmul_desc_; - cublasLtMatrixLayout_t A_desc_; - cublasLtMatrixLayout_t B_desc_; - cublasLtMatrixLayout_t C_desc_; - - cublasLtMatmulAlgo_t algo_; - - int32_t alpha_ = 1; - int32_t beta_ = 0; - - int m_ = 0; - int k_ = 0; - int n_ = 0; - - size_t workspace_size_ = 0; -}; - -template -inline cudaDataType_t GetCublasLtDataType() { - return CUDA_R_32F; -} - -template <> -inline cudaDataType_t GetCublasLtDataType() { - return CUDA_R_16F; -} - -template <> -inline cudaDataType_t GetCublasLtDataType() { - return CUDA_R_16BF; -} - -#if CUDA_VERSION >= 12010 -template -void CublasLtMatmulFP8(const phi::GPUContext& dev_ctx, - const phi::DenseTensor& mat_a, - const phi::DenseTensor& mat_b, - phi::DenseTensor* workspace, - phi::DenseTensor* out) { - int m = mat_a.dims()[0]; - int k = mat_a.dims()[1]; - int n = mat_b.dims()[1]; - - // init data structure - cublasStatus_t status; - auto A_type = CUDA_R_8F_E4M3; - auto B_type = CUDA_R_8F_E4M3; - auto C_type = GetCublasLtDataType(); - - cublasLtMatmulDesc_t matmul_desc_; - cublasLtMatrixLayout_t A_desc_; - cublasLtMatrixLayout_t B_desc_; - cublasLtMatrixLayout_t C_desc_; - float alpha_ = 1.0f; - float beta_ = 0.0f; - - cublasComputeType_t cudaComputeType = CUBLAS_COMPUTE_32F; - status = - dyl::cublasLtMatmulDescCreate(&matmul_desc_, cudaComputeType, CUDA_R_32F); - cublasOperation_t op_transpose = CUBLAS_OP_T; - status = dyl::cublasLtMatmulDescSetAttribute(matmul_desc_, - CUBLASLT_MATMUL_DESC_TRANSA, - &op_transpose, - sizeof(op_transpose)); - status = dyl::cublasLtMatrixLayoutCreate(&B_desc_, B_type, k, n, k); - status = dyl::cublasLtMatrixLayoutCreate(&A_desc_, A_type, k, m, k); - status = dyl::cublasLtMatrixLayoutCreate(&C_desc_, C_type, n, m, n); - - // Need to use heuristic - int returnedResults = 0; - cublasLtMatmulHeuristicResult_t heuristicResult = {}; - cublasLtMatmulPreference_t preference = NULL; - size_t work_space_size = workspace->numel(); - - status = dyl::cublasLtMatmulPreferenceCreate(&preference); - status = dyl::cublasLtMatmulPreferenceSetAttribute( - preference, - CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, - &work_space_size, - sizeof(work_space_size)); - - status = dyl::cublasLtMatmulAlgoGetHeuristic(dev_ctx.cublaslt_handle(), - matmul_desc_, - B_desc_, - A_desc_, - C_desc_, - C_desc_, - preference, - 1, - &heuristicResult, - &returnedResults); - - PADDLE_ENFORCE_NE(returnedResults, - 0, - common::errors::NotFound( - "Unable to find suitable cuBLAS GEMM algorithm")); - - status = - dyl::cublasLtMatmul(dev_ctx.cublaslt_handle(), - matmul_desc_, - &alpha_, - mat_b.data(), - B_desc_, - mat_a.data(), - A_desc_, - &beta_, - out->data(), - C_desc_, - out->data(), - C_desc_, - // nullptr, - &heuristicResult.algo, - // nullptr, - reinterpret_cast(workspace->data()), - // 0, - work_space_size, - dev_ctx.stream()); -} -#endif - -} // namespace phi diff --git a/backends/metax_gpu/kernels/funcs/blas/port.cc b/backends/metax_gpu/kernels/funcs/blas/port.cc deleted file mode 100644 index bc6d54e5c5f..00000000000 --- a/backends/metax_gpu/kernels/funcs/blas/port.cc +++ /dev/null @@ -1,163 +0,0 @@ -// 2024 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights -// Reserved. Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// clang-format off -#include "port.h" // NOLINT - -#include -#include -#include -#include -#include "glog/logging.h" -#if !defined(_WIN32) -#include // dladdr -#include -#include - -#else -#include // std::accumulate in msvc -// clang-format on -void *dlsym(void *handle, const char *symbol_name) { - FARPROC found_symbol; - found_symbol = GetProcAddress((HMODULE)handle, symbol_name); - - if (found_symbol == NULL) { - LOG(ERROR) << "Load symbol " << symbol_name << " failed."; - throw std::runtime_error(std::string(symbol_name) + " not found."); - } - return reinterpret_cast(found_symbol); -} - -void *dlopen(const char *filename, int flag) { - std::string file_name(filename); - HMODULE hModule = LoadLibrary(file_name.c_str()); - if (!hModule) { - if (flag) { - throw std::runtime_error(file_name + " not found."); - } else { - return nullptr; - } - } - return reinterpret_cast(hModule); -} - -int gettimeofday(struct timeval *tp, void *tzp) { - time_t clock; - struct tm tm; - SYSTEMTIME wtm; - - GetLocalTime(&wtm); - tm.tm_year = wtm.wYear - 1900; - tm.tm_mon = wtm.wMonth - 1; - tm.tm_mday = wtm.wDay; - tm.tm_hour = wtm.wHour; - tm.tm_min = wtm.wMinute; - tm.tm_sec = wtm.wSecond; - tm.tm_isdst = -1; - clock = mktime(&tm); - tp->tv_sec = clock; - tp->tv_usec = wtm.wMilliseconds * 1000; - - return (0); -} -#endif // !_WIN32 - -void ExecShellCommand(const std::string &cmd, std::string *message) { - std::array buffer; -#if !defined(_WIN32) - std::shared_ptr pipe(popen(cmd.c_str(), "r"), pclose); -#else - std::shared_ptr pipe(_popen(cmd.c_str(), "r"), _pclose); -#endif // _WIN32 - if (!pipe) { - LOG(ERROR) << "error running command: " << cmd; - return; - } - while (!feof(pipe.get())) { - if (fgets(buffer.data(), 128, pipe.get()) != nullptr) { - *message += buffer.data(); - } - } -} - -bool PathExists(const std::string &path) { -#if !defined(_WIN32) - struct stat statbuf; - if (stat(path.c_str(), &statbuf) != -1) { - if (S_ISDIR(statbuf.st_mode)) { - return true; - } - } -#else - struct _stat statbuf; - if (_stat(path.c_str(), &statbuf) != -1) { - if (S_ISDIR(statbuf.st_mode)) { - return true; - } - } -#endif // !_WIN32 - return false; -} - -#if !defined(_WIN32) -constexpr char kSEP = '/'; -#else -constexpr char kSEP = '\\'; -#endif // _WIN32 - -bool FileExists(const std::string &filepath) { -#if !defined(_WIN32) - struct stat buffer; - return (stat(filepath.c_str(), &buffer) == 0); -#else - struct _stat buffer; - return (_stat(filepath.c_str(), &buffer) == 0); -#endif // !_WIN32 -} - -std::string DirName(const std::string &filepath) { - auto pos = filepath.rfind(kSEP); - if (pos == std::string::npos) { - return ""; - } - return filepath.substr(0, pos); -} - -void MkDir(const char *path) { - std::string path_error(path); - path_error += " mkdir failed!"; -#if !defined(_WIN32) - if (mkdir(path, 0755)) { - if (errno != EEXIST) { - throw std::runtime_error(path_error); - } - } -#else - BOOL return_value = CreateDirectory(path, NULL); - if (!return_value) { - auto errorno = GetLastError(); - if (errorno != ERROR_ALREADY_EXISTS) { - throw std::runtime_error(path_error); - } - } -#endif // !_WIN32 -} - -void MkDirRecursively(const char *fullpath) { - if (*fullpath == '\0') return; // empty string - if (FileExists(fullpath)) return; - - MkDirRecursively(DirName(fullpath).c_str()); - MkDir(fullpath); -} diff --git a/backends/metax_gpu/kernels/funcs/blas/port.h b/backends/metax_gpu/kernels/funcs/blas/port.h deleted file mode 100644 index d2a59199bb7..00000000000 --- a/backends/metax_gpu/kernels/funcs/blas/port.h +++ /dev/null @@ -1,61 +0,0 @@ -// 2024 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights -// Reserved. Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include - -#define GLOG_NO_ABBREVIATED_SEVERITIES // msvc conflict logging with windows.h - -#if !defined(_WIN32) -#include // dladdr -#include - -#else -#ifndef NOMINMAX -#define NOMINMAX // msvc max/min macro conflict with std::min/max -#endif -// solve static linking error in windows -// https://github.com/google/glog/issues/301 -#define GOOGLE_GLOG_DLL_DECL -#include // _popen, _pclose -#include -#include -#include - -#ifndef S_ISDIR // windows port for sys/stat.h -#define S_ISDIR(mode) (((mode)&S_IFMT) == S_IFDIR) -#endif // S_ISDIR - -void *dlsym(void *handle, const char *symbol_name); - -void *dlopen(const char *filename, int flag); - -int gettimeofday(struct timeval *tp, void *tzp); -#endif // !_WIN32 - -void ExecShellCommand(const std::string &cmd, std::string *message); - -bool PathExists(const std::string &path); - -// TODO(yuyang18): If the functions below are needed by other files, move them -// to paddle::filesystem namespace. -bool FileExists(const std::string &filepath); - -std::string DirName(const std::string &filepath); - -void MkDir(const char *path); - -void MkDirRecursively(const char *fullpath); diff --git a/backends/metax_gpu/kernels/funcs/layer_norm_util.h b/backends/metax_gpu/kernels/funcs/layer_norm_util.h index 3e16e615b1d..0f8210d8b8f 100644 --- a/backends/metax_gpu/kernels/funcs/layer_norm_util.h +++ b/backends/metax_gpu/kernels/funcs/layer_norm_util.h @@ -18,7 +18,7 @@ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/device_context.h" -#include "../funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/math_function.h" // clang-format on namespace phi { diff --git a/backends/metax_gpu/kernels/funcs/quant_dequant.h b/backends/metax_gpu/kernels/funcs/quant_dequant.h deleted file mode 100644 index 301ae351c40..00000000000 --- a/backends/metax_gpu/kernels/funcs/quant_dequant.h +++ /dev/null @@ -1,430 +0,0 @@ -// 2024 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights -// Reserved. -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -// clang-format off -#include -#include "paddle/common/hostdevice.h" -#include "paddle/phi/backends/gpu/gpu_launch_config.h" -#include "paddle/phi/common/float16.h" -#include "paddle/phi/common/transform.h" -#include "paddle/phi/kernels/funcs/aligned_vector.h" -#include "blas/blas.h" -// clang-format on -namespace phi { - -using backends::gpu::GpuLaunchConfig; - -constexpr int DequantKernelVecSize = 4; - -template -inline HOSTDEVICE T roundWithTiesToEven(T x) { - T xLower = floor(x); - T xUpper = ceil(x); - // x is in interval [xl,xu]. Choose closest of two bounds, breaking ties to - // even. - T dLower = x - xLower; - T dUpper = xUpper - x; - return static_cast( - (dLower == dUpper ? fmod(xLower, 2.0F) == 0.0F : dLower < dUpper) - ? xLower - : xUpper); -} - -template -inline HOSTDEVICE T roundWithTiesAwayFromZero(T x) { - return static_cast(x > 0 ? ceil(x) : floor(x)); -} - -template -__forceinline__ __device__ int8_t quant_helper(const T input, - const float scale, - const int round_type, - const float max_bound, - const float min_bound) { - float quant_value = max_bound * scale * static_cast(input); - - if (round_type == 0) { - quant_value = static_cast(roundWithTiesToEven(quant_value)); - } else { - quant_value = static_cast(round(quant_value)); - } - quant_value = quant_value > max_bound ? max_bound : quant_value; - quant_value = quant_value < min_bound ? min_bound : quant_value; - return static_cast(quant_value); -} - -template -__forceinline__ __device__ int8_t -quant_helper_ties_to_even_or_away_from_zero(const T input, - const float scale, - const int round_type, - const float max_bound, - const float min_bound) { - float quant_value = max_bound * scale * static_cast(input); - - if (round_type == 0) { - quant_value = static_cast(roundWithTiesToEven(quant_value)); - } else { - quant_value = static_cast(roundWithTiesAwayFromZero(quant_value)); - } - quant_value = quant_value > max_bound ? max_bound : quant_value; - quant_value = quant_value < min_bound ? min_bound : quant_value; - return static_cast(quant_value); -} - -template -__global__ void QuantKernel(const T* input, - char4* output, - const float scale, - const int m, - const int n, - const int round_type, - const float max_bound, - const float min_bound) { - int n_id = (blockIdx.x * blockDim.x + threadIdx.x) << 2; - int m_id = blockIdx.y * blockDim.y + threadIdx.y; - - bool check = ((m_id < m) && (n_id < n)); - if (check) { - char4 tmp; - tmp.x = quant_helper( - input[m_id * n + n_id], scale, round_type, max_bound, min_bound); - tmp.y = quant_helper( - input[m_id * n + n_id + 1], scale, round_type, max_bound, min_bound); - tmp.z = quant_helper( - input[m_id * n + n_id + 2], scale, round_type, max_bound, min_bound); - tmp.w = quant_helper( - input[m_id * n + n_id + 3], scale, round_type, max_bound, min_bound); - output[(m_id * n + n_id) >> 2] = tmp; - } -} - -template -__global__ void QuantKernelWithVecSize(const T* input, - char4* output, - const float scale, - const int m, - const int n, - const int round_type, - const float max_bound, - const float min_bound) { - int n_id = (blockIdx.x * blockDim.x + threadIdx.x) << 2; - int m_id = blockIdx.y * blockDim.y + threadIdx.y; - - bool check = ((m_id < m) && (n_id < n)); - if (check) { - char4 tmp; - tmp.x = quant_helper_ties_to_even_or_away_from_zero( - input[m_id * n + n_id], scale, round_type, max_bound, min_bound); - tmp.y = quant_helper_ties_to_even_or_away_from_zero( - input[m_id * n + n_id + 1], scale, round_type, max_bound, min_bound); - tmp.z = quant_helper_ties_to_even_or_away_from_zero( - input[m_id * n + n_id + 2], scale, round_type, max_bound, min_bound); - tmp.w = quant_helper_ties_to_even_or_away_from_zero( - input[m_id * n + n_id + 3], scale, round_type, max_bound, min_bound); - output[(m_id * n + n_id) >> 2] = tmp; - } -} - -template -__global__ void QuantKernelWithVecSize(const T* input, - char3* output, - const float scale, - const int m, - const int n, - const int round_type, - const float max_bound, - const float min_bound) { - int n_id = (blockIdx.x * blockDim.x + threadIdx.x) * 3; - int m_id = blockIdx.y * blockDim.y + threadIdx.y; - - bool check = ((m_id < m) && (n_id < n)); - if (check) { - char3 tmp; - tmp.x = quant_helper_ties_to_even_or_away_from_zero( - input[m_id * n + n_id], scale, round_type, max_bound, min_bound); - tmp.y = quant_helper_ties_to_even_or_away_from_zero( - input[m_id * n + n_id + 1], scale, round_type, max_bound, min_bound); - tmp.z = quant_helper_ties_to_even_or_away_from_zero( - input[m_id * n + n_id + 2], scale, round_type, max_bound, min_bound); - output[(m_id * n + n_id) / 3] = tmp; - } -} - -template -__global__ void QuantKernelWithVecSize(const T* input, - char2* output, - const float scale, - const int m, - const int n, - const int round_type, - const float max_bound, - const float min_bound) { - int n_id = (blockIdx.x * blockDim.x + threadIdx.x) * 2; - int m_id = blockIdx.y * blockDim.y + threadIdx.y; - - bool check = ((m_id < m) && (n_id < n)); - if (check) { - char2 tmp; - tmp.x = quant_helper_ties_to_even_or_away_from_zero( - input[m_id * n + n_id], scale, round_type, max_bound, min_bound); - tmp.y = quant_helper_ties_to_even_or_away_from_zero( - input[m_id * n + n_id + 1], scale, round_type, max_bound, min_bound); - output[(m_id * n + n_id) >> 1] = tmp; - } -} - -template -__global__ void QuantKernelWithVecSize(const T* input, - char* output, - const float scale, - const int m, - const int n, - const int round_type, - const float max_bound, - const float min_bound) { - int n_id = (blockIdx.x * blockDim.x + threadIdx.x); - int m_id = blockIdx.y * blockDim.y + threadIdx.y; - - bool check = ((m_id < m) && (n_id < n)); - if (check) { - char tmp; - tmp = quant_helper_ties_to_even_or_away_from_zero( - input[m_id * n + n_id], scale, round_type, max_bound, min_bound); - output[m_id * n + n_id] = tmp; - } -} - -template -void LaunchQuantKernel(const T* input, - int8_t* output, - const float scale, - const int m, - const int n, - const int round_type, - const float max_bound, - const float min_bound, - gpuStream_t stream) { - // TODO(minghaoBD): optimize the kennel launch times when m==1 or n==1 -#ifdef PADDLE_WITH_HIP - dim3 grid(((n >> 2) + 63) / 64, (m + 7) / 8); - dim3 block(64, 8); -#else - dim3 grid(((n >> 2) + 31) / 32, (m + 31) / 32); - dim3 block(32, 32); -#endif - - QuantKernel<<>>(input, - (char4*)output, // NOLINT - scale, - m, - n, - round_type, - max_bound, - min_bound); -} - -template -void LaunchQuantKernelWithVecSize(const T* input, - int8_t* output, - const float scale, - const int m, - const int n, - const int round_type, - const float max_bound, - const float min_bound, - gpuStream_t stream) { - int vec_size = 1; - if (n % 4 == 0) { - vec_size = 4; - } else if (n % 3 == 0) { - vec_size = 3; - } else if (n % 2 == 0) { - vec_size = 2; - } - -#ifdef PADDLE_WITH_HIP - dim3 grid(((n / vec_size) + 63) / 64, (m + 7) / 8); - dim3 block(64, 8); -#else - dim3 grid(((n / vec_size) + 31) / 32, (m + 31) / 32); - dim3 block(32, 32); -#endif - - switch (vec_size) { - case 4: - QuantKernelWithVecSize<<>>( - input, - reinterpret_cast(output), - scale, - m, - n, - round_type, - max_bound, - min_bound); - break; - case 3: - QuantKernelWithVecSize<<>>( - input, - reinterpret_cast(output), - scale, - m, - n, - round_type, - max_bound, - min_bound); - break; - case 2: - QuantKernelWithVecSize<<>>( - input, - reinterpret_cast(output), - scale, - m, - n, - round_type, - max_bound, - min_bound); - break; - case 1: - QuantKernelWithVecSize<<>>( - input, - reinterpret_cast(output), - scale, - m, - n, - round_type, - max_bound, - min_bound); - break; - default: - return; - } -} - -template -__global__ void DequantKernel(T* output, - const int32_t* input, - const int m, // batch size - const int n, // hidden - const float quant_in_scale, - const float* dequant_out_scale_data) { - int numel = m * n; - int stride = blockDim.x * gridDim.x * VecSize; - int idx = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize; - int col_id = idx % n; - - phi::AlignedVector in_vec; - phi::AlignedVector out_scale_vec; - phi::AlignedVector out_vec; - - for (; idx < numel; idx += stride) { - phi::Load(input + idx, &in_vec); - phi::Load(dequant_out_scale_data + col_id, &out_scale_vec); - -#pragma unroll - for (int i = 0; i < VecSize; ++i) { - out_vec[i] = - static_cast(static_cast(in_vec[i]) * out_scale_vec[i]); - } - - phi::Store(out_vec, output + idx); - } -} - -template -void LaunchDequantKernel(const int32_t* input, - T* output, - const int m, // m - const int n, // n - gpuStream_t stream, - GpuLaunchConfig* gpu_config, - const float quant_in_scale, - const float* dequant_out_scale_data) { - DequantKernel - <<block_per_grid, gpu_config->thread_per_block, 0, stream>>>( - output, input, m, n, quant_in_scale, dequant_out_scale_data); -} - -template -__global__ void DequantKernelWithScaleOfInputAndWeight( - T* output, - const int32_t* input, - const int m, // batch size - const int n, // hidden - const float quant_in_scale, - const float* quant_weight_scale, - float quant_max_bound) { - int numel = m * n; - int stride = blockDim.x * gridDim.x * VecSize; - int idx = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize; - int col_id = idx % n; - - phi::AlignedVector in_vec; - phi::AlignedVector out_scale_vec; - phi::AlignedVector out_vec; - - for (; idx < numel; idx += stride) { - phi::Load(input + idx, &in_vec); - phi::Load(quant_weight_scale + col_id, &out_scale_vec); - -#pragma unroll - for (int i = 0; i < VecSize; ++i) { - out_vec[i] = static_cast(static_cast(in_vec[i]) / - (quant_max_bound * quant_max_bound * - quant_in_scale * out_scale_vec[i])); - } - - phi::Store(out_vec, output + idx); - } -} - -template -void LaunchDequantKernelWithScaleOfInputAndWeight( - const int32_t* input, - T* output, - const int m, // m - const int n, // n - gpuStream_t stream, - GpuLaunchConfig* gpu_config, - const float quant_in_scale, - const float* quant_weight_scale, - float quant_max_bound) { - if (n % DequantKernelVecSize != 0) { - DequantKernelWithScaleOfInputAndWeight<<block_per_grid, - gpu_config->thread_per_block, - 0, - stream>>>(output, - input, - m, - n, - quant_in_scale, - quant_weight_scale, - quant_max_bound); - return; - } - DequantKernelWithScaleOfInputAndWeight - <<block_per_grid, gpu_config->thread_per_block, 0, stream>>>( - output, - input, - m, - n, - quant_in_scale, - quant_weight_scale, - quant_max_bound); -} - -} // namespace phi diff --git a/backends/metax_gpu/kernels/gpudnn/cudnn.cc b/backends/metax_gpu/kernels/gpudnn/cudnn.cc deleted file mode 100644 index dc403282c1c..00000000000 --- a/backends/metax_gpu/kernels/gpudnn/cudnn.cc +++ /dev/null @@ -1,78 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/phi/backends/dynload/cudnn.h" // NOLINT - -#include "paddle/phi/core/enforce.h" - -namespace phi::dynload { - -std::once_flag cudnn_dso_flag; -void* cudnn_dso_handle = nullptr; - -#define DEFINE_WRAP(__name) DynLoad__##__name __name - -CUDNN_DNN_ROUTINE_EACH(DEFINE_WRAP); - -#ifdef CUDNN_DNN_ROUTINE_EACH_AFTER_R7_LESS_R8 -CUDNN_DNN_ROUTINE_EACH_AFTER_R7_LESS_R8(DEFINE_WRAP); -#endif - -#ifdef CUDNN_DNN_ROUTINE_EACH_R7 -CUDNN_DNN_ROUTINE_EACH_R7(DEFINE_WRAP); -#endif - -#ifdef CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7 -CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7(DEFINE_WRAP); -#endif - -#ifdef CUDNN_DNN_ROUTINE_EACH_AFTER_R7 -CUDNN_DNN_ROUTINE_EACH_AFTER_R7(DEFINE_WRAP); -#endif - -#ifdef CUDNN_DNN_ROUTINE_EACH_R8 -CUDNN_DNN_ROUTINE_EACH_R8(DEFINE_WRAP); -#endif - -#ifdef CUDNN_DNN_ROUTINE_EACH_FRONTEND -CUDNN_DNN_ROUTINE_EACH_FRONTEND(DEFINE_WRAP); -#endif - -#ifdef CUDNN_DNN_ROUTINE_EACH_REMOVED_IN_E9 -CUDNN_DNN_ROUTINE_EACH_REMOVED_IN_E9(DEFINE_WRAP); -#endif - -#ifdef CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7_REMOVED_IN_E9 -CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7_REMOVED_IN_E9(DEFINE_WRAP); -#endif - -#ifdef CUDNN_DNN_ROUTINE_EACH_R9 -CUDNN_DNN_ROUTINE_EACH_R9(DEFINE_WRAP); -#endif - -bool HasCUDNN() { - std::call_once(cudnn_dso_flag, - []() { cudnn_dso_handle = GetCUDNNDsoHandle(); }); - return cudnn_dso_handle != nullptr; -} - -void EnforceCUDNNLoaded(const char* fn_name) { - PADDLE_ENFORCE_NOT_NULL( - cudnn_dso_handle, - common::errors::PreconditionNotMet( - "Cannot load cudnn shared library. Cannot invoke method %s.", - fn_name)); -} - -} // namespace phi::dynload diff --git a/backends/metax_gpu/kernels/gpudnn/cudnn.h b/backends/metax_gpu/kernels/gpudnn/cudnn.h deleted file mode 100644 index 65cb6b338b7..00000000000 --- a/backends/metax_gpu/kernels/gpudnn/cudnn.h +++ /dev/null @@ -1,218 +0,0 @@ -// 2024 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights -// Reserved. -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#ifdef PADDLE_WITH_CUDA -#include - -#include // NOLINT - -#include "paddle/phi/backends/dynload/dynamic_loader.h" -#include "paddle/phi/common/port.h" - -namespace phi { -namespace dynload { - -extern std::once_flag cudnn_dso_flag; -extern void* cudnn_dso_handle; -extern bool HasCUDNN(); - -extern void EnforceCUDNNLoaded(const char* fn_name); -#define DECLARE_DYNAMIC_LOAD_CUDNN_WRAP(__name) \ - struct DynLoad__##__name { \ - template \ - auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \ - using cudnn_func = decltype(&::__name); \ - std::call_once(cudnn_dso_flag, []() { \ - cudnn_dso_handle = phi::dynload::GetCUDNNDsoHandle(); \ - }); \ - EnforceCUDNNLoaded(#__name); \ - std::string replaced_name = #__name; \ - replaced_name = replaced_name.replace(0, 2, "mc"); \ - static void* p_##__name = \ - dlsym(cudnn_dso_handle, replaced_name.c_str()); \ - return reinterpret_cast(p_##__name)(args...); \ - } \ - }; \ - extern struct DynLoad__##__name __name - -/** - * include all needed cudnn functions in HPPL - * different cudnn version has different interfaces - **/ -#define CUDNN_DNN_ROUTINE_EACH(__macro) \ - __macro(cudnnSetTensor4dDescriptor); \ - __macro(cudnnSetTensor4dDescriptorEx); \ - __macro(cudnnSetTensorNdDescriptor); \ - __macro(cudnnGetTensorNdDescriptor); \ - __macro(cudnnGetConvolutionNdForwardOutputDim); \ - __macro(cudnnCreateTensorDescriptor); \ - __macro(cudnnDestroyTensorDescriptor); \ - __macro(cudnnCreateFilterDescriptor); \ - __macro(cudnnSetFilter4dDescriptor); \ - __macro(cudnnSetFilterNdDescriptor); \ - __macro(cudnnGetFilterNdDescriptor); \ - __macro(cudnnSetPooling2dDescriptor); \ - __macro(cudnnSetPoolingNdDescriptor); \ - __macro(cudnnGetPoolingNdDescriptor); \ - __macro(cudnnDestroyFilterDescriptor); \ - __macro(cudnnCreateConvolutionDescriptor); \ - __macro(cudnnCreatePoolingDescriptor); \ - __macro(cudnnDestroyPoolingDescriptor); \ - __macro(cudnnSetConvolution2dDescriptor); \ - __macro(cudnnDestroyConvolutionDescriptor); \ - __macro(cudnnSetConvolutionNdDescriptor); \ - __macro(cudnnGetConvolutionNdDescriptor); \ - __macro(cudnnDeriveBNTensorDescriptor); \ - __macro(cudnnCreateSpatialTransformerDescriptor); \ - __macro(cudnnSetSpatialTransformerNdDescriptor); \ - __macro(cudnnDestroySpatialTransformerDescriptor); \ - __macro(cudnnSpatialTfGridGeneratorForward); \ - __macro(cudnnSpatialTfGridGeneratorBackward); \ - __macro(cudnnSpatialTfSamplerForward); \ - __macro(cudnnSpatialTfSamplerBackward); \ - __macro(cudnnCreate); \ - __macro(cudnnDestroy); \ - __macro(cudnnSetStream); \ - __macro(cudnnActivationForward); \ - __macro(cudnnActivationBackward); \ - __macro(cudnnConvolutionForward); \ - __macro(cudnnConvolutionBackwardBias); \ - __macro(cudnnGetConvolutionForwardWorkspaceSize); \ - __macro(cudnnTransformTensor); \ - __macro(cudnnPoolingForward); \ - __macro(cudnnPoolingBackward); \ - __macro(cudnnSoftmaxBackward); \ - __macro(cudnnSoftmaxForward); \ - __macro(cudnnGetVersion); \ - __macro(cudnnFindConvolutionForwardAlgorithmEx); \ - __macro(cudnnFindConvolutionBackwardFilterAlgorithmEx); \ - __macro(cudnnFindConvolutionBackwardFilterAlgorithm); \ - __macro(cudnnFindConvolutionBackwardDataAlgorithmEx); \ - __macro(cudnnGetErrorString); \ - __macro(cudnnCreateDropoutDescriptor); \ - __macro(cudnnDropoutGetStatesSize); \ - __macro(cudnnSetDropoutDescriptor); \ - __macro(cudnnRestoreDropoutDescriptor); \ - __macro(cudnnCreateRNNDescriptor); \ - __macro(cudnnGetRNNParamsSize); \ - __macro(cudnnGetRNNWorkspaceSize); \ - __macro(cudnnGetRNNTrainingReserveSize); \ - __macro(cudnnRNNForwardTraining); \ - __macro(cudnnRNNBackwardData); \ - __macro(cudnnRNNBackwardWeights); \ - __macro(cudnnRNNForwardInference); \ - __macro(cudnnDestroyDropoutDescriptor); \ - __macro(cudnnDestroyRNNDescriptor); \ - __macro(cudnnSetTensorNdDescriptorEx); \ - __macro(cudnnAddTensor); \ - __macro(cudnnConvolutionBackwardData); \ - __macro(cudnnConvolutionBackwardFilter); \ - __macro(cudnnGetConvolutionBackwardFilterWorkspaceSize); \ - __macro(cudnnGetConvolutionBackwardDataWorkspaceSize); \ - __macro(cudnnBatchNormalizationForwardTraining); \ - __macro(cudnnBatchNormalizationForwardInference); \ - __macro(cudnnBatchNormalizationBackward); \ - __macro(cudnnCreateActivationDescriptor); \ - __macro(cudnnSetActivationDescriptor); \ - __macro(cudnnGetActivationDescriptor); \ - __macro(cudnnDestroyActivationDescriptor); \ - __macro(cudnnSetRNNDescriptor_v6); -CUDNN_DNN_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) - -#if CUDNN_VERSION >= 7000 && CUDNN_VERSION < 8000 -#define CUDNN_DNN_ROUTINE_EACH_AFTER_R7_LESS_R8(__macro) \ - __macro(cudnnGetConvolutionBackwardFilterAlgorithm); \ - __macro(cudnnGetConvolutionForwardAlgorithm); \ - __macro(cudnnGetConvolutionBackwardDataAlgorithm); \ - __macro(cudnnSetRNNDescriptor); -CUDNN_DNN_ROUTINE_EACH_AFTER_R7_LESS_R8(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) -#endif - -#if CUDNN_VERSION >= 7001 -#define CUDNN_DNN_ROUTINE_EACH_R7(__macro) \ - __macro(cudnnSetConvolutionGroupCount); \ - __macro(cudnnSetConvolutionMathType); \ - __macro(cudnnConvolutionBiasActivationForward); \ - __macro(cudnnCreateCTCLossDescriptor); \ - __macro(cudnnDestroyCTCLossDescriptor); \ - __macro(cudnnGetCTCLossDescriptor); \ - __macro(cudnnSetCTCLossDescriptor); \ - __macro(cudnnGetCTCLossWorkspaceSize); \ - __macro(cudnnCTCLoss); \ - __macro(cudnnGetConvolutionBackwardDataAlgorithm_v7); \ - __macro(cudnnGetConvolutionBackwardFilterAlgorithm_v7); \ - __macro(cudnnGetConvolutionForwardAlgorithm_v7); \ - __macro(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount); -CUDNN_DNN_ROUTINE_EACH_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) -#endif - -#if CUDNN_VERSION >= 7201 -#define CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7(__macro) \ - __macro(cudnnCreateRNNDataDescriptor); \ - __macro(cudnnDestroyRNNDataDescriptor); \ - __macro(cudnnSetRNNDataDescriptor); \ - __macro(cudnnSetRNNPaddingMode); \ - __macro(cudnnRNNForwardTrainingEx); \ - __macro(cudnnRNNBackwardDataEx); \ - __macro(cudnnRNNBackwardWeightsEx); \ - __macro(cudnnRNNForwardInferenceEx); -CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) -#endif - -#if CUDNN_VERSION >= 7401 -#define CUDNN_DNN_ROUTINE_EACH_AFTER_R7(__macro) \ - __macro(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize); \ - __macro(cudnnBatchNormalizationForwardTrainingEx); \ - __macro(cudnnGetBatchNormalizationBackwardExWorkspaceSize); \ - __macro(cudnnBatchNormalizationBackwardEx); \ - __macro(cudnnGetBatchNormalizationTrainingExReserveSpaceSize); -CUDNN_DNN_ROUTINE_EACH_AFTER_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) -#endif - -#if CUDNN_VERSION >= 8000 -#define CUDNN_DNN_ROUTINE_EACH_R8(__macro) \ - __macro(cudnnSetRNNDescriptor_v8); \ - __macro(cudnnCreateFusedOpsPlan); \ - __macro(cudnnCreateFusedOpsConstParamPack); \ - __macro(cudnnCreateFusedOpsVariantParamPack); \ - __macro(cudnnDestroyFusedOpsPlan); \ - __macro(cudnnDestroyFusedOpsConstParamPack); \ - __macro(cudnnDestroyFusedOpsVariantParamPack); \ - __macro(cudnnFusedOpsExecute); \ - __macro(cudnnSetFusedOpsConstParamPackAttribute); \ - __macro(cudnnSetFusedOpsVariantParamPackAttribute); \ - __macro(cudnnMakeFusedOpsPlan); -CUDNN_DNN_ROUTINE_EACH_R8(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) -#endif - -#ifdef PADDLE_WITH_CUDNN_FRONTEND -#define CUDNN_DNN_ROUTINE_EACH_FRONTEND(__macro) \ - __macro(cudnnBackendCreateDescriptor); \ - __macro(cudnnBackendDestroyDescriptor); \ - __macro(cudnnBackendExecute); \ - __macro(cudnnBackendFinalize); \ - __macro(cudnnBackendGetAttribute); \ - __macro(cudnnBackendSetAttribute); \ - __macro(cudnnGetStream); \ - __macro(cudnnReorderFilterAndBias); -CUDNN_DNN_ROUTINE_EACH_FRONTEND(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) -#endif - -} // namespace dynload -} // namespace phi - -#endif diff --git a/backends/metax_gpu/kernels/impl/addmm_kernel_impl.h b/backends/metax_gpu/kernels/impl/addmm_kernel_impl.h index b517b719d49..a2c69b6adf0 100644 --- a/backends/metax_gpu/kernels/impl/addmm_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/addmm_kernel_impl.h @@ -20,7 +20,7 @@ limitations under the License. */ #include "glog/logging.h" #include "paddle/phi/kernels/addmm_kernel.h" -#include "../funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" // clang-format on diff --git a/backends/metax_gpu/kernels/impl/baddbmm_kernel_impl.h b/backends/metax_gpu/kernels/impl/baddbmm_kernel_impl.h index 593c044fc76..1c52ea22e4e 100644 --- a/backends/metax_gpu/kernels/impl/baddbmm_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/baddbmm_kernel_impl.h @@ -17,9 +17,9 @@ limitations under the License. */ #include #include "glog/logging.h" -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/kernels/baddbmm_kernel.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" diff --git a/backends/metax_gpu/kernels/impl/bilinear_grad_kernel_impl.h b/backends/metax_gpu/kernels/impl/bilinear_grad_kernel_impl.h index ef61d48202f..b64f94bc7ef 100644 --- a/backends/metax_gpu/kernels/impl/bilinear_grad_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/bilinear_grad_kernel_impl.h @@ -14,8 +14,8 @@ #pragma once -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/eigen/common.h" namespace phi { diff --git a/backends/metax_gpu/kernels/impl/bilinear_kernel_impl.h b/backends/metax_gpu/kernels/impl/bilinear_kernel_impl.h index c124e84eb6d..48861d48932 100644 --- a/backends/metax_gpu/kernels/impl/bilinear_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/bilinear_kernel_impl.h @@ -14,8 +14,8 @@ #pragma once -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/utils/optional.h" diff --git a/backends/metax_gpu/kernels/impl/bmm_grad_kernel_impl.h b/backends/metax_gpu/kernels/impl/bmm_grad_kernel_impl.h index 543df3ee964..cd5978ae59f 100644 --- a/backends/metax_gpu/kernels/impl/bmm_grad_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/bmm_grad_kernel_impl.h @@ -14,9 +14,9 @@ #pragma once -#include "kernels/funcs/blas/blas.h" -#include "kernels/impl/matmul_grad_kernel_impl.h" #include "paddle/phi/kernels/bmm_grad_kernel.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/impl/matmul_grad_kernel_impl.h" namespace phi { diff --git a/backends/metax_gpu/kernels/impl/bmm_kernel_impl.h b/backends/metax_gpu/kernels/impl/bmm_kernel_impl.h index 7b4164032b2..ce493b4908a 100644 --- a/backends/metax_gpu/kernels/impl/bmm_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/bmm_kernel_impl.h @@ -14,8 +14,8 @@ #pragma once -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/bmm_kernel.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" namespace phi { diff --git a/backends/metax_gpu/kernels/impl/cholesky_grad_kernel_impl.h b/backends/metax_gpu/kernels/impl/cholesky_grad_kernel_impl.h index 02332652660..5d146dae8d5 100644 --- a/backends/metax_gpu/kernels/impl/cholesky_grad_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/cholesky_grad_kernel_impl.h @@ -14,8 +14,8 @@ limitations under the License. */ #pragma once -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/cholesky_grad_kernel.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/for_range.h" namespace phi { diff --git a/backends/metax_gpu/kernels/impl/cholesky_solve_grad_kernel_impl.h b/backends/metax_gpu/kernels/impl/cholesky_solve_grad_kernel_impl.h index 62115e9ee6a..098092767c4 100644 --- a/backends/metax_gpu/kernels/impl/cholesky_solve_grad_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/cholesky_solve_grad_kernel_impl.h @@ -14,7 +14,6 @@ #pragma once -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/cholesky_solve_grad_kernel.h" #include "paddle/phi/kernels/cholesky_solve_kernel.h" #include "paddle/phi/kernels/complex_kernel.h" @@ -22,6 +21,7 @@ #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/expand_kernel.h" #include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/common_shape.h" #include "paddle/phi/kernels/funcs/complex_functors.h" #include "paddle/phi/kernels/funcs/for_range.h" diff --git a/backends/metax_gpu/kernels/impl/conv_grad_kernel_impl.h b/backends/metax_gpu/kernels/impl/conv_grad_kernel_impl.h index 25e0d93a6a4..6066720ab07 100644 --- a/backends/metax_gpu/kernels/impl/conv_grad_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/conv_grad_kernel_impl.h @@ -14,10 +14,10 @@ #pragma once -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/cpu/conv_util.h" #include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/funcs/batch_norm_utils.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/im2col.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/vol2col.h" diff --git a/backends/metax_gpu/kernels/impl/conv_kernel_impl.h b/backends/metax_gpu/kernels/impl/conv_kernel_impl.h index 2cf5fa166e7..4395e5d5782 100644 --- a/backends/metax_gpu/kernels/impl/conv_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/conv_kernel_impl.h @@ -14,11 +14,11 @@ #pragma once -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/conv_kernel.h" #include "paddle/phi/kernels/cpu/conv_util.h" #include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/funcs/batch_norm_utils.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/im2col.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/vol2col.h" diff --git a/backends/metax_gpu/kernels/impl/conv_transpose_kernel_impl.h b/backends/metax_gpu/kernels/impl/conv_transpose_kernel_impl.h index c7c002d4e9e..aadc5d2b8a0 100644 --- a/backends/metax_gpu/kernels/impl/conv_transpose_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/conv_transpose_kernel_impl.h @@ -14,12 +14,12 @@ #pragma once -#include "kernels/funcs/blas/blas.h" #include "paddle/common/ddim.h" #include "paddle/common/layout.h" #include "paddle/phi/kernels/conv_transpose_kernel.h" #include "paddle/phi/kernels/cpu/conv_util.h" #include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/concat_and_split_functor.h" #include "paddle/phi/kernels/funcs/im2col.h" #include "paddle/phi/kernels/funcs/slice.h" diff --git a/backends/metax_gpu/kernels/impl/deformable_conv_grad_kernel_impl.h b/backends/metax_gpu/kernels/impl/deformable_conv_grad_kernel_impl.h index d2419966342..b9931a89978 100644 --- a/backends/metax_gpu/kernels/impl/deformable_conv_grad_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/deformable_conv_grad_kernel_impl.h @@ -14,11 +14,11 @@ #pragma once -#include "kernels/funcs/blas/blas.h" #include "paddle/common/hostdevice.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/deformable_conv_functor.h" namespace phi { diff --git a/backends/metax_gpu/kernels/impl/elementwise.h b/backends/metax_gpu/kernels/impl/elementwise.h index 52a7709424b..b9f3d8af1c9 100644 --- a/backends/metax_gpu/kernels/impl/elementwise.h +++ b/backends/metax_gpu/kernels/impl/elementwise.h @@ -14,9 +14,9 @@ limitations under the License. */ #pragma once -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/common_shape.h" #include "paddle/phi/kernels/funcs/eigen/common.h" diff --git a/backends/metax_gpu/kernels/impl/flatten2_kernel_impl.h b/backends/metax_gpu/kernels/impl/flatten2_kernel_impl.h index d4526922c7b..dc4059a7225 100644 --- a/backends/metax_gpu/kernels/impl/flatten2_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/flatten2_kernel_impl.h @@ -15,10 +15,10 @@ #pragma once #include -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/flatten_grad_kernel.h" #include "paddle/phi/kernels/flatten_kernel.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/flatten2_utils.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/backends/metax_gpu/kernels/impl/gru_unit_kernel_impl.h b/backends/metax_gpu/kernels/impl/gru_unit_kernel_impl.h index 0929a327035..ef12141f911 100644 --- a/backends/metax_gpu/kernels/impl/gru_unit_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/gru_unit_kernel_impl.h @@ -16,10 +16,10 @@ #include -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/backends/all_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/activation_functor.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/utils/optional.h" namespace phi { diff --git a/backends/metax_gpu/kernels/impl/index_select_impl.h b/backends/metax_gpu/kernels/impl/index_select_impl.h index 78284107d34..ac39cab2704 100644 --- a/backends/metax_gpu/kernels/impl/index_select_impl.h +++ b/backends/metax_gpu/kernels/impl/index_select_impl.h @@ -15,9 +15,9 @@ #pragma once #include "glog/logging.h" -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/backends/metax_gpu/kernels/impl/inverse_grad_kernel_impl.h b/backends/metax_gpu/kernels/impl/inverse_grad_kernel_impl.h index 85aff008b4e..64b56f2cd1c 100644 --- a/backends/metax_gpu/kernels/impl/inverse_grad_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/inverse_grad_kernel_impl.h @@ -14,10 +14,10 @@ #pragma once -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/kernels/complex_kernel.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/matrix_inverse.h" #include "paddle/phi/kernels/inverse_grad_kernel.h" diff --git a/backends/metax_gpu/kernels/impl/lstm_kernel_impl.h b/backends/metax_gpu/kernels/impl/lstm_kernel_impl.h index 079548b4ad0..4a061fe4716 100644 --- a/backends/metax_gpu/kernels/impl/lstm_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/lstm_kernel_impl.h @@ -15,8 +15,8 @@ #pragma once #include -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/detail/activation_functions.h" #include "paddle/phi/kernels/funcs/lstm_compute.h" #include "paddle/phi/kernels/funcs/lstm_utils.h" diff --git a/backends/metax_gpu/kernels/impl/lu_grad_kernel_impl.h b/backends/metax_gpu/kernels/impl/lu_grad_kernel_impl.h index e9ef47490bc..5a2e5d48a11 100644 --- a/backends/metax_gpu/kernels/impl/lu_grad_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/lu_grad_kernel_impl.h @@ -14,7 +14,7 @@ #pragma once -#include "kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/impl/lu_kernel_impl.h" #include "paddle/phi/kernels/triangular_solve_kernel.h" diff --git a/backends/metax_gpu/kernels/impl/lu_solve_grad_kernel_impl.h b/backends/metax_gpu/kernels/impl/lu_solve_grad_kernel_impl.h index 21c711c53ef..24dee650dfe 100644 --- a/backends/metax_gpu/kernels/impl/lu_solve_grad_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/lu_solve_grad_kernel_impl.h @@ -15,9 +15,9 @@ #pragma once #include "paddle/phi/infermeta/binary.h" -// #include "paddle/phi/kernels/funcs/blas/blas.h" +// #include "paddle/phi/paddle/phi/kernels/funcs/blas/blas.h" -#include "kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/matrix_solve.h" #include "paddle/phi/kernels/impl/lu_kernel_impl.h" diff --git a/backends/metax_gpu/kernels/impl/matmul_grad_kernel_impl.h b/backends/metax_gpu/kernels/impl/matmul_grad_kernel_impl.h deleted file mode 100644 index 823851666f1..00000000000 --- a/backends/metax_gpu/kernels/impl/matmul_grad_kernel_impl.h +++ /dev/null @@ -1,2042 +0,0 @@ -// 2024 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights -// Reserved. -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -// clang-format off -#include "glog/logging.h" - -#include "paddle/phi/backends/cpu/cpu_context.h" -#include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/kernels/complex_kernel.h" -#include "paddle/phi/kernels/empty_kernel.h" -#include "paddle/phi/kernels/full_kernel.h" -#include "paddle/phi/kernels/funcs/reduce_function.h" -#include "paddle/phi/kernels/funcs/reduce_functor.h" -#include "paddle/phi/kernels/impl/dot_grad_kernel_impl.h" -// #include "paddle/phi/kernels/impl/matmul_kernel_impl.h" -#include "paddle/phi/kernels/reduce_sum_kernel.h" - -#include "../impl/matmul_kernel_impl.h" -// clang-format on - -#if defined(__NVCC__) || defined(__HIPCC__) -#include "paddle/phi/kernels/gpu/reduce.h" -#endif - -namespace phi { - -template -struct ReduceSumForMatmulGrad { - void operator()(const Context& dev_ctx, - const DenseTensor& input, - DenseTensor* output, - const std::vector& reduce_dims); -}; - -template -struct ReduceSumForMatmulGrad { - void operator()(const CPUContext& dev_ctx, - const DenseTensor& input, - DenseTensor* output, - const std::vector& reduce_dims) { - std::vector reduce_dims_tmp(reduce_dims.begin(), - reduce_dims.end()); - funcs::ReduceKernelImpl( - dev_ctx, input, output, reduce_dims_tmp, true, false); - } -}; - -#if defined(__NVCC__) || defined(__HIPCC__) -template -struct ReduceSumForMatmulGrad { - void operator()(const GPUContext& dev_ctx, - const DenseTensor& input, - DenseTensor* output, - const std::vector& reduce_dims) { - phi::SumKernel( - dev_ctx, input, reduce_dims, input.dtype(), false, output); - } -}; -#endif - -// Reshape a rank-3 tensor from P x M x N to (P * M) x N. -// Identity op if the tensor is not of rank 3. -static DenseTensor FoldInitDims(const DenseTensor& input) { - DenseTensor output = input; - auto in_dims = input.dims(); - if (in_dims.size() == 3) { - output.Resize({in_dims[0] * in_dims[1], in_dims[2]}); - } - return output; -} - -// Reshape a rank-3 tensor from P x M x N to M x (P * N). -// (Warning: This requires transposing data and writes into new memory.) -// Identity op if the tensor is not of rank 3. -template -static DenseTensor FoldHeadAndLastDims(const Context& dev_ctx, - const DenseTensor& input) { - auto in_dims = input.dims(); - if (in_dims.size() != 3) { - return input; - } - DenseTensor output = EmptyLike(dev_ctx, input); - output.Resize({in_dims[1], in_dims[0], in_dims[2]}); - std::vector axis = {1, 0, 2}; - funcs::Transpose trans; - trans(dev_ctx, input, &output, axis); - output.Resize({in_dims[1], in_dims[0] * in_dims[2]}); - return output; -} - -template -typename std::enable_if::value>::type MatMul( - const Context& dev_ctx, - const DenseTensor& a, - bool trans_a, - const DenseTensor& b, - bool trans_b, - DenseTensor* out, - bool flag = false) { - dev_ctx.template Alloc(out); - auto blas = phi::funcs::GetBlas(dev_ctx); - auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(a.dims(), 0, trans_a); - auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(b.dims(), 0, trans_b); - if (a.dims().size() == 3 && b.dims().size() <= 2) { - // the transpose_X must be false, if is true, the transpose cost much time - if (!trans_a) { - mat_dim_a.height_ *= mat_dim_a.batch_size_; - mat_dim_a.batch_size_ = 0; - } - } - blas.MatMul(a.data(), - mat_dim_a, - b.data(), - mat_dim_b, - static_cast(1), - dev_ctx.template Alloc(out), - static_cast(flag)); -} - -/** - * Get row matrix shape from a vector shape. If the rank of x_dim > 1, the - * original x_dim is returned. - */ -static DDim RowMatrixFromVector(const DDim& x_dim) { - if (x_dim.size() > 1) { - return x_dim; - } - return common::make_ddim({1, x_dim[0]}); -} - -/** - * Get column matrix shape from a vector shape. If the ran of y_dim > 1, the - * original y_dim is returned. - */ -static DDim ColumnMatrixFromVector(const DDim& y_dim) { - if (y_dim.size() > 1) { - return y_dim; - } - return common::make_ddim({y_dim[0], 1}); -} - -/** - * Reshape a tensor to 3-D or 2-D tensor by matrix descriptor. - * - * The shape would be [BatchSize, H, W] or [H, W]. - * If transposed, `H,W` will be swapped. - */ -static void ReshapeTensorIntoMatrixSequence( - DenseTensor* x, const phi::funcs::MatDescriptor& descriptor) { - int64_t h, w; - h = descriptor.height_; - w = descriptor.width_; - if (descriptor.trans_) { - std::swap(w, h); - } - if (descriptor.batch_size_) { - x->Resize({descriptor.batch_size_, h, w}); - } else { - x->Resize({h, w}); - } -} - -static void ReshapeXYOutIntoMatrixSequence(DenseTensor* x, - DenseTensor* y, - DenseTensor* out, - bool trans_x, - bool trans_y) { - auto x_dim = RowMatrixFromVector(x->dims()); - auto y_dim = ColumnMatrixFromVector(y->dims()); - auto mat_dim_x = phi::funcs::CreateMatrixDescriptor(x_dim, 0, trans_x); - auto mat_dim_y = phi::funcs::CreateMatrixDescriptor(y_dim, 0, trans_y); - if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) { - out->Resize({mat_dim_x.height_, mat_dim_y.width_}); - } else { - out->Resize({(std::max)(mat_dim_x.batch_size_, mat_dim_y.batch_size_), - mat_dim_x.height_, - mat_dim_y.width_}); - } - - ReshapeTensorIntoMatrixSequence(x, mat_dim_x); - ReshapeTensorIntoMatrixSequence(y, mat_dim_y); -} - -template -void CalcInputGrad(const Context& dev_ctx, - const DenseTensor& a, - bool trans_a, - bool is_fold_init_dims_a, - const DenseTensor& b, - bool trans_b, - bool is_fold_init_dims_b, - DenseTensor* out, - bool flag = false) { - if (out == nullptr) return; - bool need_combine = - (a.dims().size() == 3 || b.dims().size() == 3) && out->dims().size() == 2; - if (!need_combine) { - MatMul(dev_ctx, a, trans_a, b, trans_b, out, flag); - } else { - MatMul( - dev_ctx, - is_fold_init_dims_a ? FoldInitDims(a) - : FoldHeadAndLastDims(dev_ctx, a), - trans_a, - is_fold_init_dims_b ? FoldInitDims(b) - : FoldHeadAndLastDims(dev_ctx, b), - trans_b, - out, - flag); - } -} - -template -void MatmulGradKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - const DenseTensor& out_grad, - bool transpose_x, - bool transpose_y, - DenseTensor* dx, - DenseTensor* dy) { - // get dims - std::vector x_dims = common::vectorize(x.dims()); - std::vector y_dims = common::vectorize(y.dims()); - std::vector dout_dims = common::vectorize(out_grad.dims()); - - int x_ndim = x_dims.size(); - int y_ndim = y_dims.size(); - int ndim = dout_dims.size(); - - // Case1 : x's or y's dim = 1 - if (x_ndim == 1 && y_ndim == 1) { - if (dx) dev_ctx.template Alloc(dx); - if (dy) dev_ctx.template Alloc(dy); - if (out_grad.numel() == 1) { - DotGradFunction()(dev_ctx, &x, &y, &out_grad, dx, dy); - return; - } - } - - bool is_broadcast = true; - if (x_ndim <= 2 || y_ndim <= 2) { - is_broadcast = false; - } else if (x_ndim != y_ndim) { - is_broadcast = true; - } else { - is_broadcast = !std::equal( - x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2, y_dims.cbegin()); - } - - // for complex - DenseTensor x_conj; - DenseTensor y_conj; - - // Case2: no broadcast or no batch size, it aims to speed and it is same as - // matmul in old version. - if (!is_broadcast) { - DenseTensor x_help = x; - DenseTensor y_help = y; - DenseTensor out_grad_help = out_grad; - - ReshapeXYOutIntoMatrixSequence( - &x_help, &y_help, &out_grad_help, transpose_x, transpose_y); - - DDim dx_dims; - if (dx) { - dx_dims = dx->dims(); - if (dx_dims != x_help.dims()) { - dx->Resize(x_help.dims()); - } - - y_conj = Conj(dev_ctx, y_help); - } - - DDim dy_dims; - if (dy) { - dy_dims = dy->dims(); - if (dy_dims != y_help.dims()) { - dy->Resize(y_help.dims()); - } - - x_conj = Conj(dev_ctx, x_help); - } - - if (transpose_x && transpose_y) { - CalcInputGrad( - dev_ctx, y_conj, true, true, out_grad_help, true, false, dx); - CalcInputGrad( - dev_ctx, out_grad_help, true, true, x_conj, true, false, dy); - } else if (transpose_x) { - CalcInputGrad( - dev_ctx, y_conj, false, false, out_grad_help, true, false, dx); - CalcInputGrad( - dev_ctx, x_conj, false, false, out_grad_help, false, true, dy); - } else if (transpose_y) { - CalcInputGrad( - dev_ctx, out_grad_help, false, false, y_conj, false, true, dx); - CalcInputGrad( - dev_ctx, out_grad_help, true, true, x_conj, false, true, dy); - } else { - CalcInputGrad( - dev_ctx, out_grad_help, false, false, y_conj, true, false, dx); - CalcInputGrad( - dev_ctx, x_conj, true, true, out_grad_help, false, true, dy); - } - - if (dx) { - if (dx_dims != x_help.dims()) { - dx->Resize(dx_dims); - } - } - if (dy) { - if (dy_dims != y_help.dims()) { - dy->Resize(dy_dims); - } - } - } else { - // Case3: broadcast. It need cost much time to reduce sum for the - // broadcast and wastes the memory. - // So we should avoid the case in reality. - VLOG(3) << "It need cost much time to reduce sum for the broadcast and " - "wastes the memory. So we should avoid the case in reality"; - x_conj = Conj(dev_ctx, x); - y_conj = Conj(dev_ctx, y); - - DenseTensor dx_help; - DenseTensor dy_help; - - if (transpose_x) { - if (transpose_y) { - // X'Y': dA = Y'G', dB = G'X' - if (dx) - MatMulFunction(dev_ctx, - y_conj, - out_grad, - y_dims, - dout_dims, - &dx_help, - true, - true); - if (dy) - MatMulFunction(dev_ctx, - out_grad, - x_conj, - dout_dims, - x_dims, - &dy_help, - true, - true); - } else { - // X'Y: dX = YG', dY = XG - if (dx) - MatMulFunction(dev_ctx, - y_conj, - out_grad, - y_dims, - dout_dims, - &dx_help, - false, - true); - if (dy) - MatMulFunction(dev_ctx, - x_conj, - out_grad, - x_dims, - dout_dims, - &dy_help, - false, - false); - } - } else { - if (transpose_y) { - // XY': dX = GY, dY = G'X - if (dx) - MatMulFunction(dev_ctx, - out_grad, - y_conj, - dout_dims, - y_dims, - &dx_help, - false, - false); - if (dy) - MatMulFunction(dev_ctx, - out_grad, - x_conj, - dout_dims, - x_dims, - &dy_help, - true, - false); - } else { - // XY: dX = GY', dY = X'G - if (dx) - MatMulFunction(dev_ctx, - out_grad, - y_conj, - dout_dims, - y_dims, - &dx_help, - false, - true); - if (dy) - MatMulFunction(dev_ctx, - x_conj, - out_grad, - x_dims, - dout_dims, - &dy_help, - true, - false); - } - } - - // get help dims - const std::vector dx_help_dims = - common::vectorize(dx_help.dims()); - const std::vector dy_help_dims = - common::vectorize(dy_help.dims()); - - std::vector dx_broadcast_dims(ndim); - std::vector dy_broadcast_dims(ndim); - - std::fill( - dx_broadcast_dims.data(), dx_broadcast_dims.data() + ndim - x_ndim, 1); - std::fill( - dy_broadcast_dims.data(), dy_broadcast_dims.data() + ndim - y_ndim, 1); - std::copy(x_dims.data(), - x_dims.data() + x_ndim, - dx_broadcast_dims.data() + ndim - x_ndim); - std::copy(y_dims.data(), - y_dims.data() + y_ndim, - dy_broadcast_dims.data() + ndim - y_ndim); - - std::vector dx_reduce_dims; - std::vector dy_reduce_dims; - for (int idx = 0; idx <= ndim - 3; idx++) { - if (dx_help_dims[idx] != 1 && dx_broadcast_dims[idx] == 1) { - dx_reduce_dims.push_back(idx); - } - if (dy_help_dims[idx] != 1 && dy_broadcast_dims[idx] == 1) { - dy_reduce_dims.push_back(idx); - } - } - // reduce sum to get grad by ReduceSum - if (dx) { - if (dx_reduce_dims.empty()) { - *dx = std::move(dx_help); - } else { - ReduceSumForMatmulGrad()( - dev_ctx, dx_help, dx, dx_reduce_dims); - } - dx->Resize(x.dims()); - } - if (dy) { - if (dy_reduce_dims.empty()) { - *dy = std::move(dy_help); - } else { - ReduceSumForMatmulGrad()( - dev_ctx, dy_help, dy, dy_reduce_dims); - } - dy->Resize(y.dims()); - } - // Get the OutputGrad(out) - } -} - -template -void MatmulDoubleGradKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - const DenseTensor& dout, - const paddle::optional& ddx, - const paddle::optional& ddy, - bool transpose_x, - bool transpose_y, - DenseTensor* dx, - DenseTensor* dy, - DenseTensor* ddout) { - // Get dims from the input x, y, output_grad - std::vector x_dims = common::vectorize(x.dims()); - std::vector y_dims = common::vectorize(y.dims()); - std::vector dout_dims = common::vectorize(dout.dims()); - - int x_ndim = x_dims.size(); - int y_ndim = y_dims.size(); - int ndim = dout_dims.size(); - - // Case1 : x's or y's dim = 1 - if (x_ndim == 1 && y_ndim == 1) { - DotDoubleGradFunction()( - dev_ctx, &x, &y, &dout, &ddx, &ddy, dx, dy, ddout); - return; - } - - DenseTensor x_conj; - DenseTensor y_conj; - DenseTensor dout_conj; - - bool is_broadcast = true; - if (x_ndim <= 2 || y_ndim <= 2) { - is_broadcast = false; - } else if (x_ndim != y_ndim) { - is_broadcast = true; - } else { - is_broadcast = !std::equal( - x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2, y_dims.cbegin()); - } - - if (!is_broadcast) { - // Case2: no broadcast or no batch size - DenseTensor x_help = x; - DenseTensor y_help = y; - DenseTensor dout_help = dout; - ReshapeXYOutIntoMatrixSequence( - &x_help, &y_help, &dout_help, transpose_x, transpose_y); - DDim dx_dims; - - if (dx) { - dx_dims = dx->dims(); - if (dx_dims != x_help.dims()) { - dx->Resize(x_help.dims()); - } - } - - DDim dy_dims; - if (dy) { - dy_dims = dy->dims(); - if (dy_dims != y_help.dims()) { - dy->Resize(y_help.dims()); - } - } - - DDim ddout_dims; - if (ddout) { - ddout_dims = ddout->dims(); - if (ddout_dims != dout_help.dims()) { - ddout->Resize(dout_help.dims()); - } - - x_conj = Conj(dev_ctx, x_help); - y_conj = Conj(dev_ctx, y_help); - } - - if (dx || dy) { - dout_conj = Conj(dev_ctx, dout_help); - } - - bool ddout_flag = false; - if (ddx) { - auto ddx_mat = ddx.get(); - if (ddx_mat.dims() != x_help.dims()) { - ddx_mat.Resize(x_help.dims()); - } - if (dy) { - if (transpose_x && transpose_y) { - // dy = dout' * ddx' - CalcInputGrad( - dev_ctx, dout_conj, true, true, ddx_mat, true, false, dy, false); - } else if (transpose_x) { - // dy = ddx * dout - CalcInputGrad(dev_ctx, - ddx_mat, - false, - false, - dout_conj, - false, - true, - dy, - false); - } else if (transpose_y) { - // dy = dout' * ddx - CalcInputGrad( - dev_ctx, dout_conj, true, true, ddx_mat, false, true, dy, false); - } else { - // dy = ddx' * dout - CalcInputGrad( - dev_ctx, ddx_mat, true, true, dout_conj, false, true, dy, false); - } - } - - if (ddout) { - CalcInputGrad(dev_ctx, - ddx_mat, - transpose_x, - true, - y_conj, - transpose_y, - false, - ddout, - ddout_flag); - ddout_flag = true; - } - } else if (!ddx && dy) { - FullLikeKernel(dev_ctx, y, Scalar(0.0), y.dtype(), dy); - } - if (ddy) { - auto ddy_mat = ddy.get(); - if (ddy_mat.dims() != y_help.dims()) { - ddy_mat.Resize(y_help.dims()); - } - if (dx) { - if (transpose_x && transpose_y) { - // dx = ddy' * dout' - CalcInputGrad( - dev_ctx, ddy_mat, true, true, dout_conj, true, false, dx, false); - } else if (transpose_x) { - // dx = ddy * dout' - CalcInputGrad(dev_ctx, - ddy_mat, - false, - false, - dout_conj, - true, - false, - dx, - false); - } else if (transpose_y) { - // dx = dout * ddy - CalcInputGrad(dev_ctx, - dout_conj, - false, - false, - ddy_mat, - false, - true, - dx, - false); - } else { - // dx = dout * ddy' - CalcInputGrad(dev_ctx, - dout_conj, - false, - false, - ddy_mat, - true, - false, - dx, - false); - } - } - - if (ddout) { - CalcInputGrad(dev_ctx, - x_conj, - transpose_x, - true, - ddy_mat, - transpose_y, - false, - ddout, - ddout_flag); - } - } else if (!ddy && dx) { - FullLikeKernel(dev_ctx, x, Scalar(0.0), x.dtype(), dx); - } - if (ddout && !ddx && !ddy) { - FullLikeKernel( - dev_ctx, dout, Scalar(0.0), dout.dtype(), ddout); - } - - if (dx) { - if (dx_dims != x_help.dims()) { - dx->Resize(dx_dims); - } - } - - if (dy) { - if (dy_dims != y_help.dims()) { - dy->Resize(dy_dims); - } - } - - if (ddout) { - if (ddout_dims != dout_help.dims()) { - ddout->Resize(ddout_dims); - } - } - } else { - // Case3: broadcast. It need cost much time to reduce sum for the - // broadcast and wastes the memory. - // So we should avoid the case in reality. - VLOG(3) << "It need cost much time to reduce sum for the broadcast and " - "wastes the memory. So we should avoid the case in reality"; - if (dx || dy) { - dout_conj = Conj(dev_ctx, dout); - } - if (ddout) { - x_conj = Conj(dev_ctx, x); - y_conj = Conj(dev_ctx, y); - } - - DenseTensor dx_help; - DenseTensor dy_help; - - if (transpose_x) { - if (transpose_y) { - if (dx && ddy) { - MatMulFunction(dev_ctx, - ddy.get(), - dout_conj, - y_dims, - dout_dims, - &dx_help, - true, - true); - } - if (dy && ddx) { - MatMulFunction(dev_ctx, - dout_conj, - ddx.get(), - dout_dims, - x_dims, - &dy_help, - true, - true); - } - } else { - if (dx && ddy) { - MatMulFunction(dev_ctx, - ddy.get(), - dout_conj, - y_dims, - dout_dims, - &dx_help, - false, - true); - } - if (dy && ddx) { - MatMulFunction(dev_ctx, - ddx.get(), - dout_conj, - x_dims, - dout_dims, - &dy_help, - false, - false); - } - } - } else { - if (transpose_y) { - if (dx && ddy) { - MatMulFunction(dev_ctx, - dout_conj, - ddy.get(), - dout_dims, - y_dims, - &dx_help, - false, - false); - } - if (dy && ddx) { - MatMulFunction(dev_ctx, - dout_conj, - ddx.get(), - dout_dims, - x_dims, - &dy_help, - true, - false); - } - } else { - if (dx && ddy) { - MatMulFunction(dev_ctx, - dout_conj, - ddy.get(), - dout_dims, - y_dims, - &dx_help, - false, - true); - } - if (dy && ddx) { - MatMulFunction(dev_ctx, - ddx.get(), - dout_conj, - x_dims, - dout_dims, - &dy_help, - true, - false); - } - } - } - - // get help dims - const std::vector dx_help_dims = - common::vectorize(dx_help.dims()); - const std::vector dy_help_dims = - common::vectorize(dy_help.dims()); - - std::vector dx_broadcast_dims(ndim); - std::vector dy_broadcast_dims(ndim); - - std::fill( - dx_broadcast_dims.data(), dx_broadcast_dims.data() + ndim - x_ndim, 1); - std::fill( - dy_broadcast_dims.data(), dy_broadcast_dims.data() + ndim - y_ndim, 1); - std::copy(x_dims.data(), - x_dims.data() + x_ndim, - dx_broadcast_dims.data() + ndim - x_ndim); - std::copy(y_dims.data(), - y_dims.data() + y_ndim, - dy_broadcast_dims.data() + ndim - y_ndim); - - std::vector dx_reduce_dims; - std::vector dy_reduce_dims; - for (int idx = 0; idx <= ndim - 3; idx++) { - if (dx_help_dims[idx] != 1 && dx_broadcast_dims[idx] == 1) { - dx_reduce_dims.push_back(idx); - } - if (dy_help_dims[idx] != 1 && dy_broadcast_dims[idx] == 1) { - dy_reduce_dims.push_back(idx); - } - } - // Reduce sum to get grad by ReduceSum - if (dx && dx_help.initialized()) { - if (dx_reduce_dims.empty()) { - *dx = std::move(dx_help); - } else { - ReduceSumForMatmulGrad()( - dev_ctx, dx_help, dx, dx_reduce_dims); - } - dx->Resize(x.dims()); - } else if (dx && !dx_help.initialized()) { - FullLikeKernel(dev_ctx, x, Scalar(0.0), x.dtype(), dx); - } - if (dy && dy_help.initialized()) { - if (dy_reduce_dims.empty()) { - *dy = std::move(dy_help); - } else { - ReduceSumForMatmulGrad()( - dev_ctx, dy_help, dy, dy_reduce_dims); - } - dy->Resize(y.dims()); - } else if (dy && !dy_help.initialized()) { - FullLikeKernel(dev_ctx, y, Scalar(0.0), y.dtype(), dy); - } - - if (ddout) { - // Calculate the gradient of OutputGrad(Out) - if (ddx) { - MatMulFunction(dev_ctx, - ddx.get(), - y_conj, - x_dims, - y_dims, - ddout, - transpose_x, - transpose_y); - } - - if (ddy) { - MatMulFunction(dev_ctx, - x_conj, - ddy.get(), - x_dims, - y_dims, - ddout, - transpose_x, - transpose_y, - true); - } - } - } -} - -template -void MatmulTripleGradKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - const DenseTensor& dout, - const paddle::optional& ddx, - const paddle::optional& ddy, - const paddle::optional& d_dx, - const paddle::optional& d_dy, - const paddle::optional& d_ddout, - bool transpose_x, - bool transpose_y, - DenseTensor* out_d_x, - DenseTensor* out_d_y, - DenseTensor* out_d_dout, - DenseTensor* out_d_ddx, - DenseTensor* out_d_ddy) { - // Get dims from the input x, y, output_grad - std::vector x_dims = common::vectorize(x.dims()); - std::vector y_dims = common::vectorize(y.dims()); - std::vector dout_dims = common::vectorize(dout.dims()); - - int x_ndim = x_dims.size(); - int y_ndim = y_dims.size(); - int ndim = dout_dims.size(); - - // Case1 : x's and y's dim = 1 - if (x_ndim == 1 && y_ndim == 1) { - VLOG(3) << "======== MatMulV2TripleGradKernel, Compute ====== Case 1"; - DotTripleGradFunction()(dev_ctx, - &x, - &y, - &dout, - &ddx, - &ddy, - &d_dx, - &d_dy, - &d_ddout, - out_d_x, - out_d_y, - out_d_dout, - out_d_ddx, - out_d_ddy); - return; - } - - DenseTensor x_conj; - DenseTensor y_conj; - DenseTensor dout_conj; - DenseTensor ddx_conj; - DenseTensor ddy_conj; - - bool is_broadcast = true; - if (x_ndim <= 2 || y_ndim <= 2) { - is_broadcast = false; - } else if (x_ndim != y_ndim) { - is_broadcast = true; - } else { - is_broadcast = !std::equal( - x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2, y_dims.cbegin()); - } - - if (!is_broadcast) { - // Case2: no broadcast or no batch size - VLOG(3) << "======== MatMulV2TripleGradKernel, Compute ====== Case 2"; - DenseTensor x_help = x; - DenseTensor y_help = y; - DenseTensor dout_help = dout; - - DenseTensor ddx_help; - DenseTensor ddy_help; - ReshapeXYOutIntoMatrixSequence( - &x_help, &y_help, &dout_help, transpose_x, transpose_y); - if (ddx) { - ddx_help = ddx.get(); - if (ddx_help.dims() != x_help.dims()) { - ddx_help.Resize(x_help.dims()); - } - } - - if (ddy) { - ddy_help = ddy.get(); - if (ddy_help.dims() != y_help.dims()) { - ddy_help.Resize(y_help.dims()); - } - } - - DDim out_dx_dims; - if (out_d_x) { - out_dx_dims = out_d_x->dims(); - if (out_dx_dims != x_help.dims()) { - out_d_x->Resize(x_help.dims()); - } - if (ddy) { - ddy_conj = Conj(dev_ctx, ddy_help); - } - } - DDim out_dy_dims; - if (out_d_y) { - out_dy_dims = out_d_y->dims(); - if (out_dy_dims != y_help.dims()) { - out_d_y->Resize(y_help.dims()); - } - if (ddx) { - ddx_conj = Conj(dev_ctx, ddx_help); - } - } - DDim out_d_dout_dims; - if (out_d_dout) { - out_d_dout_dims = out_d_dout->dims(); - if (out_d_dout_dims != dout_help.dims()) { - out_d_dout->Resize(dout_help.dims()); - } - if (ddx && !ddx_conj.IsInitialized()) { - ddx_conj = Conj(dev_ctx, ddx_help); - } - if (ddy && !ddy_conj.IsInitialized()) { - ddy_conj = Conj(dev_ctx, ddy_help); - } - } - DDim out_d_ddx_dims; - if (out_d_ddx) { - out_d_ddx_dims = out_d_ddx->dims(); - if (out_d_ddx_dims != x_help.dims()) { - out_d_ddx->Resize(x_help.dims()); - } - dout_conj = Conj(dev_ctx, dout_help); - y_conj = Conj(dev_ctx, y_help); - } - DDim out_d_ddy_dims; - if (out_d_ddy) { - out_d_ddy_dims = out_d_ddy->dims(); - if (out_d_ddy_dims != y_help.dims()) { - out_d_ddy->Resize(y_help.dims()); - } - if (!dout_conj.IsInitialized()) { - dout_conj = Conj(dev_ctx, dout_help); - } - x_conj = Conj(dev_ctx, x_help); - } - - bool d_dout_flag = false; - bool d_ddx_flag = false; - bool d_ddy_flag = false; - if (d_ddout) { - auto d_ddout_mat = d_ddout.get(); - if (d_ddout_mat.dims() != dout_help.dims()) { - d_ddout_mat.Resize(dout_help.dims()); - } - - if (out_d_y && ddx) { - if (transpose_x && transpose_y) { - // out_d_y = d_ddout' * ddx' - CalcInputGrad(dev_ctx, - d_ddout_mat, - true, - true, - ddx_conj, - true, - false, - out_d_y, - false); - } else if (transpose_x) { - // out_d_y = ddx * d_ddout - CalcInputGrad(dev_ctx, - ddx_conj, - false, - false, - d_ddout_mat, - false, - true, - out_d_y, - false); - } else if (transpose_y) { - // out_d_y = d_ddout' * ddx - CalcInputGrad(dev_ctx, - d_ddout_mat, - true, - true, - ddx_conj, - false, - true, - out_d_y, - false); - } else { - // out_d_y = ddx' * d_ddout - CalcInputGrad(dev_ctx, - ddx_conj, - true, - true, - d_ddout_mat, - false, - true, - out_d_y, - false); - } - } else if (out_d_y) { - FullLikeKernel(dev_ctx, y, Scalar(0.0), y.dtype(), out_d_y); - } - if (out_d_x && ddy) { - if (transpose_x && transpose_y) { - // out_d_x = ddy' * d_ddout' - CalcInputGrad(dev_ctx, - ddy_conj, - true, - true, - d_ddout_mat, - true, - false, - out_d_x, - false); - } else if (transpose_x) { - // out_d_x = ddy * d_ddout' - CalcInputGrad(dev_ctx, - ddy_conj, - false, - false, - d_ddout_mat, - true, - false, - out_d_x, - false); - } else if (transpose_y) { - // out_d_x = d_ddout * ddy - CalcInputGrad(dev_ctx, - d_ddout_mat, - false, - false, - ddy_conj, - false, - true, - out_d_x, - false); - } else { - // out_d_x = d_ddout * ddy' - CalcInputGrad(dev_ctx, - d_ddout_mat, - false, - false, - ddy_conj, - true, - false, - out_d_x, - false); - } - } else if (out_d_x) { - FullLikeKernel(dev_ctx, x, Scalar(0.0), x.dtype(), out_d_x); - } - - // equations: - // d_ddx = DOut * D_DY + Y * D_DDOut - // Let: d_ddx1 = Y * D_DDOut - // Let: d_ddx2 = DOut * D_DY - - // d_ddy = DOut * D_DX + X * D_DDOut - // Let: d_ddy1 = X * D_DDOut - // Let: d_ddy2 = DOut * D_DX - - // d_dout = DDY * D_DX + DDX * D_DY - // Let: d_dout1 = DDX * D_DY - // Let: d_dout2 = DDY * D_DX - - // compute d_ddx1 - if (out_d_ddx) { - if (transpose_x && transpose_y) { - // out_d_ddx1 = y' * d_ddout' - CalcInputGrad(dev_ctx, - y_conj, - true, - true, - d_ddout_mat, - true, - false, - out_d_ddx, - d_ddx_flag); - } else if (transpose_x) { - // out_d_ddx1 = y * d_ddout' - CalcInputGrad(dev_ctx, - y_conj, - false, - false, - d_ddout_mat, - true, - false, - out_d_ddx, - d_ddx_flag); - } else if (transpose_y) { - // out_d_ddx1 = d_ddout * y - CalcInputGrad(dev_ctx, - d_ddout_mat, - false, - false, - y_conj, - false, - true, - out_d_ddx, - d_ddx_flag); - } else { - // out_d_ddx1 = d_ddout * y' - CalcInputGrad(dev_ctx, - d_ddout_mat, - false, - false, - y_conj, - true, - false, - out_d_ddx, - d_ddx_flag); - } - d_ddx_flag = true; - } - - // compute d_ddy1 - if (out_d_ddy) { - if (transpose_x && transpose_y) { - // out_d_ddy1 = d_ddout' * x' - CalcInputGrad(dev_ctx, - d_ddout_mat, - true, - true, - x_conj, - true, - false, - out_d_ddy, - false); - } else if (transpose_x) { - // out_d_ddy1 = x * d_ddout - CalcInputGrad(dev_ctx, - x_conj, - false, - false, - d_ddout_mat, - false, - true, - out_d_ddy, - false); - } else if (transpose_y) { - // out_d_ddy1 = d_ddout' * x - CalcInputGrad(dev_ctx, - d_ddout_mat, - true, - true, - x_conj, - false, - true, - out_d_ddy, - false); - } else { - // out_d_ddy1 = x' * d_ddout - CalcInputGrad(dev_ctx, - x_conj, - true, - true, - d_ddout_mat, - false, - true, - out_d_ddy, - false); - } - d_ddy_flag = true; - } - } else { - // d_ddout is none - if (out_d_x) { - FullLikeKernel(dev_ctx, x, Scalar(0.0), x.dtype(), out_d_x); - } - - if (out_d_y) { - FullLikeKernel(dev_ctx, y, Scalar(0.0), y.dtype(), out_d_y); - } - } - - if (d_dy) { - auto d_dy_mat = d_dy.get(); - if (d_dy_mat.dims() != y_help.dims()) { - d_dy_mat.Resize(y_help.dims()); - } - - // compute d_dout1 - if (out_d_dout && ddx) { - CalcInputGrad(dev_ctx, - ddx_conj, - transpose_x, - true, - d_dy_mat, - transpose_y, - false, - out_d_dout, - d_dout_flag); - d_dout_flag = true; - } - - // compute d_ddx2 - if (out_d_ddx) { - if (transpose_x && transpose_y) { - // out_d_ddx2 = D_DY' * DOut' - CalcInputGrad(dev_ctx, - d_dy_mat, - true, - true, - dout_conj, - true, - false, - out_d_ddx, - d_ddx_flag); - } else if (transpose_x) { - // out_d_ddx2 = D_DY * Dout' - CalcInputGrad(dev_ctx, - d_dy_mat, - false, - false, - dout_conj, - true, - false, - out_d_ddx, - d_ddx_flag); - } else if (transpose_y) { - // out_d_ddx2 = Dout * D_DY - CalcInputGrad(dev_ctx, - dout_conj, - false, - false, - d_dy_mat, - false, - true, - out_d_ddx, - d_ddx_flag); - } else { - // out_d_ddx2 = Dout * D_DY' - CalcInputGrad(dev_ctx, - dout_conj, - false, - false, - d_dy_mat, - true, - false, - out_d_ddx, - d_ddx_flag); - } - } - } - - if (d_dx) { - auto d_dx_mat = d_dx.get(); - if (d_dx_mat.dims() != x_help.dims()) { - d_dx_mat.Resize(x_help.dims()); - } - - // compute d_dout2 - if (out_d_dout && ddy) { - CalcInputGrad(dev_ctx, - d_dx_mat, - transpose_x, - true, - ddy_conj, - transpose_y, - false, - out_d_dout, - d_dout_flag); - } - - // compute d_ddy2 - if (out_d_ddy) { - if (transpose_x && transpose_y) { - // out_d_ddy2 = dout' * d_dx' - CalcInputGrad(dev_ctx, - dout_conj, - true, - true, - d_dx_mat, - true, - false, - out_d_ddy, - d_ddy_flag); - } else if (transpose_x) { - // out_d_ddy2 = d_dx * dout - CalcInputGrad(dev_ctx, - d_dx_mat, - false, - false, - dout_conj, - false, - true, - out_d_ddy, - d_ddy_flag); - } else if (transpose_y) { - // out_d_ddy2 = dout' * d_dx - CalcInputGrad(dev_ctx, - dout_conj, - true, - true, - d_dx_mat, - false, - true, - out_d_ddy, - d_ddy_flag); - } else { - // out_d_ddy2 = d_dx' * dout - CalcInputGrad(dev_ctx, - d_dx_mat, - true, - true, - dout_conj, - false, - true, - out_d_ddy, - d_ddy_flag); - } - } - } - - if (out_d_x) { - if (out_dx_dims != x_help.dims()) { - out_d_x->Resize(out_dx_dims); - } - } - - if (out_d_y) { - if (out_dy_dims != y_help.dims()) { - out_d_y->Resize(out_dy_dims); - } - } - - if (out_d_dout) { - if (out_d_dout_dims != dout_help.dims()) { - out_d_dout->Resize(out_d_dout_dims); - } - } - - if (out_d_ddx) { - if (out_d_ddx_dims != x_help.dims()) { - out_d_ddx->Resize(out_d_ddx_dims); - } - } - - if (out_d_ddy) { - if (out_d_ddy_dims != y_help.dims()) { - out_d_ddy->Resize(out_d_ddy_dims); - } - } - - if (out_d_dout && !out_d_dout->IsInitialized()) { - FullLikeKernel( - dev_ctx, dout, Scalar(0.0), dout.dtype(), out_d_dout); - } - - if (out_d_ddx && !out_d_ddx->IsInitialized()) { - FullLikeKernel(dev_ctx, x, Scalar(0.0), x.dtype(), out_d_ddx); - } - - if (out_d_ddy && !out_d_ddy->IsInitialized()) { - FullLikeKernel(dev_ctx, y, Scalar(0.0), y.dtype(), out_d_ddy); - } - } else { - // Case3: broadcast. It need cost much time to reduce sum for the - // broadcast and wastes the memory. - // So we should avoid the case in reality. - VLOG(3) << "======== MatMulV2TripleGradKernel, Compute ====== Case 3"; - VLOG(3) << "It need cost much time to reduce sum for the broadcast and " - "wastes the memory. So we should avoid the case in reality"; - - DenseTensor out_dx_help; - DenseTensor out_dy_help; - DenseTensor out_d_ddx_help; - DenseTensor out_d_ddy_help; - - if (out_d_dout) { - if (ddx) { - ddx_conj = Conj(dev_ctx, ddx.get()); - } - if (ddy) { - ddy_conj = Conj(dev_ctx, ddy.get()); - } - } - if (out_d_ddx || out_d_ddy) { - x_conj = Conj(dev_ctx, x); - y_conj = Conj(dev_ctx, y); - dout_conj = Conj(dev_ctx, dout); - } - - if (transpose_x) { - if (transpose_y) { - // dX = ddY' d_ddout’, dY = d_ddout’ ddX' - if (out_d_x && ddy && d_ddout) - MatMulFunction(dev_ctx, - ddy_conj, - d_ddout.get(), - y_dims, - dout_dims, - &out_dx_help, - true, - true); - if (out_d_y && ddx && d_ddout) - MatMulFunction(dev_ctx, - d_ddout.get(), - ddx_conj, - dout_dims, - x_dims, - &out_dy_help, - true, - true); - } else { - // dX = ddY d_ddout', dY = ddX d_ddout - if (out_d_x && ddy && d_ddout) - MatMulFunction(dev_ctx, - ddy_conj, - d_ddout.get(), - y_dims, - dout_dims, - &out_dx_help, - false, - true); - if (out_d_y && ddx && d_ddout) - MatMulFunction(dev_ctx, - ddx_conj, - d_ddout.get(), - x_dims, - dout_dims, - &out_dy_help, - false, - false); - } - - } else { - if (transpose_y) { - // dX = d_ddout ddY, dY = d_ddout’ ddX - if (out_d_x && ddy && d_ddout) - MatMulFunction(dev_ctx, - d_ddout.get(), - ddy_conj, - dout_dims, - y_dims, - &out_dx_help, - false, - false); - if (out_d_y && ddx && d_ddout) - MatMulFunction(dev_ctx, - d_ddout.get(), - ddx_conj, - dout_dims, - x_dims, - &out_dy_help, - true, - false); - } else { - // dX = d_ddout ddY', dY = ddX' d_ddout - if (out_d_x && ddy && d_ddout) - MatMulFunction(dev_ctx, - d_ddout.get(), - ddy_conj, - dout_dims, - y_dims, - &out_dx_help, - false, - true); - if (out_d_y && ddx && d_ddout) - MatMulFunction(dev_ctx, - ddx_conj, - d_ddout.get(), - x_dims, - dout_dims, - &out_dy_help, - true, - false); - } - } - - // get help dims - const std::vector dx_help_dims = - common::vectorize(out_dx_help.dims()); - const std::vector dy_help_dims = - common::vectorize(out_dx_help.dims()); - - std::vector dx_broadcast_dims(ndim); - std::vector dy_broadcast_dims(ndim); - - std::fill( - dx_broadcast_dims.data(), dx_broadcast_dims.data() + ndim - x_ndim, 1); - std::fill( - dy_broadcast_dims.data(), dy_broadcast_dims.data() + ndim - y_ndim, 1); - std::copy(x_dims.data(), - x_dims.data() + x_ndim, - dx_broadcast_dims.data() + ndim - x_ndim); - std::copy(y_dims.data(), - y_dims.data() + y_ndim, - dy_broadcast_dims.data() + ndim - y_ndim); - - std::vector dx_reduce_dims; - std::vector dy_reduce_dims; - for (int idx = 0; idx <= ndim - 3; idx++) { - if (dx_help_dims[idx] != 1 && dx_broadcast_dims[idx] == 1) { - dx_reduce_dims.push_back(idx); - } - if (dy_help_dims[idx] != 1 && dy_broadcast_dims[idx] == 1) { - dy_reduce_dims.push_back(idx); - } - } - - // Reduce sum to get grad by ReduceSum - if (out_d_x && out_dx_help.initialized()) { - if (dx_reduce_dims.empty()) { - *out_d_x = std::move(out_dx_help); - } else { - ReduceSumForMatmulGrad()( - dev_ctx, out_dx_help, out_d_x, dx_reduce_dims); - } - out_d_x->Resize(x.dims()); - } else if (out_d_x) { - FullLikeKernel(dev_ctx, x, Scalar(0.0), x.dtype(), out_d_x); - } - - if (out_d_y && out_dy_help.initialized()) { - if (dy_reduce_dims.empty()) { - *out_d_y = std::move(out_dy_help); - } else { - ReduceSumForMatmulGrad()( - dev_ctx, out_dy_help, out_d_y, dy_reduce_dims); - } - out_d_y->Resize(y.dims()); - } else if (out_d_y) { - FullLikeKernel(dev_ctx, y, Scalar(0.0), y.dtype(), out_d_y); - } - - // compute d_dout - if (out_d_dout) { - if (d_dx && ddy) { - MatMulFunction(dev_ctx, - d_dx.get(), - ddy_conj, - x_dims, - y_dims, - out_d_dout, - transpose_x, - transpose_y); - } - if (d_dy && ddx) { - MatMulFunction(dev_ctx, - ddx_conj, - d_dy.get(), - x_dims, - y_dims, - out_d_dout, - transpose_x, - transpose_y, - true); - } - - if (!out_d_dout->initialized()) { - FullLikeKernel( - dev_ctx, dout, Scalar(0.0), dout.dtype(), out_d_dout); - } - } - - // compute d_ddx - if (out_d_ddx) { - if (transpose_x && transpose_y) { - // out_d_ddx1 = y' * d_ddout' - if (d_ddout) { - MatMulFunction(dev_ctx, - y_conj, - d_ddout.get(), - y_dims, - dout_dims, - &out_d_ddx_help, - true, - true); - } - - // out_d_ddx2 = D_DY' * DOut' - if (d_dy) { - MatMulFunction(dev_ctx, - d_dy.get(), - dout_conj, - y_dims, - dout_dims, - &out_d_ddx_help, - true, - true, - true); - } - - } else if (transpose_x) { - // out_d_ddx1 = y * d_ddout' - if (d_ddout) { - MatMulFunction(dev_ctx, - y_conj, - d_ddout.get(), - y_dims, - dout_dims, - &out_d_ddx_help, - false, - true); - } - - // out_d_ddx2 = D_DY * Dout' - if (d_dy) { - MatMulFunction(dev_ctx, - d_dy.get(), - dout_conj, - y_dims, - dout_dims, - &out_d_ddx_help, - false, - true, - true); - } - - } else if (transpose_y) { - // out_d_ddx1 = d_ddout * y - if (d_ddout) { - MatMulFunction(dev_ctx, - d_ddout.get(), - y_conj, - dout_dims, - y_dims, - &out_d_ddx_help, - false, - false); - } - - // out_d_ddx2 = Dout * D_DY - if (d_dy) { - MatMulFunction(dev_ctx, - dout_conj, - d_dy.get(), - dout_dims, - y_dims, - &out_d_ddx_help, - false, - false, - true); - } - } else { - // out_d_ddx1 = d_ddout * y' - if (d_ddout) { - MatMulFunction(dev_ctx, - d_ddout.get(), - y_conj, - dout_dims, - y_dims, - &out_d_ddx_help, - false, - true); - } - - // out_d_ddx2 = Dout * D_DY' - if (d_dy) { - MatMulFunction(dev_ctx, - dout_conj, - d_dy.get(), - dout_dims, - y_dims, - &out_d_ddx_help, - false, - true, - true); - } - } - if (out_d_ddx_help.initialized()) { - if (dx_reduce_dims.empty()) { - *out_d_ddx = std::move(out_d_ddx_help); - } else { - ReduceSumForMatmulGrad()( - dev_ctx, out_d_ddx_help, out_d_ddx, dx_reduce_dims); - } - } else { - FullLikeKernel( - dev_ctx, x, Scalar(0.0), x.dtype(), out_d_ddx); - } - - out_d_ddx->Resize(x.dims()); - } - - // compute d_ddy - if (out_d_ddy) { - if (transpose_x && transpose_y) { - // out_d_ddy1 = d_ddout' * x' - if (d_ddout) { - MatMulFunction(dev_ctx, - d_ddout.get(), - x_conj, - dout_dims, - x_dims, - &out_d_ddy_help, - true, - true); - } - - // out_d_ddy2 = dout' * d_dx' - if (d_dx) { - MatMulFunction(dev_ctx, - dout_conj, - d_dx.get(), - dout_dims, - x_dims, - &out_d_ddy_help, - true, - true, - true); - } - - } else if (transpose_x) { - // out_d_ddy1 = x * d_ddout - if (d_ddout) { - MatMulFunction(dev_ctx, - x_conj, - d_ddout.get(), - x_dims, - dout_dims, - &out_d_ddy_help, - false, - false); - } - - // out_d_ddy2 = d_dx * dout - if (d_dx) { - MatMulFunction(dev_ctx, - d_dx.get(), - dout_conj, - x_dims, - dout_dims, - &out_d_ddy_help, - false, - false, - true); - } - - } else if (transpose_y) { - // out_d_ddy1 = d_ddout' * x - if (d_ddout) { - MatMulFunction(dev_ctx, - d_ddout.get(), - x_conj, - dout_dims, - x_dims, - &out_d_ddy_help, - true, - false); - } - - // out_d_ddy2 = dout' * d_dx - if (d_dx) { - MatMulFunction(dev_ctx, - dout_conj, - d_dx.get(), - dout_dims, - x_dims, - &out_d_ddy_help, - true, - false, - true); - } - - } else { - // out_d_ddy1 = x' * d_ddout - if (d_ddout) { - MatMulFunction(dev_ctx, - x_conj, - d_ddout.get(), - x_dims, - dout_dims, - &out_d_ddy_help, - true, - false); - } - - // out_d_ddy2 = d_dx' * dout - if (d_dx) { - MatMulFunction(dev_ctx, - d_dx.get(), - dout_conj, - x_dims, - dout_dims, - &out_d_ddy_help, - true, - false, - true); - } - } - - if (out_d_ddy_help.initialized()) { - if (dy_reduce_dims.empty()) { - *out_d_ddy = std::move(out_d_ddy_help); - } else { - ReduceSumForMatmulGrad()( - dev_ctx, out_d_ddy_help, out_d_ddy, dy_reduce_dims); - } - } else { - FullLikeKernel( - dev_ctx, y, Scalar(0.0), y.dtype(), out_d_ddy); - } - - out_d_ddy->Resize(y.dims()); - } - } -} - -template -void MatmulWithFlattenGradKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - const DenseTensor& out_grad, - int x_num_col_dims, - int y_num_col_dims, - DenseTensor* x_grad, - DenseTensor* y_grad) { - auto x_matrix = - x.dims().size() > 2 ? phi::ReshapeToMatrix(x, x_num_col_dims) : x; - auto y_matrix = - y.dims().size() > 2 ? phi::ReshapeToMatrix(y, y_num_col_dims) : y; - auto* dout = &out_grad; - - DenseTensor dout_mat(*dout); - dout_mat.Resize({common::flatten_to_2d(x.dims(), x_num_col_dims)[0], - common::flatten_to_2d(y.dims(), y_num_col_dims)[1]}); - - auto* dx = x_grad; - auto* dy = y_grad; - - if (dx != nullptr) { - dx->set_lod(x.lod()); - } - if (dy != nullptr) { - dy->set_lod(y.lod()); - } - - auto blas = phi::funcs::GetBlas(dev_ctx); - if (dx) { - dev_ctx.template Alloc(dx); - DenseTensor dx_matrix = - dx->dims().size() > 2 ? phi::ReshapeToMatrix(*dx, x_num_col_dims) : *dx; - - // dx = dout * y'. dx: M x K, dout : M x N, y : K x N - blas.MatMul(dout_mat, false, y_matrix, true, &dx_matrix); - } - if (dy) { - dev_ctx.template Alloc(dy); - DenseTensor dy_matrix = - dy->dims().size() > 2 ? phi::ReshapeToMatrix(*dy, y_num_col_dims) : *dy; - // dy = x' * dout. dy K x N, dout : M x N, x : M x K - blas.MatMul(x_matrix, true, dout_mat, false, &dy_matrix); - } -} - -template -void MatmulWithFlattenDoubleGradKernel( - const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - const DenseTensor& out_grad, - const paddle::optional& x_grad_grad, - const paddle::optional& y_grad_grad, - int x_num_col_dims, - int y_num_col_dims, - DenseTensor* x_grad, - DenseTensor* y_grad, - DenseTensor* out_grad_grad) { - auto x_mat = - x.dims().size() > 2 ? phi::ReshapeToMatrix(x, x_num_col_dims) : x; - auto y_mat = - y.dims().size() > 2 ? phi::ReshapeToMatrix(y, y_num_col_dims) : y; - - const int m = common::flatten_to_2d(x.dims(), x_num_col_dims)[0]; - const int n = common::flatten_to_2d(y.dims(), y_num_col_dims)[1]; - - auto* dout = &out_grad; - DenseTensor dout_mat(*dout); - dout_mat.Resize({m, n}); - - auto* ddx = x_grad_grad.get_ptr(); - auto* ddy = y_grad_grad.get_ptr(); - - auto* dx = x_grad; - auto* dy = y_grad; - auto* ddout = out_grad_grad; - - DenseTensor ddout_mat; - if (ddout) { - ddout->set_lod(dout->lod()); - // allocate and reshape ddout - dev_ctx.template Alloc(ddout); - ddout_mat.ShareDataWith(*ddout); - ddout_mat.Resize({m, n}); - } - - auto blas = phi::funcs::GetBlas(dev_ctx); - // a flag to specify whether ddout value has been set, if flag - // is false, MatMul beta should be 0 to set ddout, if flag is - // true, MatMul beta should be 1 to add result to ddout. - bool ddout_flag = false; - if (ddx) { - auto ddx_mat = ddx->dims().size() > 2 - ? phi::ReshapeToMatrix(*ddx, x_num_col_dims) - : static_cast(*ddx); - - // dy = ddx' * dout. dy : K x M, ddx' : K x M, dout : M x N - if (dy) { - dy->set_lod(y.lod()); - // allocate and reshape dy - dev_ctx.template Alloc(dy); - DenseTensor dy_mat = dy->dims().size() > 2 - ? phi::ReshapeToMatrix(*dy, y_num_col_dims) - : *dy; - blas.MatMul(ddx_mat, true, dout_mat, false, &dy_mat); - } - // ddout1 = ddx * y. ddx : M x K, y : K x N, ddout1 : M x N - if (ddout) { - blas.MatMul(ddx_mat, - false, - y_mat, - false, - static_cast(1.0), - &ddout_mat, - static_cast(ddout_flag)); - ddout_flag = true; - } - } - if (ddy) { - auto ddy_mat = ddy->dims().size() > 2 - ? phi::ReshapeToMatrix(*ddy, y_num_col_dims) - : static_cast(*ddy); - // dx = dout * ddy'. dout : M x N, ddy' : N x K, dx : M x K - if (dx) { - dx->set_lod(x.lod()); - // allocate and reshape dx - dev_ctx.template Alloc(dx); - DenseTensor dx_mat = dx->dims().size() > 2 - ? phi::ReshapeToMatrix(*dx, x_num_col_dims) - : *dx; - blas.MatMul(dout_mat, false, ddy_mat, true, &dx_mat); - } - // ddout2 = x * ddy. x : M x K, ddy : K x N, ddout2 : M x N - if (ddout) { - blas.MatMul(x_mat, - false, - ddy_mat, - false, - static_cast(1.0), - &ddout_mat, - static_cast(ddout_flag)); - } - } -} -template -void LegacyMatmulGradKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - const DenseTensor& out_grad, - bool transpose_x, - bool transpose_y, - float alpha, - DenseTensor* dx, - DenseTensor* dy) { - MatmulGradKernel( - dev_ctx, x, y, out_grad, transpose_x, transpose_y, dx, dy); - if (std::fabs(alpha - 1.f) > 1e-6f) { - ScaleKernel(dev_ctx, *dx, Scalar(alpha), Scalar(0), false, dx); - ScaleKernel(dev_ctx, *dy, Scalar(alpha), Scalar(0), false, dy); - } -} -} // namespace phi diff --git a/backends/metax_gpu/kernels/impl/matmul_kernel_impl.h b/backends/metax_gpu/kernels/impl/matmul_kernel_impl.h deleted file mode 100755 index 5221bd93ba9..00000000000 --- a/backends/metax_gpu/kernels/impl/matmul_kernel_impl.h +++ /dev/null @@ -1,1717 +0,0 @@ -// 2024 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights -// Reserved. -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -// clang-format off -#include "glog/logging.h" - -#include "paddle/phi/common/memory_utils.h" -#include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/kernels/autotune/cache_base.h" -#include "paddle/phi/kernels/cast_kernel.h" -#include "../funcs/blas/blas.h" -#ifdef PADDLE_WITH_HIP -#include "paddle/phi/kernels/funcs/blas/blaslt_impl.hip.h" -#else -#include "../funcs/blas/blaslt_impl.cu.h" -#endif -#include "paddle/phi/kernels/funcs/complex_functors.h" -#include "paddle/phi/kernels/scale_kernel.h" -#if defined(PADDLE_WITH_CUDA) -// #include "paddle/phi/kernels/funcs/cublaslt.h" -#include "paddle/phi/kernels/gpu/cuda_gemm_kernel.h" -#include "paddle/phi/kernels/transpose_kernel.h" -#elif defined(PADDLE_WITH_HIP) -#include "paddle/phi/kernels/funcs/hipblaslt.h" -#endif -#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 && 0 -#include "paddle/phi/kernels/autotune/auto_tune_base.h" -#endif -#include "paddle/phi/kernels/full_kernel.h" -// clang-format on -namespace phi { - -static void GetBroadcastFromDims(const int x_ndim, - const std::int64_t* x_dims, - const int y_ndim, - const std::int64_t* y_dims, - std::int64_t* x_bd_dims, - std::int64_t* y_bd_dims, - std::int64_t* out_bd_dims) { - const int ndim = (std::max)(x_ndim, y_ndim); - std::fill(x_bd_dims, x_bd_dims + ndim - x_ndim, 1); - std::fill(y_bd_dims, y_bd_dims + ndim - y_ndim, 1); - std::copy(x_dims, x_dims + x_ndim, x_bd_dims + ndim - x_ndim); - std::copy(y_dims, y_dims + y_ndim, y_bd_dims + ndim - y_ndim); - - for (int i = 0; i < ndim; ++i) { - PADDLE_ENFORCE_EQ( - x_bd_dims[i] == y_bd_dims[i] || x_bd_dims[i] <= 1 || y_bd_dims[i] <= 1, - true, - phi::errors::InvalidArgument( - "Input(X) and Input(Y) has error dim. " - "X_broadcast's shape[%s] must be equal to Y_broadcast's shape[%s], " - "or X_broadcast's shape[%s] <= 1, or Y_broadcast's shape[%s] <= 1, " - "but received X_broadcast's shape[%s] = [%s]" - "received Y_broadcast's shape[%s] = [%s].", - i, - i, - i, - i, - i, - x_bd_dims[i], - i, - y_bd_dims[i])); - if (x_bd_dims[i] == 0 || y_bd_dims[i] == 0) { - out_bd_dims[i] = 0; - } else { - out_bd_dims[i] = (std::max)(x_bd_dims[i], y_bd_dims[i]); - } - } -} - -static int64_t GetIndexMessage(const int n, - const int64_t* dims, - const int64_t* index) { - int64_t sum = 0; - for (int i = 0; i < n; ++i) { - if (dims[i] > 1) { - sum = sum * dims[i] + index[i]; - } - } - return sum; -} - -static void IndexIncreaseFromDims(const int ndim, - const int64_t* dims, - int64_t* index) { - for (int i = ndim - 1; i >= 0; --i) { - ++index[i]; - if (index[i] >= dims[i]) { - index[i] -= dims[i]; - } else { - break; - } - } -} - -// The general implementation with blas. -template -void MatMulFunctionImplWithBlas( - const Context& dev_ctx, - const DenseTensor& X, - const DenseTensor& Y, - const std::vector& x_dims, - const std::vector& y_dims, - DenseTensor* Out, - bool trans_x, - bool trans_y, - bool flag = false, - phi::funcs::MatmulPlanner* matmul_planner UNUSED = nullptr) { - const int x_ndim = x_dims.size(); - const int y_ndim = y_dims.size(); - - // Get data ptr - const T* x_data = X.data(); - const T* y_data = Y.data(); - - auto blas = phi::funcs::GetBlas(dev_ctx); - - if (x_ndim == 1 && y_ndim == 1) { - const int M = X.numel(); - const int N = Y.numel(); - PADDLE_ENFORCE_EQ( - M, - N, - phi::errors::InvalidArgument( - "X's numbers must be equal to Y's numbers, " - "when X/Y's dims =1. But received X has [%d] elements, " - "received Y has [%d] elements.", - M, - N)); - VLOG(3) << "MatMul's case 1"; - Out->Resize(common::make_ddim({})); - dev_ctx.template Alloc(Out); - blas.GEMM(CblasNoTrans, - CblasTrans, - 1, - 1, - M, - static_cast(1), - y_data, - x_data, - static_cast(flag), - dev_ctx.template Alloc(Out)); - return; - } - - if (x_ndim == 1) { - const int N = X.numel(); - if (trans_y) { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 1], - N, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 1, - N, - y_ndim - 1, - y_dims[y_ndim - 1])); - } else { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 2], - N, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 2, - N, - y_ndim - 2, - y_dims[y_ndim - 2])); - } - std::vector out_dims(y_ndim - 1); - if (trans_y) { - std::copy_n(y_dims.cbegin(), y_ndim - 1, out_dims.begin()); - } else { - std::copy_n(y_dims.cbegin(), y_ndim - 2, out_dims.begin()); - out_dims.back() = y_dims.back(); - } - Out->ResizeAndAllocate(common::make_ddim(out_dims)); - dev_ctx.template Alloc(Out); - if (trans_y) { - const int M = Y.numel() / N; - VLOG(3) << "MatMul's case 2"; - blas.GEMV(false, - M, - N, - static_cast(1), - y_data, - x_data, - static_cast(flag), - dev_ctx.template Alloc(Out)); - } else { - const int M = y_dims[y_ndim - 1]; - const int batch_size = Y.numel() / (M * N); - if (batch_size == 1) { - VLOG(3) << "MatMul's case 3"; - blas.GEMV(true, - N, - M, - static_cast(1), - y_data, - x_data, - static_cast(flag), - dev_ctx.template Alloc(Out)); - } else { - VLOG(3) << "MatMul's case 4"; - blas.BatchedGEMM(CblasTrans, - CblasNoTrans, - M, - 1, - N, - static_cast(1), - y_data, - x_data, - static_cast(flag), - dev_ctx.template Alloc(Out), - batch_size, - M * N, - 0); - } - } - return; - } - - if (y_ndim == 1) { - const int N = Y.numel(); - if (trans_x) { - PADDLE_ENFORCE_EQ( - x_dims[x_ndim - 2], - N, - phi::errors::InvalidArgument("Input(X) has error dim." - "X'dims[%d] must be equal to %d" - "But received X'dims[%d] is %d", - x_ndim - 2, - N, - x_ndim - 2, - x_dims[x_ndim - 2])); - } else { - PADDLE_ENFORCE_EQ( - x_dims[x_ndim - 1], - N, - phi::errors::InvalidArgument("Input(X) has error dim." - "X'dims[%d] must be equal to %d" - "But received X'dims[%d] is %d", - x_ndim - 1, - N, - x_ndim - 1, - x_dims[x_ndim - 1])); - } - std::vector out_dims(x_ndim - 1); - if (trans_x) { - std::copy_n(x_dims.cbegin(), x_ndim - 2, out_dims.begin()); - out_dims.back() = x_dims.back(); - } else { - std::copy_n(x_dims.cbegin(), x_ndim - 1, out_dims.begin()); - } - Out->ResizeAndAllocate(common::make_ddim(out_dims)); - dev_ctx.template Alloc(Out); - - if (trans_x) { - const int M = x_dims[x_ndim - 1]; - const int batch_size = X.numel() / (M * N); - if (batch_size == 1) { - VLOG(3) << "MatMul's case 5"; - blas.GEMV(true, - N, - M, - static_cast(1), - x_data, - y_data, - static_cast(flag), - dev_ctx.template Alloc(Out)); - } else { - VLOG(3) << "MatMul's case 6"; - blas.BatchedGEMM(CblasTrans, - CblasNoTrans, - M, - 1, - N, - static_cast(1), - x_data, - y_data, - static_cast(flag), - dev_ctx.template Alloc(Out), - batch_size, - M * N, - 0); - } - } else { - const int M = X.numel() / N; - VLOG(3) << "MatMul's case 7"; - blas.GEMV(false, - M, - N, - static_cast(1), - x_data, - y_data, - static_cast(flag), - dev_ctx.template Alloc(Out)); - } - return; - } - - const int M = trans_x ? x_dims[x_ndim - 1] : x_dims[x_ndim - 2]; - const int K = trans_x ? x_dims[x_ndim - 2] : x_dims[x_ndim - 1]; - if (trans_y) { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 1], - K, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 1, - K, - y_ndim - 1, - y_dims[y_ndim - 1])); - } else { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 2], - K, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 2, - K, - y_ndim - 2, - y_dims[y_ndim - 2])); - } - const int N = trans_y ? y_dims[y_ndim - 2] : y_dims[y_ndim - 1]; - const int ndim = (std::max)(x_ndim, y_ndim); - std::vector x_broadcast_dims(ndim); - std::vector y_broadcast_dims(ndim); - std::vector out_broadcast_dims(ndim); - - GetBroadcastFromDims(x_ndim - 2, - x_dims.data(), - y_ndim - 2, - y_dims.data(), - x_broadcast_dims.data(), - y_broadcast_dims.data(), - out_broadcast_dims.data()); - out_broadcast_dims[ndim - 2] = M; - out_broadcast_dims[ndim - 1] = N; - - Out->ResizeAndAllocate(common::make_ddim(out_broadcast_dims)); - dev_ctx.template Alloc(Out); - - const int batch_dim = ndim - 2; - // broadcast message - const bool is_broadcast_dims = - !std::equal(x_broadcast_dims.cbegin(), - x_broadcast_dims.cbegin() + batch_dim, - y_broadcast_dims.cbegin()); - - const std::int64_t x_batch_size = - std::accumulate(x_broadcast_dims.cbegin(), - x_broadcast_dims.cbegin() + batch_dim, - 1LL, - std::multiplies()); - const std::int64_t y_batch_size = - std::accumulate(y_broadcast_dims.cbegin(), - y_broadcast_dims.cbegin() + batch_dim, - 1LL, - std::multiplies()); - const std::int64_t out_batch_size = - std::accumulate(out_broadcast_dims.cbegin(), - out_broadcast_dims.cbegin() + batch_dim, - 1LL, - std::multiplies()); - if (out_batch_size == 0) return; - if (x_batch_size == 1 && y_batch_size == 1) { - VLOG(3) << "MatMul's case 8"; - blas.GEMM(trans_x ? CblasTrans : CblasNoTrans, - trans_y ? CblasTrans : CblasNoTrans, - M, - N, - K, - static_cast(1), - x_data, - y_data, - static_cast(flag), - dev_ctx.template Alloc(Out)); - } else if (x_batch_size == 1) { - if (M == 1 && trans_y) { - VLOG(3) << "MatMul's case 9"; - blas.GEMV(false, - y_batch_size * N, - K, - static_cast(1), - y_data, - x_data, - static_cast(flag), - dev_ctx.template Alloc(Out)); - } else { - VLOG(3) << "MatMul's case 10"; - blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, - trans_y ? CblasTrans : CblasNoTrans, - M, - N, - K, - static_cast(1), - x_data, - y_data, - static_cast(flag), - dev_ctx.template Alloc(Out), - out_batch_size, - 0, - K * N); - } - } else if (y_batch_size == 1) { - if (!trans_x) { - VLOG(3) << "MatMul's case 11"; - blas.GEMM(CblasNoTrans, - trans_y ? CblasTrans : CblasNoTrans, - x_batch_size * M, - N, - K, - static_cast(1), - x_data, - y_data, - static_cast(flag), - dev_ctx.template Alloc(Out)); - } else { - VLOG(3) << "MatMul's case 12"; - blas.BatchedGEMM(CblasTrans, - trans_y ? CblasTrans : CblasNoTrans, - M, - N, - K, - static_cast(1), - x_data, - y_data, - static_cast(flag), - dev_ctx.template Alloc(Out), - out_batch_size, - M * K, - 0); - } - } else if (!is_broadcast_dims) { - VLOG(3) << "MatMul's case 13"; - blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, - trans_y ? CblasTrans : CblasNoTrans, - M, - N, - K, - static_cast(1), - x_data, - y_data, - static_cast(flag), - dev_ctx.template Alloc(Out), - out_batch_size, - M * K, - K * N); - } else { - // in the case, can't use stridedgemm - std::vector x_ptr(out_batch_size); - std::vector y_ptr(out_batch_size); - std::vector out_ptr(out_batch_size); - std::vector index(batch_dim, 0); - for (std::int64_t i = 0; i < out_batch_size; ++i) { - // using the index to get offset - const std::int64_t x_index = - GetIndexMessage(batch_dim, x_broadcast_dims.data(), index.data()); - const std::int64_t y_index = - GetIndexMessage(batch_dim, y_broadcast_dims.data(), index.data()); - - x_ptr[i] = x_data + x_index * M * K; - y_ptr[i] = y_data + y_index * K * N; - out_ptr[i] = dev_ctx.template Alloc(Out) + i * M * N; - IndexIncreaseFromDims(batch_dim, out_broadcast_dims.data(), index.data()); - } - VLOG(3) << "MatMul's case 14"; - blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, - trans_y ? CblasTrans : CblasNoTrans, - M, - N, - K, - static_cast(1), - x_ptr.data(), - y_ptr.data(), - static_cast(flag), - out_ptr.data(), - out_batch_size); - } -} - -#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 && 0 -// This is almost a copy from MatMulFunctionImplWithBlas, -// compare cublas with cublasLt kernels when Matmul autotune is on -template -void MatMulFunctionImplWithCublasLt( - const Context& dev_ctx, - const DenseTensor& X, - const DenseTensor& Y, - const std::vector& x_dims, - const std::vector& y_dims, - DenseTensor* Out, - bool trans_x, - bool trans_y, - bool flag = false, - phi::funcs::MatmulPlanner* matmul_planner = nullptr) { - const int x_ndim = x_dims.size(); - const int y_ndim = y_dims.size(); - const T* x_data = X.data(); - const T* y_data = Y.data(); - using blaslt = phi::funcs::MatmulWithCublasLt; - - if (x_ndim == 1 && y_ndim == 1) { - const int M = X.numel(); - const int N = Y.numel(); - PADDLE_ENFORCE_EQ( - M, - N, - phi::errors::InvalidArgument( - "X's numbers must be equal to Y's numbers," - "when X/Y's dims =1. But received X has [%d] elements," - "received Y has [%d] elements", - M, - N)); - - // MatMul's case 0 => vector * vector - Out->Resize(common::make_ddim({})); - dev_ctx.template Alloc(Out); - VLOG(3) << "MatMul with blaslt case 1"; - blaslt::Run(dev_ctx, - y_data, - x_data, - dev_ctx.template Alloc(Out), - 1, - 1, - M, - false, - true, - matmul_planner); - return; - } - - if (x_ndim == 1) { - const int N = X.numel(); - if (trans_y) { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 1], - N, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 1, - N, - y_ndim - 1, - y_dims[y_ndim - 1])); - } else { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 2], - N, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 2, - N, - y_ndim - 2, - y_dims[y_ndim - 2])); - } - std::vector out_dims(y_ndim - 1); - if (trans_y) { - std::copy_n(y_dims.cbegin(), y_ndim - 1, out_dims.begin()); - } else { - std::copy_n(y_dims.cbegin(), y_ndim - 2, out_dims.begin()); - out_dims.back() = y_dims.back(); - } - Out->ResizeAndAllocate(common::make_ddim(out_dims)); - dev_ctx.template Alloc(Out); - if (trans_y) { - const int M = Y.numel() / N; - VLOG(3) << "MatMul with blaslt 2"; - blaslt::Run(dev_ctx, - y_data, - x_data, - dev_ctx.template Alloc(Out), - M, - 1, - N, - false, - false, - matmul_planner); - } else { - const int M = y_dims[y_ndim - 1]; - const int batch_size = Y.numel() / (M * N); - if (batch_size == 1) { - VLOG(3) << "MatMul with blaslt 3"; - blaslt::Run(dev_ctx, - y_data, - x_data, - dev_ctx.template Alloc(Out), - M, - 1, - N, - true, - false, - matmul_planner); - } else { - VLOG(3) << "MatMul with blaslt 4"; - blaslt::RunWithBatch(dev_ctx, - y_data, - x_data, - dev_ctx.template Alloc(Out), - M, - 1, - N, - true, - false, - batch_size, - M * N, - 0, - M, - matmul_planner); - } - } - return; - } - - if (y_ndim == 1) { - const int N = Y.numel(); - if (trans_x) { - PADDLE_ENFORCE_EQ( - x_dims[x_ndim - 2], - N, - phi::errors::InvalidArgument("Input(X) has error dim." - "X'dims[%d] must be equal to %d" - "But received X'dims[%d] is %d", - x_ndim - 2, - N, - x_ndim - 2, - x_dims[x_ndim - 2])); - } else { - PADDLE_ENFORCE_EQ( - x_dims[x_ndim - 1], - N, - phi::errors::InvalidArgument("Input(X) has error dim." - "X'dims[%d] must be equal to %d" - "But received X'dims[%d] is %d", - x_ndim - 1, - N, - x_ndim - 1, - x_dims[x_ndim - 1])); - } - std::vector out_dims(x_ndim - 1); - if (trans_x) { - std::copy_n(x_dims.cbegin(), x_ndim - 2, out_dims.begin()); - out_dims.back() = x_dims.back(); - } else { - std::copy_n(x_dims.cbegin(), x_ndim - 1, out_dims.begin()); - } - Out->ResizeAndAllocate(common::make_ddim(out_dims)); - dev_ctx.template Alloc(Out); - - if (trans_x) { - const int M = x_dims[x_ndim - 1]; - const int batch_size = X.numel() / (M * N); - if (batch_size == 1) { - VLOG(3) << "MatMul with blaslt 5"; - blaslt::Run(dev_ctx, - x_data, - y_data, - dev_ctx.template Alloc(Out), - M, - 1, - N, - true, - false, - matmul_planner); - } else { - VLOG(3) << "MatMul with blaslt 6"; - blaslt::RunWithBatch(dev_ctx, - x_data, - y_data, - dev_ctx.template Alloc(Out), - M, - 1, - N, - true, - false, - batch_size, - M * N, - 0, - M, - matmul_planner); - } - } else { - const int M = X.numel() / N; - VLOG(3) << "MatMul with blaslt 7"; - blaslt::Run(dev_ctx, - x_data, - y_data, - dev_ctx.template Alloc(Out), - M, - 1, - N, - false, - false, - matmul_planner); - } - return; - } - - const int M = trans_x ? x_dims[x_ndim - 1] : x_dims[x_ndim - 2]; - const int K = trans_x ? x_dims[x_ndim - 2] : x_dims[x_ndim - 1]; - if (trans_y) { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 1], - K, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 1, - K, - y_ndim - 1, - y_dims[y_ndim - 1])); - } else { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 2], - K, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 2, - K, - y_ndim - 2, - y_dims[y_ndim - 2])); - } - const int N = trans_y ? y_dims[y_ndim - 2] : y_dims[y_ndim - 1]; - const int ndim = (std::max)(x_ndim, y_ndim); - std::vector x_broadcast_dims(ndim); - std::vector y_broadcast_dims(ndim); - std::vector out_broadcast_dims(ndim); - - GetBroadcastFromDims(x_ndim - 2, - x_dims.data(), - y_ndim - 2, - y_dims.data(), - x_broadcast_dims.data(), - y_broadcast_dims.data(), - out_broadcast_dims.data()); - out_broadcast_dims[ndim - 2] = M; - out_broadcast_dims[ndim - 1] = N; - - Out->ResizeAndAllocate(common::make_ddim(out_broadcast_dims)); - dev_ctx.template Alloc(Out); - - const int batch_dim = ndim - 2; - // broadcast message - const bool is_broadcast_dims = - !std::equal(x_broadcast_dims.cbegin(), - x_broadcast_dims.cbegin() + batch_dim, - y_broadcast_dims.cbegin()); - - const std::int64_t x_batch_size = - std::accumulate(x_broadcast_dims.cbegin(), - x_broadcast_dims.cbegin() + batch_dim, - 1LL, - std::multiplies()); - const std::int64_t y_batch_size = - std::accumulate(y_broadcast_dims.cbegin(), - y_broadcast_dims.cbegin() + batch_dim, - 1LL, - std::multiplies()); - const std::int64_t out_batch_size = - std::accumulate(out_broadcast_dims.cbegin(), - out_broadcast_dims.cbegin() + batch_dim, - 1LL, - std::multiplies()); - if (out_batch_size == 0) return; - if (x_batch_size == 1 && y_batch_size == 1) { - VLOG(3) << "MatMul with blaslt 8"; - blaslt::Run(dev_ctx, - x_data, - y_data, - dev_ctx.template Alloc(Out), - M, - N, - K, - trans_x, - trans_y, - matmul_planner); - } else if (x_batch_size == 1) { - if (M == 1 && trans_y) { - VLOG(3) << "MatMul with blaslt 9"; - blaslt::Run(dev_ctx, - y_data, - x_data, - dev_ctx.template Alloc(Out), - y_batch_size * N, - 1, - K, - false, - false, - matmul_planner); - } else { - VLOG(3) << "MatMul with blaslt 10"; - blaslt::RunWithBatch(dev_ctx, - x_data, - y_data, - dev_ctx.template Alloc(Out), - M, - N, - K, - trans_x, - trans_y, - out_batch_size, - 0, - K * N, - M * N, - matmul_planner); - } - } else if (y_batch_size == 1) { - if (!trans_x) { - VLOG(3) << "MatMul with blaslt 11"; - blaslt::Run(dev_ctx, - x_data, - y_data, - dev_ctx.template Alloc(Out), - x_batch_size * M, - N, - K, - false, - trans_y, - matmul_planner); - } else { - VLOG(3) << "MatMul with blaslt 12"; - blaslt::RunWithBatch(dev_ctx, - x_data, - y_data, - dev_ctx.template Alloc(Out), - M, - N, - K, - true, - trans_y, - out_batch_size, - M * K, - 0, - M * N, - matmul_planner); - } - } else if (!is_broadcast_dims) { - VLOG(3) << "MatMul with blaslt 13"; - blaslt::RunWithBatch(dev_ctx, - x_data, - y_data, - dev_ctx.template Alloc(Out), - M, - N, - K, - trans_x, - trans_y, - out_batch_size, - M * K, - K * N, - M * N, - matmul_planner); - } else { - // in the case, can't use stridedgemm - std::vector x_ptr(out_batch_size); - std::vector y_ptr(out_batch_size); - std::vector out_ptr(out_batch_size); - std::vector index(batch_dim, 0); - for (std::int64_t i = 0; i < out_batch_size; ++i) { - // using the index to get offset - const std::int64_t x_index = - GetIndexMessage(batch_dim, x_broadcast_dims.data(), index.data()); - const std::int64_t y_index = - GetIndexMessage(batch_dim, y_broadcast_dims.data(), index.data()); - - x_ptr[i] = x_data + x_index * M * K; - y_ptr[i] = y_data + y_index * K * N; - out_ptr[i] = dev_ctx.template Alloc(Out) + i * M * N; - IndexIncreaseFromDims(batch_dim, out_broadcast_dims.data(), index.data()); - } - VLOG(3) << "MatMul with blaslt 14"; - blaslt::RunWithBatch(dev_ctx, - x_ptr.data(), - y_ptr.data(), - out_ptr.data(), - M, - N, - K, - trans_x, - trans_y, - out_batch_size, - matmul_planner); - } -} -#endif - -template -struct MatMulDispatcher { - void operator()(const Context& ctx, - const DenseTensor& x, - const DenseTensor& y, - const std::vector& x_dims, - const std::vector& y_dims, - DenseTensor* out, - bool trans_x, - bool trans_y, - bool flag = false) { - MatMulFunctionImplWithBlas( - ctx, x, y, x_dims, y_dims, out, trans_x, trans_y, flag); - } -}; - -#ifdef PADDLE_WITH_CUDA -template -struct MatMulDispatcher { - void operator()(const phi::GPUContext& ctx, - const DenseTensor& x, - const DenseTensor& y, - const std::vector& x_dims, - const std::vector& y_dims, - DenseTensor* out, - bool trans_x, - bool trans_y, - bool flag = false) { -#if CUDA_VERSION >= 11060 && 0 - auto* tuner = phi::autotune::MakeMatmulTuner( - MatMulFunctionImplWithBlas); - tuner->AddCallBack(MatMulFunctionImplWithCublasLt); - phi::funcs::MatmulPlanner matmul_planner(x_dims, - y_dims, - trans_x, - trans_y, - phi::CppTypeToDataType::Type(), - funcs::MatmulFusedType::kMatmul, - /* bias_data */ nullptr, - /* reserve_data */ nullptr, - /* use_addto */ flag, - /* no_exchange */ true); - tuner->Run(ctx, - matmul_planner.GetKey(), - ctx, - x, - y, - x_dims, - y_dims, - out, - trans_x, - trans_y, - flag, - &matmul_planner); -#else - MatMulFunctionImplWithBlas( - ctx, x, y, x_dims, y_dims, out, trans_x, trans_y, flag); -#endif - } -}; - -#endif // PADDLE_WITH_CUDA - -template -void MatMulFunction(const Context& ctx, - const DenseTensor& x, - const DenseTensor& y, - const std::vector& x_dims, - const std::vector& y_dims, - DenseTensor* out, - bool trans_x, - bool trans_y, - bool flag = false) { - MatMulDispatcher()( - ctx, x, y, x_dims, y_dims, out, trans_x, trans_y, flag); -} - -template -bool MatMulInt8Function(const Context& ctx, - const DenseTensor& x, - const DenseTensor& y, - const std::vector& x_dims, - const std::vector& y_dims, - DenseTensor* out, - bool trans_x, - bool trans_y) { - return false; -} - -#ifdef PADDLE_WITH_CUDA -template <> -bool inline MatMulInt8Function(const phi::GPUContext& ctx, - const DenseTensor& x, - const DenseTensor& y, - const std::vector& x_dims, - const std::vector& y_dims, - DenseTensor* out, - bool trans_x, - bool trans_y) { - if (x.dtype() != DataType::INT8 || y.dtype() != DataType::INT8) { - return false; - } -#if CUDA_VERSION >= 11060 && 0 - const int x_ndim = x_dims.size(); - const int y_ndim = y_dims.size(); - const int8_t* x_data = x.data(); - const int8_t* y_data = y.data(); - using blaslt = phi::funcs::MatmulWithCublasLt; - - phi::funcs::MatmulPlanner matmul_planner( - x_dims, - y_dims, - trans_x, - trans_y, - phi::CppTypeToDataType::Type(), - funcs::MatmulFusedType::kMatmul, - /* bias_data */ nullptr, - /* reserve_data */ nullptr, - /* use_addto */ false, - /* no_exchange */ true); - - if (x_ndim == 1 && y_ndim == 1) { - const int M = x.numel(); - const int N = y.numel(); - PADDLE_ENFORCE_EQ( - M, - N, - phi::errors::InvalidArgument( - "X's numbers must be equal to Y's numbers," - "when X/Y's dims =1. But received X has [%d] elements," - "received Y has [%d] elements", - M, - N)); - if (!(M % 4 == 0)) { - return false; - } - - out->Resize(common::make_ddim({})); - ctx.template Alloc(out); - blaslt::Run(ctx, - y_data, - x_data, - ctx.template Alloc(out), - 1, - 1, - M, - false, - true, - &matmul_planner); - return true; - } - if (x_ndim == 1) { - const int N = x.numel(); - if (trans_y) { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 1], - N, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 1, - N, - y_ndim - 1, - y_dims[y_ndim - 1])); - if (!(N % 4 == 0)) { - return false; - } - } else { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 2], - N, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 2, - N, - y_ndim - 2, - y_dims[y_ndim - 2])); - const int M = y.numel() / N; - if (!(M == 1 || M % 4 == 0)) { - return false; - } - } - std::vector out_dims(y_ndim - 1); - if (trans_y) { - std::copy_n(y_dims.cbegin(), y_ndim - 1, out_dims.begin()); - } else { - std::copy_n(y_dims.cbegin(), y_ndim - 2, out_dims.begin()); - out_dims.back() = y_dims.back(); - } - out->ResizeAndAllocate(common::make_ddim(out_dims)); - ctx.template Alloc(out); - if (trans_y) { - const int M = y.numel() / N; - blaslt::Run(ctx, - y_data, - x_data, - ctx.template Alloc(out), - M, - 1, - N, - false, - false, - &matmul_planner); - } else { - const int M = y_dims[y_ndim - 1]; - const int batch_size = y.numel() / (M * N); - if (batch_size == 1) { - blaslt::Run(ctx, - y_data, - x_data, - ctx.template Alloc(out), - M, - 1, - N, - true, - false, - &matmul_planner); - } else { - blaslt::RunWithBatch(ctx, - y_data, - x_data, - ctx.template Alloc(out), - M, - 1, - N, - true, - false, - batch_size, - M * N, - 0, - M, - &matmul_planner); - } - } - return true; - } - - if (y_ndim == 1) { - const int N = y.numel(); - if (trans_x) { - PADDLE_ENFORCE_EQ( - x_dims[x_ndim - 2], - N, - phi::errors::InvalidArgument("Input(X) has error dim." - "X'dims[%d] must be equal to %d" - "But received X'dims[%d] is %d", - x_ndim - 2, - N, - x_ndim - 2, - x_dims[x_ndim - 2])); - const int M = x.numel() / N; - if (!((M == 1 || M % 4 == 0))) { - return false; - } - } else { - PADDLE_ENFORCE_EQ( - x_dims[x_ndim - 1], - N, - phi::errors::InvalidArgument("Input(X) has error dim." - "X'dims[%d] must be equal to %d" - "But received X'dims[%d] is %d", - x_ndim - 1, - N, - x_ndim - 1, - x_dims[x_ndim - 1])); - if (N % 4 != 0) { - return false; - } - } - std::vector out_dims(x_ndim - 1); - if (trans_x) { - std::copy_n(x_dims.cbegin(), x_ndim - 2, out_dims.begin()); - out_dims.back() = x_dims.back(); - } else { - std::copy_n(x_dims.cbegin(), x_ndim - 1, out_dims.begin()); - } - out->ResizeAndAllocate(common::make_ddim(out_dims)); - ctx.template Alloc(out); - - if (trans_x) { - const int M = x_dims[x_ndim - 1]; - const int batch_size = x.numel() / (M * N); - if (batch_size == 1) { - blaslt::Run(ctx, - x_data, - y_data, - ctx.template Alloc(out), - M, - 1, - N, - true, - false, - &matmul_planner); - } else { - blaslt::RunWithBatch(ctx, - x_data, - y_data, - ctx.template Alloc(out), - M, - 1, - N, - true, - false, - batch_size, - M * N, - 0, - M, - &matmul_planner); - } - } else { - const int M = x.numel() / N; - blaslt::Run(ctx, - x_data, - y_data, - ctx.template Alloc(out), - M, - 1, - N, - false, - false, - &matmul_planner); - } - return true; - } - - const int M = trans_x ? x_dims[x_ndim - 1] : x_dims[x_ndim - 2]; - const int K = trans_x ? x_dims[x_ndim - 2] : x_dims[x_ndim - 1]; - if (trans_y) { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 1], - K, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 1, - K, - y_ndim - 1, - y_dims[y_ndim - 1])); - } else { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 2], - K, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 2, - K, - y_ndim - 2, - y_dims[y_ndim - 2])); - } - const int N = trans_y ? y_dims[y_ndim - 2] : y_dims[y_ndim - 1]; - const int ndim = (std::max)(x_ndim, y_ndim); - std::vector x_broadcast_dims(ndim); - std::vector y_broadcast_dims(ndim); - std::vector out_broadcast_dims(ndim); - GetBroadcastFromDims(x_ndim - 2, - x_dims.data(), - y_ndim - 2, - y_dims.data(), - x_broadcast_dims.data(), - y_broadcast_dims.data(), - out_broadcast_dims.data()); - out_broadcast_dims[ndim - 2] = M; - out_broadcast_dims[ndim - 1] = N; - - out->ResizeAndAllocate(common::make_ddim(out_broadcast_dims)); - ctx.template Alloc(out); - - const int batch_dim = ndim - 2; - // broadcast message - const bool is_broadcast_dims = - !std::equal(x_broadcast_dims.cbegin(), - x_broadcast_dims.cbegin() + batch_dim, - y_broadcast_dims.cbegin()); - - const std::int64_t x_batch_size = - std::accumulate(x_broadcast_dims.cbegin(), - x_broadcast_dims.cbegin() + batch_dim, - 1LL, - std::multiplies()); - const std::int64_t y_batch_size = - std::accumulate(y_broadcast_dims.cbegin(), - y_broadcast_dims.cbegin() + batch_dim, - 1LL, - std::multiplies()); - const std::int64_t out_batch_size = - std::accumulate(out_broadcast_dims.cbegin(), - out_broadcast_dims.cbegin() + batch_dim, - 1LL, - std::multiplies()); - if (out_batch_size == 0) return true; - - if (x_batch_size == 1 && M == 1 && trans_y) { - if (!(K % 4 == 0)) { - return false; - } - } else if (!trans_x && !trans_y) { - if (!(N % 4 == 0 || N == 1) || !(K % 4 == 0) || (M == 1 && N == 1)) { - return false; - } - } else if (!trans_x && trans_y) { - if (!(K % 4 == 0)) { - return false; - } - } else if (trans_x && !trans_y) { - if (!(M % 4 == 0 || M == 1) || !(N % 4 == 0 || N == 1)) { - return false; - } - } else { - if (!(M % 4 == 0 || M == 1) || !(K % 4 == 0)) { - return false; - } - } - if (x_batch_size == 1 && y_batch_size == 1) { - blaslt::Run(ctx, - x_data, - y_data, - ctx.template Alloc(out), - M, - N, - K, - trans_x, - trans_y, - &matmul_planner); - } else if (x_batch_size == 1) { - if (M == 1 && trans_y) { - blaslt::Run(ctx, - y_data, - x_data, - ctx.template Alloc(out), - y_batch_size * N, - 1, - K, - false, - false, - &matmul_planner); - } else { - blaslt::RunWithBatch(ctx, - x_data, - y_data, - ctx.template Alloc(out), - M, - N, - K, - trans_x, - trans_y, - out_batch_size, - 0, - K * N, - M * N, - &matmul_planner); - } - } else if (y_batch_size == 1) { - if (!trans_x) { - blaslt::Run(ctx, - x_data, - y_data, - ctx.template Alloc(out), - x_batch_size * M, - N, - K, - false, - trans_y, - &matmul_planner); - } else { - blaslt::RunWithBatch(ctx, - x_data, - y_data, - ctx.template Alloc(out), - M, - N, - K, - true, - trans_y, - out_batch_size, - M * K, - 0, - M * N, - &matmul_planner); - } - } else if (!is_broadcast_dims) { - blaslt::RunWithBatch(ctx, - x_data, - y_data, - ctx.template Alloc(out), - M, - N, - K, - trans_x, - trans_y, - out_batch_size, - M * K, - K * N, - M * N, - &matmul_planner); - } else { - // in the case, can't use stridedgemm - std::vector x_ptr(out_batch_size); - std::vector y_ptr(out_batch_size); - std::vector out_ptr(out_batch_size); - std::vector index(batch_dim, 0); - for (std::int64_t i = 0; i < out_batch_size; ++i) { - // using the index to get offset - const std::int64_t x_index = - GetIndexMessage(batch_dim, x_broadcast_dims.data(), index.data()); - const std::int64_t y_index = - GetIndexMessage(batch_dim, y_broadcast_dims.data(), index.data()); - - x_ptr[i] = x_data + x_index * M * K; - y_ptr[i] = y_data + y_index * K * N; - out_ptr[i] = ctx.template Alloc(out) + i * M * N; - IndexIncreaseFromDims(batch_dim, out_broadcast_dims.data(), index.data()); - } - blaslt::RunWithBatch(ctx, - x_ptr.data(), - y_ptr.data(), - out_ptr.data(), - M, - N, - K, - trans_x, - trans_y, - out_batch_size, - &matmul_planner); - } - return true; -#else - return false; -#endif -} -#endif - -template -typename std::enable_if::value>::type -MatmulJudgeDtypeKernel(const Context& ctx, - const DenseTensor& x, - const DenseTensor& y, - const std::vector& x_dims, - const std::vector& y_dims, - DenseTensor* out, - bool transpose_x, - bool transpose_y) { - bool try_matmul_int8 = MatMulInt8Function( - ctx, x, y, x_dims, y_dims, out, transpose_x, transpose_y); - if (try_matmul_int8) { - return; - } - auto x_tmp = phi::Cast(ctx, x, phi::DataType::FLOAT32); - auto y_tmp = phi::Cast(ctx, y, phi::DataType::FLOAT32); - DenseTensor out_tmp; - MatMulFunction( - ctx, x_tmp, y_tmp, x_dims, y_dims, &out_tmp, transpose_x, transpose_y); - if (x.dtype() == phi::DataType::INT8) { - phi::CastKernel(ctx, out_tmp, phi::DataType::INT32, out); - return; - } - phi::CastKernel(ctx, out_tmp, x.dtype(), out); -} - -template -typename std::enable_if::value>::type -MatmulJudgeDtypeKernel(const Context& ctx, - const DenseTensor& x, - const DenseTensor& y, - const std::vector& x_dims, - const std::vector& y_dims, - DenseTensor* out, - bool transpose_x, - bool transpose_y) { - MatMulFunction( - ctx, x, y, x_dims, y_dims, out, transpose_x, transpose_y); -} - -template -void MatmulKernel(const Context& ctx, - const DenseTensor& x, - const DenseTensor& y, - bool transpose_x, - bool transpose_y, - DenseTensor* out) { - if (x.numel() == 0 || y.numel() == 0) { - // input shape [1, 1, 5, 0], [1, 1, 0, 5], result shape is [1, 1, 5, 5] - phi::Full( - ctx, phi::IntArray(common::vectorize(out->dims())), 0, out); - return; - } - PADDLE_ENFORCE_GE( - common::product(x.dims()), - 0, - common::errors::InvalidArgument( - "The dims of Input(X) should be greater than or equal to 0.")); - PADDLE_ENFORCE_GE( - common::product(y.dims()), - 0, - common::errors::InvalidArgument( - "The dims of Input(Y) should be greater than or equal to 0.")); - const std::vector x_dims = common::vectorize(x.dims()); - const std::vector y_dims = common::vectorize(y.dims()); - MatmulJudgeDtypeKernel( - ctx, x, y, x_dims, y_dims, out, transpose_x, transpose_y); -} - -template -void MatmulWithFlattenKernelImpl(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - int x_num_col_dims, - int y_num_col_dims, - DenseTensor* out) { - const DenseTensor x_matrix = - x.dims().size() > 2 ? phi::ReshapeToMatrix(x, x_num_col_dims) : x; - const DenseTensor y_matrix = - y.dims().size() > 2 ? phi::ReshapeToMatrix(y, y_num_col_dims) : y; - - dev_ctx.template Alloc(out); - auto z_dim = out->dims(); - if (z_dim.size() != 2) { - out->Resize({x_matrix.dims()[0], y_matrix.dims()[1]}); - } - - auto blas = phi::funcs::GetBlas(dev_ctx); - - blas.MatMul(x_matrix, y_matrix, out); - if (z_dim.size() != 2) { - out->Resize(z_dim); - } -} - -#ifdef PADDLE_WITH_CUDA - -template -void MatmulWithFlattenKernelInt8Impl(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - int x_num_col_dims, - int y_num_col_dims, - DenseTensor* out) { - PADDLE_ENFORCE_EQ( - x.dtype(), - DataType::INT8, - phi::errors::InvalidArgument( - "The type of input(x) used in int8 mul must be (%s) " - "does not match the " - "type of data (%s) currently contained in the container.", - phi::CppTypeToDataType::Type(), - x.dtype())); - PADDLE_ENFORCE_EQ( - y.dtype(), - DataType::INT8, - phi::errors::InvalidArgument( - "The type of input(y) used in int8 mul must be (%s) " - "does not match the " - "type of data (%s) currently contained in the container.", - phi::CppTypeToDataType::Type(), - y.dtype())); - - const DenseTensor x_matrix = - x.dims().size() > 2 ? phi::ReshapeToMatrix(x, x_num_col_dims) : x; - const DenseTensor y_matrix = - y.dims().size() > 2 ? phi::ReshapeToMatrix(y, y_num_col_dims) : y; - - PADDLE_ENFORCE_EQ( - x_matrix.dims()[1], - y_matrix.dims()[0], - phi::errors::InvalidArgument( - "X's numbers of columns must be equal to Y's numbers of rows." - "But received X has [%d] columns," - "received Y has [%d] rows", - x_matrix.dims()[1], - y_matrix.dims()[0])); - - PADDLE_ENFORCE_EQ((y_matrix.dims()[1] % 4 == 0 || y_matrix.dims()[1] == 1), - true, - phi::errors::InvalidArgument( - "The dimension size N used in int8 mul must be 1" - "or a multiple of 4 does not match the size (%d)" - "currently contained in the container.", - y_matrix.dims()[1])); - PADDLE_ENFORCE_EQ((x_matrix.dims()[1] % 4 == 0), - true, - phi::errors::InvalidArgument( - "The dimension size K used in int8 mul must be a" - "multiple of 4 does not match the size (%d) currently" - "contained in the container.", - x_matrix.dims()[1])); - - dev_ctx.template Alloc(out); - auto z_dim = out->dims(); - if (z_dim.size() != 2) { - out->Resize({x_matrix.dims()[0], y_matrix.dims()[1]}); - } - -#if CUDA_VERSION >= 11060 && 0 - using blaslt = phi::funcs::MatmulWithCublasLt; - - const int8_t* x_data = x_matrix.data(); - const int8_t* y_data = y_matrix.data(); - - std::vector x_dims = {x_matrix.dims()[0], x_matrix.dims()[1]}; - std::vector y_dims = {y_matrix.dims()[0], y_matrix.dims()[1]}; - phi::funcs::MatmulPlanner matmul_planner( - x_dims, - y_dims, - false, - false, - phi::CppTypeToDataType::Type(), - funcs::MatmulFusedType::kMatmul, - /* bias_data */ nullptr, - /* reserve_data */ nullptr, - /* use_addto */ false, - /* no_exchange */ true); - - blaslt::Run(dev_ctx, - x_data, - y_data, - dev_ctx.template Alloc(out), - x_matrix.dims()[0], - y_matrix.dims()[1], - x_matrix.dims()[1], - false, - false, - &matmul_planner); - - if (z_dim.size() != 2) { - out->Resize(z_dim); - } -#endif -} -#endif - -#ifdef PADDLE_WITH_CUDA -template -typename std::enable_if::value, - void>::type -DispatchMatmulWithFlattenInt8Kernel(const phi::GPUContext& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - int x_num_col_dims, - int y_num_col_dims, - DenseTensor* out) { - MatmulWithFlattenKernelInt8Impl( - dev_ctx, x, y, x_num_col_dims, y_num_col_dims, out); -} -#endif - -template -typename std::enable_if::value, - void>::type -DispatchMatmulWithFlattenInt8Kernel(const phi::CPUContext& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - int x_num_col_dims, - int y_num_col_dims, - DenseTensor* out) { - PADDLE_THROW(phi::errors::Unimplemented( - "MatmulWithFlatten with CPU is NOT implemented " - "yet.")); -} - -template -typename std::enable_if::value, void>::type -DispatchMatmulFlattenKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - int x_num_col_dims, - int y_num_col_dims, - DenseTensor* out) { - DispatchMatmulWithFlattenInt8Kernel( - dev_ctx, x, y, x_num_col_dims, y_num_col_dims, out); -} - -template -typename std::enable_if::value, void>::type -DispatchMatmulFlattenKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - int x_num_col_dims, - int y_num_col_dims, - DenseTensor* out) { - MatmulWithFlattenKernelImpl( - dev_ctx, x, y, x_num_col_dims, y_num_col_dims, out); -} - -template -void MatmulWithFlattenKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - int x_num_col_dims, - int y_num_col_dims, - DenseTensor* out) { - DispatchMatmulFlattenKernel( - dev_ctx, x, y, x_num_col_dims, y_num_col_dims, out); -} - -template -void LegacyMatmulKernel(const Context& ctx, - const DenseTensor& x, - const DenseTensor& y, - bool transpose_x, - bool transpose_y, - float alpha, - DenseTensor* out) { - MatmulKernel(ctx, x, y, transpose_x, transpose_y, out); - if (std::fabs(alpha - 1.f) > 1e-6f) { - ScaleKernel(ctx, *out, Scalar(alpha), Scalar(0), false, out); - } -} - -} // namespace phi diff --git a/backends/metax_gpu/kernels/impl/matmul_kernel_impl_maca.h b/backends/metax_gpu/kernels/impl/matmul_kernel_impl_maca.h deleted file mode 100644 index 9750abae5ca..00000000000 --- a/backends/metax_gpu/kernels/impl/matmul_kernel_impl_maca.h +++ /dev/null @@ -1,1696 +0,0 @@ -// 2024 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights -// Reserved. -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -// clang-format off -#include "glog/logging.h" - -#include "paddle/phi/common/memory_utils.h" -#include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/kernels/autotune/cache_base.h" -#include "paddle/phi/kernels/cast_kernel.h" -#include "../funcs/blas/blas.h" -#ifdef PADDLE_WITH_HIP -#include "paddle/phi/kernels/funcs/blas/blaslt_impl.hip.h" -#else -#include "../funcs/blas/blaslt_impl.cu.h" -#endif -#include "paddle/phi/kernels/funcs/complex_functors.h" -#include "paddle/phi/kernels/scale_kernel.h" -#if defined(PADDLE_WITH_CUDA) -#include "paddle/phi/kernels/funcs/cublaslt.h" -#include "paddle/phi/kernels/gpu/cuda_gemm_kernel.h" -#include "paddle/phi/kernels/transpose_kernel.h" -#elif defined(PADDLE_WITH_HIP) -#include "paddle/phi/kernels/funcs/hipblaslt.h" -#endif -#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 && 0 -#include "paddle/phi/kernels/autotune/auto_tune_base.h" -#endif -// clang-format on -namespace phi { - -static void GetBroadcastFromDims(const int x_ndim, - const std::int64_t* x_dims, - const int y_ndim, - const std::int64_t* y_dims, - std::int64_t* x_bd_dims, - std::int64_t* y_bd_dims, - std::int64_t* out_bd_dims) { - const int ndim = (std::max)(x_ndim, y_ndim); - std::fill(x_bd_dims, x_bd_dims + ndim - x_ndim, 1); - std::fill(y_bd_dims, y_bd_dims + ndim - y_ndim, 1); - std::copy(x_dims, x_dims + x_ndim, x_bd_dims + ndim - x_ndim); - std::copy(y_dims, y_dims + y_ndim, y_bd_dims + ndim - y_ndim); - - for (int i = 0; i < ndim; ++i) { - PADDLE_ENFORCE_EQ( - x_bd_dims[i] == y_bd_dims[i] || x_bd_dims[i] <= 1 || y_bd_dims[i] <= 1, - true, - phi::errors::InvalidArgument( - "Input(X) and Input(Y) has error dim. " - "X_broadcast's shape[%s] must be equal to Y_broadcast's shape[%s], " - "or X_broadcast's shape[%s] <= 1, or Y_broadcast's shape[%s] <= 1, " - "but received X_broadcast's shape[%s] = [%s]" - "received Y_broadcast's shape[%s] = [%s].", - i, - i, - i, - i, - i, - x_bd_dims[i], - i, - y_bd_dims[i])); - if (x_bd_dims[i] == 0 || y_bd_dims[i] == 0) { - out_bd_dims[i] = 0; - } else { - out_bd_dims[i] = (std::max)(x_bd_dims[i], y_bd_dims[i]); - } - } -} - -static int64_t GetIndexMessage(const int n, - const int64_t* dims, - const int64_t* index) { - int64_t sum = 0; - for (int i = 0; i < n; ++i) { - if (dims[i] > 1) { - sum = sum * dims[i] + index[i]; - } - } - return sum; -} - -static void IndexIncreaseFromDims(const int ndim, - const int64_t* dims, - int64_t* index) { - for (int i = ndim - 1; i >= 0; --i) { - ++index[i]; - if (index[i] >= dims[i]) { - index[i] -= dims[i]; - } else { - break; - } - } -} - -// The general implementation with blas. -template -void MatMulFunctionImplWithBlas( - const Context& dev_ctx, - const DenseTensor& X, - const DenseTensor& Y, - const std::vector& x_dims, - const std::vector& y_dims, - DenseTensor* Out, - bool trans_x, - bool trans_y, - bool flag = false, - phi::funcs::MatmulPlanner* matmul_planner UNUSED = nullptr) { - const int x_ndim = x_dims.size(); - const int y_ndim = y_dims.size(); - - // Get data ptr - const T* x_data = X.data(); - const T* y_data = Y.data(); - - auto blas = phi::funcs::GetBlas(dev_ctx); - - if (x_ndim == 1 && y_ndim == 1) { - const int M = X.numel(); - const int N = Y.numel(); - PADDLE_ENFORCE_EQ( - M, - N, - phi::errors::InvalidArgument( - "X's numbers must be equal to Y's numbers, " - "when X/Y's dims =1. But received X has [%d] elements, " - "received Y has [%d] elements.", - M, - N)); - VLOG(3) << "MatMul's case 1"; - Out->Resize(common::make_ddim({})); - dev_ctx.template Alloc(Out); - blas.GEMM(CblasNoTrans, - CblasTrans, - 1, - 1, - M, - static_cast(1), - y_data, - x_data, - static_cast(flag), - dev_ctx.template Alloc(Out)); - return; - } - - if (x_ndim == 1) { - const int N = X.numel(); - if (trans_y) { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 1], - N, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 1, - N, - y_ndim - 1, - y_dims[y_ndim - 1])); - } else { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 2], - N, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 2, - N, - y_ndim - 2, - y_dims[y_ndim - 2])); - } - std::vector out_dims(y_ndim - 1); - if (trans_y) { - std::copy_n(y_dims.cbegin(), y_ndim - 1, out_dims.begin()); - } else { - std::copy_n(y_dims.cbegin(), y_ndim - 2, out_dims.begin()); - out_dims.back() = y_dims.back(); - } - Out->ResizeAndAllocate(common::make_ddim(out_dims)); - dev_ctx.template Alloc(Out); - if (trans_y) { - const int M = Y.numel() / N; - VLOG(3) << "MatMul's case 2"; - blas.GEMV(false, - M, - N, - static_cast(1), - y_data, - x_data, - static_cast(flag), - dev_ctx.template Alloc(Out)); - } else { - const int M = y_dims[y_ndim - 1]; - const int batch_size = Y.numel() / (M * N); - if (batch_size == 1) { - VLOG(3) << "MatMul's case 3"; - blas.GEMV(true, - N, - M, - static_cast(1), - y_data, - x_data, - static_cast(flag), - dev_ctx.template Alloc(Out)); - } else { - VLOG(3) << "MatMul's case 4"; - blas.BatchedGEMM(CblasTrans, - CblasNoTrans, - M, - 1, - N, - static_cast(1), - y_data, - x_data, - static_cast(flag), - dev_ctx.template Alloc(Out), - batch_size, - M * N, - 0); - } - } - return; - } - - if (y_ndim == 1) { - const int N = Y.numel(); - if (trans_x) { - PADDLE_ENFORCE_EQ( - x_dims[x_ndim - 2], - N, - phi::errors::InvalidArgument("Input(X) has error dim." - "X'dims[%d] must be equal to %d" - "But received X'dims[%d] is %d", - x_ndim - 2, - N, - x_ndim - 2, - x_dims[x_ndim - 2])); - } else { - PADDLE_ENFORCE_EQ( - x_dims[x_ndim - 1], - N, - phi::errors::InvalidArgument("Input(X) has error dim." - "X'dims[%d] must be equal to %d" - "But received X'dims[%d] is %d", - x_ndim - 1, - N, - x_ndim - 1, - x_dims[x_ndim - 1])); - } - std::vector out_dims(x_ndim - 1); - if (trans_x) { - std::copy_n(x_dims.cbegin(), x_ndim - 2, out_dims.begin()); - out_dims.back() = x_dims.back(); - } else { - std::copy_n(x_dims.cbegin(), x_ndim - 1, out_dims.begin()); - } - Out->ResizeAndAllocate(common::make_ddim(out_dims)); - dev_ctx.template Alloc(Out); - - if (trans_x) { - const int M = x_dims[x_ndim - 1]; - const int batch_size = X.numel() / (M * N); - if (batch_size == 1) { - VLOG(3) << "MatMul's case 5"; - blas.GEMV(true, - N, - M, - static_cast(1), - x_data, - y_data, - static_cast(flag), - dev_ctx.template Alloc(Out)); - } else { - VLOG(3) << "MatMul's case 6"; - blas.BatchedGEMM(CblasTrans, - CblasNoTrans, - M, - 1, - N, - static_cast(1), - x_data, - y_data, - static_cast(flag), - dev_ctx.template Alloc(Out), - batch_size, - M * N, - 0); - } - } else { - const int M = X.numel() / N; - VLOG(3) << "MatMul's case 7"; - blas.GEMV(false, - M, - N, - static_cast(1), - x_data, - y_data, - static_cast(flag), - dev_ctx.template Alloc(Out)); - } - return; - } - - const int M = trans_x ? x_dims[x_ndim - 1] : x_dims[x_ndim - 2]; - const int K = trans_x ? x_dims[x_ndim - 2] : x_dims[x_ndim - 1]; - if (trans_y) { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 1], - K, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 1, - K, - y_ndim - 1, - y_dims[y_ndim - 1])); - } else { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 2], - K, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 2, - K, - y_ndim - 2, - y_dims[y_ndim - 2])); - } - const int N = trans_y ? y_dims[y_ndim - 2] : y_dims[y_ndim - 1]; - const int ndim = (std::max)(x_ndim, y_ndim); - std::vector x_broadcast_dims(ndim); - std::vector y_broadcast_dims(ndim); - std::vector out_broadcast_dims(ndim); - - GetBroadcastFromDims(x_ndim - 2, - x_dims.data(), - y_ndim - 2, - y_dims.data(), - x_broadcast_dims.data(), - y_broadcast_dims.data(), - out_broadcast_dims.data()); - out_broadcast_dims[ndim - 2] = M; - out_broadcast_dims[ndim - 1] = N; - - Out->ResizeAndAllocate(common::make_ddim(out_broadcast_dims)); - dev_ctx.template Alloc(Out); - - const int batch_dim = ndim - 2; - // broadcast message - const bool is_broadcast_dims = - !std::equal(x_broadcast_dims.cbegin(), - x_broadcast_dims.cbegin() + batch_dim, - y_broadcast_dims.cbegin()); - - const std::int64_t x_batch_size = - std::accumulate(x_broadcast_dims.cbegin(), - x_broadcast_dims.cbegin() + batch_dim, - 1LL, - std::multiplies()); - const std::int64_t y_batch_size = - std::accumulate(y_broadcast_dims.cbegin(), - y_broadcast_dims.cbegin() + batch_dim, - 1LL, - std::multiplies()); - const std::int64_t out_batch_size = - std::accumulate(out_broadcast_dims.cbegin(), - out_broadcast_dims.cbegin() + batch_dim, - 1LL, - std::multiplies()); - if (out_batch_size == 0) return; - if (x_batch_size == 1 && y_batch_size == 1) { - VLOG(3) << "MatMul's case 8"; - blas.GEMM(trans_x ? CblasTrans : CblasNoTrans, - trans_y ? CblasTrans : CblasNoTrans, - M, - N, - K, - static_cast(1), - x_data, - y_data, - static_cast(flag), - dev_ctx.template Alloc(Out)); - } else if (x_batch_size == 1) { - if (M == 1 && trans_y) { - VLOG(3) << "MatMul's case 9"; - blas.GEMV(false, - y_batch_size * N, - K, - static_cast(1), - y_data, - x_data, - static_cast(flag), - dev_ctx.template Alloc(Out)); - } else { - VLOG(3) << "MatMul's case 10"; - blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, - trans_y ? CblasTrans : CblasNoTrans, - M, - N, - K, - static_cast(1), - x_data, - y_data, - static_cast(flag), - dev_ctx.template Alloc(Out), - out_batch_size, - 0, - K * N); - } - } else if (y_batch_size == 1) { - if (!trans_x) { - VLOG(3) << "MatMul's case 11"; - blas.GEMM(CblasNoTrans, - trans_y ? CblasTrans : CblasNoTrans, - x_batch_size * M, - N, - K, - static_cast(1), - x_data, - y_data, - static_cast(flag), - dev_ctx.template Alloc(Out)); - } else { - VLOG(3) << "MatMul's case 12"; - blas.BatchedGEMM(CblasTrans, - trans_y ? CblasTrans : CblasNoTrans, - M, - N, - K, - static_cast(1), - x_data, - y_data, - static_cast(flag), - dev_ctx.template Alloc(Out), - out_batch_size, - M * K, - 0); - } - } else if (!is_broadcast_dims) { - VLOG(3) << "MatMul's case 13"; - blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, - trans_y ? CblasTrans : CblasNoTrans, - M, - N, - K, - static_cast(1), - x_data, - y_data, - static_cast(flag), - dev_ctx.template Alloc(Out), - out_batch_size, - M * K, - K * N); - } else { - // in the case, can't use stridedgemm - std::vector x_ptr(out_batch_size); - std::vector y_ptr(out_batch_size); - std::vector out_ptr(out_batch_size); - std::vector index(batch_dim, 0); - for (std::int64_t i = 0; i < out_batch_size; ++i) { - // using the index to get offset - const std::int64_t x_index = - GetIndexMessage(batch_dim, x_broadcast_dims.data(), index.data()); - const std::int64_t y_index = - GetIndexMessage(batch_dim, y_broadcast_dims.data(), index.data()); - - x_ptr[i] = x_data + x_index * M * K; - y_ptr[i] = y_data + y_index * K * N; - out_ptr[i] = dev_ctx.template Alloc(Out) + i * M * N; - IndexIncreaseFromDims(batch_dim, out_broadcast_dims.data(), index.data()); - } - VLOG(3) << "MatMul's case 14"; - blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, - trans_y ? CblasTrans : CblasNoTrans, - M, - N, - K, - static_cast(1), - x_ptr.data(), - y_ptr.data(), - static_cast(flag), - out_ptr.data(), - out_batch_size); - } -} - -#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 && 0 -// This is almost a copy from MatMulFunctionImplWithBlas, -// compare cublas with cublasLt kernels when Matmul autotune is on -template -void MatMulFunctionImplWithCublasLt( - const Context& dev_ctx, - const DenseTensor& X, - const DenseTensor& Y, - const std::vector& x_dims, - const std::vector& y_dims, - DenseTensor* Out, - bool trans_x, - bool trans_y, - bool flag = false, - phi::funcs::MatmulPlanner* matmul_planner = nullptr) { - const int x_ndim = x_dims.size(); - const int y_ndim = y_dims.size(); - const T* x_data = X.data(); - const T* y_data = Y.data(); - using blaslt = phi::funcs::MatmulWithCublasLt; - - if (x_ndim == 1 && y_ndim == 1) { - const int M = X.numel(); - const int N = Y.numel(); - PADDLE_ENFORCE_EQ( - M, - N, - phi::errors::InvalidArgument( - "X's numbers must be equal to Y's numbers," - "when X/Y's dims =1. But received X has [%d] elements," - "received Y has [%d] elements", - M, - N)); - - // MatMul's case 0 => vector * vector - Out->Resize(common::make_ddim({})); - dev_ctx.template Alloc(Out); - VLOG(3) << "MatMul with blaslt case 1"; - blaslt::Run(dev_ctx, - y_data, - x_data, - dev_ctx.template Alloc(Out), - 1, - 1, - M, - false, - true, - matmul_planner); - return; - } - - if (x_ndim == 1) { - const int N = X.numel(); - if (trans_y) { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 1], - N, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 1, - N, - y_ndim - 1, - y_dims[y_ndim - 1])); - } else { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 2], - N, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 2, - N, - y_ndim - 2, - y_dims[y_ndim - 2])); - } - std::vector out_dims(y_ndim - 1); - if (trans_y) { - std::copy_n(y_dims.cbegin(), y_ndim - 1, out_dims.begin()); - } else { - std::copy_n(y_dims.cbegin(), y_ndim - 2, out_dims.begin()); - out_dims.back() = y_dims.back(); - } - Out->ResizeAndAllocate(common::make_ddim(out_dims)); - dev_ctx.template Alloc(Out); - if (trans_y) { - const int M = Y.numel() / N; - VLOG(3) << "MatMul with blaslt 2"; - blaslt::Run(dev_ctx, - y_data, - x_data, - dev_ctx.template Alloc(Out), - M, - 1, - N, - false, - false, - matmul_planner); - } else { - const int M = y_dims[y_ndim - 1]; - const int batch_size = Y.numel() / (M * N); - if (batch_size == 1) { - VLOG(3) << "MatMul with blaslt 3"; - blaslt::Run(dev_ctx, - y_data, - x_data, - dev_ctx.template Alloc(Out), - M, - 1, - N, - true, - false, - matmul_planner); - } else { - VLOG(3) << "MatMul with blaslt 4"; - blaslt::RunWithBatch(dev_ctx, - y_data, - x_data, - dev_ctx.template Alloc(Out), - M, - 1, - N, - true, - false, - batch_size, - M * N, - 0, - M, - matmul_planner); - } - } - return; - } - - if (y_ndim == 1) { - const int N = Y.numel(); - if (trans_x) { - PADDLE_ENFORCE_EQ( - x_dims[x_ndim - 2], - N, - phi::errors::InvalidArgument("Input(X) has error dim." - "X'dims[%d] must be equal to %d" - "But received X'dims[%d] is %d", - x_ndim - 2, - N, - x_ndim - 2, - x_dims[x_ndim - 2])); - } else { - PADDLE_ENFORCE_EQ( - x_dims[x_ndim - 1], - N, - phi::errors::InvalidArgument("Input(X) has error dim." - "X'dims[%d] must be equal to %d" - "But received X'dims[%d] is %d", - x_ndim - 1, - N, - x_ndim - 1, - x_dims[x_ndim - 1])); - } - std::vector out_dims(x_ndim - 1); - if (trans_x) { - std::copy_n(x_dims.cbegin(), x_ndim - 2, out_dims.begin()); - out_dims.back() = x_dims.back(); - } else { - std::copy_n(x_dims.cbegin(), x_ndim - 1, out_dims.begin()); - } - Out->ResizeAndAllocate(common::make_ddim(out_dims)); - dev_ctx.template Alloc(Out); - - if (trans_x) { - const int M = x_dims[x_ndim - 1]; - const int batch_size = X.numel() / (M * N); - if (batch_size == 1) { - VLOG(3) << "MatMul with blaslt 5"; - blaslt::Run(dev_ctx, - x_data, - y_data, - dev_ctx.template Alloc(Out), - M, - 1, - N, - true, - false, - matmul_planner); - } else { - VLOG(3) << "MatMul with blaslt 6"; - blaslt::RunWithBatch(dev_ctx, - x_data, - y_data, - dev_ctx.template Alloc(Out), - M, - 1, - N, - true, - false, - batch_size, - M * N, - 0, - M, - matmul_planner); - } - } else { - const int M = X.numel() / N; - VLOG(3) << "MatMul with blaslt 7"; - blaslt::Run(dev_ctx, - x_data, - y_data, - dev_ctx.template Alloc(Out), - M, - 1, - N, - false, - false, - matmul_planner); - } - return; - } - - const int M = trans_x ? x_dims[x_ndim - 1] : x_dims[x_ndim - 2]; - const int K = trans_x ? x_dims[x_ndim - 2] : x_dims[x_ndim - 1]; - if (trans_y) { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 1], - K, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 1, - K, - y_ndim - 1, - y_dims[y_ndim - 1])); - } else { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 2], - K, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 2, - K, - y_ndim - 2, - y_dims[y_ndim - 2])); - } - const int N = trans_y ? y_dims[y_ndim - 2] : y_dims[y_ndim - 1]; - const int ndim = (std::max)(x_ndim, y_ndim); - std::vector x_broadcast_dims(ndim); - std::vector y_broadcast_dims(ndim); - std::vector out_broadcast_dims(ndim); - - GetBroadcastFromDims(x_ndim - 2, - x_dims.data(), - y_ndim - 2, - y_dims.data(), - x_broadcast_dims.data(), - y_broadcast_dims.data(), - out_broadcast_dims.data()); - out_broadcast_dims[ndim - 2] = M; - out_broadcast_dims[ndim - 1] = N; - - Out->ResizeAndAllocate(common::make_ddim(out_broadcast_dims)); - dev_ctx.template Alloc(Out); - - const int batch_dim = ndim - 2; - // broadcast message - const bool is_broadcast_dims = - !std::equal(x_broadcast_dims.cbegin(), - x_broadcast_dims.cbegin() + batch_dim, - y_broadcast_dims.cbegin()); - - const std::int64_t x_batch_size = - std::accumulate(x_broadcast_dims.cbegin(), - x_broadcast_dims.cbegin() + batch_dim, - 1LL, - std::multiplies()); - const std::int64_t y_batch_size = - std::accumulate(y_broadcast_dims.cbegin(), - y_broadcast_dims.cbegin() + batch_dim, - 1LL, - std::multiplies()); - const std::int64_t out_batch_size = - std::accumulate(out_broadcast_dims.cbegin(), - out_broadcast_dims.cbegin() + batch_dim, - 1LL, - std::multiplies()); - if (out_batch_size == 0) return; - if (x_batch_size == 1 && y_batch_size == 1) { - VLOG(3) << "MatMul with blaslt 8"; - blaslt::Run(dev_ctx, - x_data, - y_data, - dev_ctx.template Alloc(Out), - M, - N, - K, - trans_x, - trans_y, - matmul_planner); - } else if (x_batch_size == 1) { - if (M == 1 && trans_y) { - VLOG(3) << "MatMul with blaslt 9"; - blaslt::Run(dev_ctx, - y_data, - x_data, - dev_ctx.template Alloc(Out), - y_batch_size * N, - 1, - K, - false, - false, - matmul_planner); - } else { - VLOG(3) << "MatMul with blaslt 10"; - blaslt::RunWithBatch(dev_ctx, - x_data, - y_data, - dev_ctx.template Alloc(Out), - M, - N, - K, - trans_x, - trans_y, - out_batch_size, - 0, - K * N, - M * N, - matmul_planner); - } - } else if (y_batch_size == 1) { - if (!trans_x) { - VLOG(3) << "MatMul with blaslt 11"; - blaslt::Run(dev_ctx, - x_data, - y_data, - dev_ctx.template Alloc(Out), - x_batch_size * M, - N, - K, - false, - trans_y, - matmul_planner); - } else { - VLOG(3) << "MatMul with blaslt 12"; - blaslt::RunWithBatch(dev_ctx, - x_data, - y_data, - dev_ctx.template Alloc(Out), - M, - N, - K, - true, - trans_y, - out_batch_size, - M * K, - 0, - M * N, - matmul_planner); - } - } else if (!is_broadcast_dims) { - VLOG(3) << "MatMul with blaslt 13"; - blaslt::RunWithBatch(dev_ctx, - x_data, - y_data, - dev_ctx.template Alloc(Out), - M, - N, - K, - trans_x, - trans_y, - out_batch_size, - M * K, - K * N, - M * N, - matmul_planner); - } else { - // in the case, can't use stridedgemm - std::vector x_ptr(out_batch_size); - std::vector y_ptr(out_batch_size); - std::vector out_ptr(out_batch_size); - std::vector index(batch_dim, 0); - for (std::int64_t i = 0; i < out_batch_size; ++i) { - // using the index to get offset - const std::int64_t x_index = - GetIndexMessage(batch_dim, x_broadcast_dims.data(), index.data()); - const std::int64_t y_index = - GetIndexMessage(batch_dim, y_broadcast_dims.data(), index.data()); - - x_ptr[i] = x_data + x_index * M * K; - y_ptr[i] = y_data + y_index * K * N; - out_ptr[i] = dev_ctx.template Alloc(Out) + i * M * N; - IndexIncreaseFromDims(batch_dim, out_broadcast_dims.data(), index.data()); - } - VLOG(3) << "MatMul with blaslt 14"; - blaslt::RunWithBatch(dev_ctx, - x_ptr.data(), - y_ptr.data(), - out_ptr.data(), - M, - N, - K, - trans_x, - trans_y, - out_batch_size, - matmul_planner); - } -} -#endif - -template -struct MatMulDispatcher { - void operator()(const Context& ctx, - const DenseTensor& x, - const DenseTensor& y, - const std::vector& x_dims, - const std::vector& y_dims, - DenseTensor* out, - bool trans_x, - bool trans_y, - bool flag = false) { - MatMulFunctionImplWithBlas( - ctx, x, y, x_dims, y_dims, out, trans_x, trans_y, flag); - } -}; - -#ifdef PADDLE_WITH_CUDA -template -struct MatMulDispatcher { - void operator()(const phi::GPUContext& ctx, - const DenseTensor& x, - const DenseTensor& y, - const std::vector& x_dims, - const std::vector& y_dims, - DenseTensor* out, - bool trans_x, - bool trans_y, - bool flag = false) { -#if CUDA_VERSION >= 11060 && 0 - auto* tuner = phi::autotune::MakeMatmulTuner( - MatMulFunctionImplWithBlas); - tuner->AddCallBack(MatMulFunctionImplWithCublasLt); - phi::funcs::MatmulPlanner matmul_planner(x_dims, - y_dims, - trans_x, - trans_y, - phi::CppTypeToDataType::Type(), - funcs::MatmulFusedType::kMatmul, - /* bias_data */ nullptr, - /* reserve_data */ nullptr, - /* use_addto */ flag, - /* no_exchange */ true); - tuner->Run(ctx, - matmul_planner.GetKey(), - ctx, - x, - y, - x_dims, - y_dims, - out, - trans_x, - trans_y, - flag, - &matmul_planner); -#else - MatMulFunctionImplWithBlas( - ctx, x, y, x_dims, y_dims, out, trans_x, trans_y, flag); -#endif - } -}; - -#endif // PADDLE_WITH_CUDA - -template -void MatMulFunction(const Context& ctx, - const DenseTensor& x, - const DenseTensor& y, - const std::vector& x_dims, - const std::vector& y_dims, - DenseTensor* out, - bool trans_x, - bool trans_y, - bool flag = false) { - MatMulDispatcher()( - ctx, x, y, x_dims, y_dims, out, trans_x, trans_y, flag); -} - -template -bool MatMulInt8Function(const Context& ctx, - const DenseTensor& x, - const DenseTensor& y, - const std::vector& x_dims, - const std::vector& y_dims, - DenseTensor* out, - bool trans_x, - bool trans_y) { - return false; -} - -#ifdef PADDLE_WITH_CUDA -template <> -bool inline MatMulInt8Function(const phi::GPUContext& ctx, - const DenseTensor& x, - const DenseTensor& y, - const std::vector& x_dims, - const std::vector& y_dims, - DenseTensor* out, - bool trans_x, - bool trans_y) { - if (x.dtype() != DataType::INT8 || y.dtype() != DataType::INT8) { - return false; - } -#if CUDA_VERSION >= 11060 && 0 - const int x_ndim = x_dims.size(); - const int y_ndim = y_dims.size(); - const int8_t* x_data = x.data(); - const int8_t* y_data = y.data(); - using blaslt = phi::funcs::MatmulWithCublasLt; - - phi::funcs::MatmulPlanner matmul_planner( - x_dims, - y_dims, - trans_x, - trans_y, - phi::CppTypeToDataType::Type(), - funcs::MatmulFusedType::kMatmul, - /* bias_data */ nullptr, - /* reserve_data */ nullptr, - /* use_addto */ false, - /* no_exchange */ true); - - if (x_ndim == 1 && y_ndim == 1) { - const int M = x.numel(); - const int N = y.numel(); - PADDLE_ENFORCE_EQ( - M, - N, - phi::errors::InvalidArgument( - "X's numbers must be equal to Y's numbers," - "when X/Y's dims =1. But received X has [%d] elements," - "received Y has [%d] elements", - M, - N)); - if (!(M % 4 == 0)) { - return false; - } - - out->Resize(common::make_ddim({})); - ctx.template Alloc(out); - blaslt::Run(ctx, - y_data, - x_data, - ctx.template Alloc(out), - 1, - 1, - M, - false, - true, - &matmul_planner); - return true; - } - if (x_ndim == 1) { - const int N = x.numel(); - if (trans_y) { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 1], - N, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 1, - N, - y_ndim - 1, - y_dims[y_ndim - 1])); - if (!(N % 4 == 0)) { - return false; - } - } else { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 2], - N, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 2, - N, - y_ndim - 2, - y_dims[y_ndim - 2])); - const int M = y.numel() / N; - if (!(M == 1 || M % 4 == 0)) { - return false; - } - } - std::vector out_dims(y_ndim - 1); - if (trans_y) { - std::copy_n(y_dims.cbegin(), y_ndim - 1, out_dims.begin()); - } else { - std::copy_n(y_dims.cbegin(), y_ndim - 2, out_dims.begin()); - out_dims.back() = y_dims.back(); - } - out->ResizeAndAllocate(common::make_ddim(out_dims)); - ctx.template Alloc(out); - if (trans_y) { - const int M = y.numel() / N; - blaslt::Run(ctx, - y_data, - x_data, - ctx.template Alloc(out), - M, - 1, - N, - false, - false, - &matmul_planner); - } else { - const int M = y_dims[y_ndim - 1]; - const int batch_size = y.numel() / (M * N); - if (batch_size == 1) { - blaslt::Run(ctx, - y_data, - x_data, - ctx.template Alloc(out), - M, - 1, - N, - true, - false, - &matmul_planner); - } else { - blaslt::RunWithBatch(ctx, - y_data, - x_data, - ctx.template Alloc(out), - M, - 1, - N, - true, - false, - batch_size, - M * N, - 0, - M, - &matmul_planner); - } - } - return true; - } - - if (y_ndim == 1) { - const int N = y.numel(); - if (trans_x) { - PADDLE_ENFORCE_EQ( - x_dims[x_ndim - 2], - N, - phi::errors::InvalidArgument("Input(X) has error dim." - "X'dims[%d] must be equal to %d" - "But received X'dims[%d] is %d", - x_ndim - 2, - N, - x_ndim - 2, - x_dims[x_ndim - 2])); - const int M = x.numel() / N; - if (!((M == 1 || M % 4 == 0))) { - return false; - } - } else { - PADDLE_ENFORCE_EQ( - x_dims[x_ndim - 1], - N, - phi::errors::InvalidArgument("Input(X) has error dim." - "X'dims[%d] must be equal to %d" - "But received X'dims[%d] is %d", - x_ndim - 1, - N, - x_ndim - 1, - x_dims[x_ndim - 1])); - if (N % 4 != 0) { - return false; - } - } - std::vector out_dims(x_ndim - 1); - if (trans_x) { - std::copy_n(x_dims.cbegin(), x_ndim - 2, out_dims.begin()); - out_dims.back() = x_dims.back(); - } else { - std::copy_n(x_dims.cbegin(), x_ndim - 1, out_dims.begin()); - } - out->ResizeAndAllocate(common::make_ddim(out_dims)); - ctx.template Alloc(out); - - if (trans_x) { - const int M = x_dims[x_ndim - 1]; - const int batch_size = x.numel() / (M * N); - if (batch_size == 1) { - blaslt::Run(ctx, - x_data, - y_data, - ctx.template Alloc(out), - M, - 1, - N, - true, - false, - &matmul_planner); - } else { - blaslt::RunWithBatch(ctx, - x_data, - y_data, - ctx.template Alloc(out), - M, - 1, - N, - true, - false, - batch_size, - M * N, - 0, - M, - &matmul_planner); - } - } else { - const int M = x.numel() / N; - blaslt::Run(ctx, - x_data, - y_data, - ctx.template Alloc(out), - M, - 1, - N, - false, - false, - &matmul_planner); - } - return true; - } - - const int M = trans_x ? x_dims[x_ndim - 1] : x_dims[x_ndim - 2]; - const int K = trans_x ? x_dims[x_ndim - 2] : x_dims[x_ndim - 1]; - if (trans_y) { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 1], - K, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 1, - K, - y_ndim - 1, - y_dims[y_ndim - 1])); - } else { - PADDLE_ENFORCE_EQ( - y_dims[y_ndim - 2], - K, - phi::errors::InvalidArgument("Input(Y) has error dim. " - "Y'dims[%d] must be equal to %d, " - "but received Y'dims[%d] is %d.", - y_ndim - 2, - K, - y_ndim - 2, - y_dims[y_ndim - 2])); - } - const int N = trans_y ? y_dims[y_ndim - 2] : y_dims[y_ndim - 1]; - const int ndim = (std::max)(x_ndim, y_ndim); - std::vector x_broadcast_dims(ndim); - std::vector y_broadcast_dims(ndim); - std::vector out_broadcast_dims(ndim); - GetBroadcastFromDims(x_ndim - 2, - x_dims.data(), - y_ndim - 2, - y_dims.data(), - x_broadcast_dims.data(), - y_broadcast_dims.data(), - out_broadcast_dims.data()); - out_broadcast_dims[ndim - 2] = M; - out_broadcast_dims[ndim - 1] = N; - - out->ResizeAndAllocate(common::make_ddim(out_broadcast_dims)); - ctx.template Alloc(out); - - const int batch_dim = ndim - 2; - // broadcast message - const bool is_broadcast_dims = - !std::equal(x_broadcast_dims.cbegin(), - x_broadcast_dims.cbegin() + batch_dim, - y_broadcast_dims.cbegin()); - - const std::int64_t x_batch_size = - std::accumulate(x_broadcast_dims.cbegin(), - x_broadcast_dims.cbegin() + batch_dim, - 1LL, - std::multiplies()); - const std::int64_t y_batch_size = - std::accumulate(y_broadcast_dims.cbegin(), - y_broadcast_dims.cbegin() + batch_dim, - 1LL, - std::multiplies()); - const std::int64_t out_batch_size = - std::accumulate(out_broadcast_dims.cbegin(), - out_broadcast_dims.cbegin() + batch_dim, - 1LL, - std::multiplies()); - if (out_batch_size == 0) return true; - - if (x_batch_size == 1 && M == 1 && trans_y) { - if (!(K % 4 == 0)) { - return false; - } - } else if (!trans_x && !trans_y) { - if (!(N % 4 == 0 || N == 1) || !(K % 4 == 0) || (M == 1 && N == 1)) { - return false; - } - } else if (!trans_x && trans_y) { - if (!(K % 4 == 0)) { - return false; - } - } else if (trans_x && !trans_y) { - if (!(M % 4 == 0 || M == 1) || !(N % 4 == 0 || N == 1)) { - return false; - } - } else { - if (!(M % 4 == 0 || M == 1) || !(K % 4 == 0)) { - return false; - } - } - if (x_batch_size == 1 && y_batch_size == 1) { - blaslt::Run(ctx, - x_data, - y_data, - ctx.template Alloc(out), - M, - N, - K, - trans_x, - trans_y, - &matmul_planner); - } else if (x_batch_size == 1) { - if (M == 1 && trans_y) { - blaslt::Run(ctx, - y_data, - x_data, - ctx.template Alloc(out), - y_batch_size * N, - 1, - K, - false, - false, - &matmul_planner); - } else { - blaslt::RunWithBatch(ctx, - x_data, - y_data, - ctx.template Alloc(out), - M, - N, - K, - trans_x, - trans_y, - out_batch_size, - 0, - K * N, - M * N, - &matmul_planner); - } - } else if (y_batch_size == 1) { - if (!trans_x) { - blaslt::Run(ctx, - x_data, - y_data, - ctx.template Alloc(out), - x_batch_size * M, - N, - K, - false, - trans_y, - &matmul_planner); - } else { - blaslt::RunWithBatch(ctx, - x_data, - y_data, - ctx.template Alloc(out), - M, - N, - K, - true, - trans_y, - out_batch_size, - M * K, - 0, - M * N, - &matmul_planner); - } - } else if (!is_broadcast_dims) { - blaslt::RunWithBatch(ctx, - x_data, - y_data, - ctx.template Alloc(out), - M, - N, - K, - trans_x, - trans_y, - out_batch_size, - M * K, - K * N, - M * N, - &matmul_planner); - } else { - // in the case, can't use stridedgemm - std::vector x_ptr(out_batch_size); - std::vector y_ptr(out_batch_size); - std::vector out_ptr(out_batch_size); - std::vector index(batch_dim, 0); - for (std::int64_t i = 0; i < out_batch_size; ++i) { - // using the index to get offset - const std::int64_t x_index = - GetIndexMessage(batch_dim, x_broadcast_dims.data(), index.data()); - const std::int64_t y_index = - GetIndexMessage(batch_dim, y_broadcast_dims.data(), index.data()); - - x_ptr[i] = x_data + x_index * M * K; - y_ptr[i] = y_data + y_index * K * N; - out_ptr[i] = ctx.template Alloc(out) + i * M * N; - IndexIncreaseFromDims(batch_dim, out_broadcast_dims.data(), index.data()); - } - blaslt::RunWithBatch(ctx, - x_ptr.data(), - y_ptr.data(), - out_ptr.data(), - M, - N, - K, - trans_x, - trans_y, - out_batch_size, - &matmul_planner); - } - return true; -#else - return false; -#endif -} -#endif - -template -typename std::enable_if::value>::type -MatmulJudgeDtypeKernel(const Context& ctx, - const DenseTensor& x, - const DenseTensor& y, - const std::vector& x_dims, - const std::vector& y_dims, - DenseTensor* out, - bool transpose_x, - bool transpose_y) { - bool try_matmul_int8 = MatMulInt8Function( - ctx, x, y, x_dims, y_dims, out, transpose_x, transpose_y); - if (try_matmul_int8) { - return; - } - auto x_tmp = phi::Cast(ctx, x, phi::DataType::FLOAT32); - auto y_tmp = phi::Cast(ctx, y, phi::DataType::FLOAT32); - DenseTensor out_tmp; - MatMulFunction( - ctx, x_tmp, y_tmp, x_dims, y_dims, &out_tmp, transpose_x, transpose_y); - if (x.dtype() == phi::DataType::INT8) { - phi::CastKernel(ctx, out_tmp, phi::DataType::INT32, out); - return; - } - phi::CastKernel(ctx, out_tmp, x.dtype(), out); -} - -template -typename std::enable_if::value>::type -MatmulJudgeDtypeKernel(const Context& ctx, - const DenseTensor& x, - const DenseTensor& y, - const std::vector& x_dims, - const std::vector& y_dims, - DenseTensor* out, - bool transpose_x, - bool transpose_y) { - MatMulFunction( - ctx, x, y, x_dims, y_dims, out, transpose_x, transpose_y); -} - -template -void MatmulKernel(const Context& ctx, - const DenseTensor& x, - const DenseTensor& y, - bool transpose_x, - bool transpose_y, - DenseTensor* out) { - PADDLE_ENFORCE_NE( - common::product(x.dims()), - 0, - phi::errors::InvalidArgument("The Input(X) dims size must not be equal 0," - " but reviced dims size is 0. ")); - PADDLE_ENFORCE_NE( - common::product(y.dims()), - 0, - phi::errors::InvalidArgument("The Input(Y) dims size must not be equal 0," - " but reviced dims size is 0. ")); - const std::vector x_dims = common::vectorize(x.dims()); - const std::vector y_dims = common::vectorize(y.dims()); - MatmulJudgeDtypeKernel( - ctx, x, y, x_dims, y_dims, out, transpose_x, transpose_y); -} - -template -void MatmulWithFlattenKernelImpl(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - int x_num_col_dims, - int y_num_col_dims, - DenseTensor* out) { - const DenseTensor x_matrix = - x.dims().size() > 2 ? phi::ReshapeToMatrix(x, x_num_col_dims) : x; - const DenseTensor y_matrix = - y.dims().size() > 2 ? phi::ReshapeToMatrix(y, y_num_col_dims) : y; - - dev_ctx.template Alloc(out); - auto z_dim = out->dims(); - if (z_dim.size() != 2) { - out->Resize({x_matrix.dims()[0], y_matrix.dims()[1]}); - } - - auto blas = phi::funcs::GetBlas(dev_ctx); - - blas.MatMul(x_matrix, y_matrix, out); - if (z_dim.size() != 2) { - out->Resize(z_dim); - } -} - -#ifdef PADDLE_WITH_CUDA - -template -void MatmulWithFlattenKernelInt8Impl(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - int x_num_col_dims, - int y_num_col_dims, - DenseTensor* out) { - PADDLE_ENFORCE_EQ( - x.dtype(), - DataType::INT8, - phi::errors::InvalidArgument( - "The type of input(x) used in int8 mul must be (%s) " - "does not match the " - "type of data (%s) currently contained in the container.", - phi::CppTypeToDataType::Type(), - x.dtype())); - PADDLE_ENFORCE_EQ( - y.dtype(), - DataType::INT8, - phi::errors::InvalidArgument( - "The type of input(y) used in int8 mul must be (%s) " - "does not match the " - "type of data (%s) currently contained in the container.", - phi::CppTypeToDataType::Type(), - y.dtype())); - - const DenseTensor x_matrix = - x.dims().size() > 2 ? phi::ReshapeToMatrix(x, x_num_col_dims) : x; - const DenseTensor y_matrix = - y.dims().size() > 2 ? phi::ReshapeToMatrix(y, y_num_col_dims) : y; - - PADDLE_ENFORCE_EQ( - x_matrix.dims()[1], - y_matrix.dims()[0], - phi::errors::InvalidArgument( - "X's numbers of columns must be equal to Y's numbers of rows." - "But received X has [%d] columns," - "received Y has [%d] rows", - x_matrix.dims()[1], - y_matrix.dims()[0])); - - PADDLE_ENFORCE_EQ((y_matrix.dims()[1] % 4 == 0 || y_matrix.dims()[1] == 1), - true, - phi::errors::InvalidArgument( - "The dimension size N used in int8 mul must be 1" - "or a multiple of 4 does not match the size (%d)" - "currently contained in the container.", - y_matrix.dims()[1])); - PADDLE_ENFORCE_EQ((x_matrix.dims()[1] % 4 == 0), - true, - phi::errors::InvalidArgument( - "The dimension size K used in int8 mul must be a" - "multiple of 4 does not match the size (%d) currently" - "contained in the container.", - x_matrix.dims()[1])); - - dev_ctx.template Alloc(out); - auto z_dim = out->dims(); - if (z_dim.size() != 2) { - out->Resize({x_matrix.dims()[0], y_matrix.dims()[1]}); - } - -#if CUDA_VERSION >= 11060 && 0 - using blaslt = phi::funcs::MatmulWithCublasLt; - - const int8_t* x_data = x_matrix.data(); - const int8_t* y_data = y_matrix.data(); - - std::vector x_dims = {x_matrix.dims()[0], x_matrix.dims()[1]}; - std::vector y_dims = {y_matrix.dims()[0], y_matrix.dims()[1]}; - phi::funcs::MatmulPlanner matmul_planner( - x_dims, - y_dims, - false, - false, - phi::CppTypeToDataType::Type(), - funcs::MatmulFusedType::kMatmul, - /* bias_data */ nullptr, - /* reserve_data */ nullptr, - /* use_addto */ false, - /* no_exchange */ true); - - blaslt::Run(dev_ctx, - x_data, - y_data, - dev_ctx.template Alloc(out), - x_matrix.dims()[0], - y_matrix.dims()[1], - x_matrix.dims()[1], - false, - false, - &matmul_planner); - - if (z_dim.size() != 2) { - out->Resize(z_dim); - } -#endif -} -#endif - -#ifdef PADDLE_WITH_CUDA -template -typename std::enable_if::value, - void>::type -DispatchMatmulWithFlattenInt8Kernel(const phi::GPUContext& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - int x_num_col_dims, - int y_num_col_dims, - DenseTensor* out) { - MatmulWithFlattenKernelInt8Impl( - dev_ctx, x, y, x_num_col_dims, y_num_col_dims, out); -} -#endif - -template -typename std::enable_if::value, - void>::type -DispatchMatmulWithFlattenInt8Kernel(const phi::CPUContext& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - int x_num_col_dims, - int y_num_col_dims, - DenseTensor* out) { - PADDLE_THROW(phi::errors::Unimplemented( - "MatmulWithFlatten with CPU is NOT implemented " - "yet.")); -} - -template -typename std::enable_if::value, void>::type -DispatchMatmulFlattenKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - int x_num_col_dims, - int y_num_col_dims, - DenseTensor* out) { - DispatchMatmulWithFlattenInt8Kernel( - dev_ctx, x, y, x_num_col_dims, y_num_col_dims, out); -} - -template -typename std::enable_if::value, void>::type -DispatchMatmulFlattenKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - int x_num_col_dims, - int y_num_col_dims, - DenseTensor* out) { - MatmulWithFlattenKernelImpl( - dev_ctx, x, y, x_num_col_dims, y_num_col_dims, out); -} - -template -void MatmulWithFlattenKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - int x_num_col_dims, - int y_num_col_dims, - DenseTensor* out) { - DispatchMatmulFlattenKernel( - dev_ctx, x, y, x_num_col_dims, y_num_col_dims, out); -} - -} // namespace phi diff --git a/backends/metax_gpu/kernels/impl/metax_weight_quantize_kernel_impl.h b/backends/metax_gpu/kernels/impl/metax_weight_quantize_kernel_impl.h new file mode 100644 index 00000000000..b305ec96a30 --- /dev/null +++ b/backends/metax_gpu/kernels/impl/metax_weight_quantize_kernel_impl.h @@ -0,0 +1,282 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/common/enforce.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/aligned_vector.h" +#include "paddle/phi/kernels/funcs/common_shape.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +void show_2d_cpu_tensor(const DenseTensor& tensor, + const int64_t row_num = 3, + const int64_t col_num = 3) { + const int64_t rows = tensor.dims()[0]; + const int64_t cols = tensor.dims()[1]; + printf("\nTensor shape = [%d, %d]\n", rows, cols); + + const int8_t* cpu_ptr = tensor.data(); + + for (int r = 0; r < row_num; r++) { + for (int c = 0; c < col_num; c++) { + int8_t val = *(cpu_ptr + r * cols + c); + printf("%d ", val); + } + printf("\n"); + } + printf("\n\n"); +} + +void show_2d_gpu_tensor(const CustomContext& dev_ctx, + const DenseTensor& tensor, + const int64_t row_num = 3, + const int64_t col_num = 3) { + phi::CPUPlace cpu_place; + + DenseTensor cpu_tensor; + phi::Copy(dev_ctx, tensor, cpu_place, true, &cpu_tensor); + + const int64_t rows = cpu_tensor.dims()[0]; + const int64_t cols = cpu_tensor.dims()[1]; + printf("\nTensor shape = [%d, %d]\n", rows, cols); + + const int8_t* cpu_ptr = cpu_tensor.data(); + + for (int r = 0; r < row_num; r++) { + for (int c = 0; c < col_num; c++) { + int8_t val = *(cpu_ptr + r * cols + c); + printf("%d ", val); + } + printf("\n"); + } + printf("\n\n"); +} + +void cpu_2d_tensor_transpose(const DenseTensor& input_data, + DenseTensor* transposed_data) { + const int64_t input_data_rows = input_data.dims()[0]; + const int64_t input_data_cols = input_data.dims()[1]; + + const int8_t* input_data_ptr = input_data.data(); + int8_t* transposed_data_ptr = transposed_data->data(); + + for (int64_t r = 0; r < input_data_rows; r++) { + for (int64_t c = 0; c < input_data_cols; c++) { + *(transposed_data_ptr + r + c * input_data_rows) = + *(input_data_ptr + r * input_data_cols + c); + } + } +} + +void cpu_int4_quanted_weight_raw_unpack(const DenseTensor& packed_data, + DenseTensor* unpacked_data) { + const int64_t packed_data_rows = packed_data.dims()[0]; + const int64_t packed_data_cols = packed_data.dims()[1]; + + const int8_t* packed_data_ptr = packed_data.data(); + int8_t* unpacked_data_ptr = unpacked_data->data(); + + for (int64_t c = 0; c < packed_data_cols; c++) { + for (int64_t r = 0; r < packed_data_rows; r++) { + int8_t val = *(packed_data_ptr + r * packed_data_cols + c); + int8_t low_int4 = val & 0x0f; + int8_t hight_int4 = (val >> 4) & 0x0f; + + *(unpacked_data_ptr + (2 * r) * packed_data_cols + c) = + low_int4 >= 8 ? low_int4 - 16 : low_int4; + *(unpacked_data_ptr + (2 * r + 1) * packed_data_cols + c) = + hight_int4 >= 8 ? hight_int4 - 16 : hight_int4; + } + } +} + +void cpu_int4_quanted_weight_col_pack(const DenseTensor& unpacked_data, + DenseTensor* packed_data) { + const int64_t packed_data_rows = packed_data->dims()[0]; + const int64_t packed_data_cols = packed_data->dims()[1]; + + int8_t* packed_data_ptr = packed_data->data(); + const int8_t* unpacked_data_ptr = unpacked_data.data(); + + for (int64_t r = 0; r < packed_data_rows; r++) { + for (int64_t c = 0; c < packed_data_cols; c++) { + int8_t low_int4 = *(unpacked_data_ptr + 2 * r * packed_data_cols + 2 * c); + int8_t hight_int4 = + *(unpacked_data_ptr + 2 * r * packed_data_cols + 2 * c + 1); + + low_int4 = low_int4 < 0 ? low_int4 + 16 : low_int4; + hight_int4 = hight_int4 < 0 ? hight_int4 + 16 : hight_int4; + + *(packed_data_ptr + r * packed_data_cols + c) = + ((hight_int4 << 4) & 0xf0) | (low_int4 & 0x0f); + } + } +} + +void cpu_int4_quantized_weight_layout_trans_impl( + const CustomContext& dev_ctx, + const std::vector& shape, + DenseTensor* out) { + const int64_t m = shape[0]; + const int64_t n = shape[1]; + + phi::CPUPlace cpu_place; + + out->Resize({m / 2, n}); + + DenseTensor out_cpu_tensor; + phi::Copy(dev_ctx, (*out), cpu_place, true, &out_cpu_tensor); + + // raw unpack + DenseTensor raw_unpack_tensor; + raw_unpack_tensor.Resize({out_cpu_tensor.dims()[0] * 2, n}); + raw_unpack_tensor.mutable_data(cpu_place); + cpu_int4_quanted_weight_raw_unpack(out_cpu_tensor, &raw_unpack_tensor); + + // transpose + DenseTensor transposed_tensor; + transposed_tensor.Resize( + {raw_unpack_tensor.dims()[1], raw_unpack_tensor.dims()[0]}); + transposed_tensor.mutable_data(cpu_place); + cpu_2d_tensor_transpose(raw_unpack_tensor, &transposed_tensor); + + // col pack + out_cpu_tensor.Resize( + {transposed_tensor.dims()[0], transposed_tensor.dims()[1] / 2}); + cpu_int4_quanted_weight_col_pack(transposed_tensor, &out_cpu_tensor); + + out_cpu_tensor.Resize({n / 2, m}); + out->Resize({n / 2, m}); + phi::Copy(dev_ctx, out_cpu_tensor, dev_ctx.GetPlace(), true, out); +} + +__global__ void int4_quanted_matrix_raw_unpack_kernel(const int8_t* mat, + int8_t* unpack_mat, + int M, + int N) { + int global_idx = blockIdx.x * blockDim.x + threadIdx.x; + + int i = global_idx / N; + int j = global_idx % N; + + if (global_idx >= M * N) { + return; + } + + int8_t val = mat[global_idx]; + int8_t low = val & 0x0F; + int8_t mask = ((low & 0x80) == 0) & ((low & 0x78) != 0); + low -= 16 * mask; + + int8_t high = (val >> 4) & 0x0F; + mask = ((high & 0x80) == 0) & ((high & 0x78) != 0); + high -= 16 * mask; + + int output_global_idx0 = (2 * i) * N + j; + int output_global_idx1 = (2 * i + 1) * N + j; + + unpack_mat[output_global_idx0] = low; + unpack_mat[output_global_idx1] = high; +} + +__global__ void int4_quanted_matrix_col_pack_kernel(const int8_t* mat, + int8_t* pack_mat, + int M, + int N) { + int global_idx = blockIdx.x * blockDim.x + threadIdx.x; + + int i = global_idx / N; + int j = global_idx % N; + + if (global_idx >= M * N) { + return; + } + + int mat_global_idx0 = i * 2 * N + 2 * j; + int mat_global_idx1 = i * 2 * N + 2 * j + 1; + + int8_t low = mat[mat_global_idx0] & 0x0F; + low = low + ((low >> 3) & 1) * 16; + + int8_t high = mat[mat_global_idx1] & 0x0F; + high = high + ((high >> 3) & 1) * 16; + + pack_mat[global_idx] = ((high << 4) & 0xf0) | (low & 0x0f); +} + +void gpu_int4_quantized_weight_layout_trans_impl( + const CustomContext& dev_ctx, + const std::vector& shape, + DenseTensor* out) { + int64_t total_m = shape[0]; + int64_t total_n = shape[1]; + out->Resize({total_m / 2, total_n}); + + DenseTensor unpack_mat(out->type()); + unpack_mat.Resize({total_m, total_n}); + dev_ctx.template Alloc(&unpack_mat); + + constexpr int kBlockSize = 64; + int64_t kGridSize = (out->numel() + kBlockSize - 1) / kBlockSize; + int4_quanted_matrix_raw_unpack_kernel<<>>( + out->data(), + unpack_mat.data(), + out->dims()[0], + out->dims()[1]); + + DenseTensor transposed_tensor; + transposed_tensor.Resize({unpack_mat.dims()[1], unpack_mat.dims()[0]}); + dev_ctx.template Alloc(&transposed_tensor); + std::vector axis = {1, 0}; + funcs::Transpose trans; + trans(dev_ctx, unpack_mat, &transposed_tensor, axis); + + out->Resize({transposed_tensor.dims()[0], transposed_tensor.dims()[1] / 2}); + int4_quanted_matrix_col_pack_kernel<<>>( + transposed_tensor.data(), + out->data(), + out->dims()[0], + out->dims()[1]); + + out->Resize({total_n / 2, total_m}); +} + +template +void MetaxQuantizedWeightLayoutTrans(const Context& dev_ctx, + const std::string& algo, + const std::vector& shape, + DenseTensor* out) { + if (algo == "weight_only_int4") { + if (dev_ctx.GetPlace() == phi::CPUPlace()) { + cpu_int4_quantized_weight_layout_trans_impl(dev_ctx, shape, out); + } else { + gpu_int4_quantized_weight_layout_trans_impl(dev_ctx, shape, out); + } + + } else { + PADDLE_FATAL( + "The algo must be in ['weight_only_int4'" + "], but got[%s]", + algo); + } +} + +} // namespace phi diff --git a/backends/metax_gpu/kernels/impl/multi_dot_kernel_impl.h b/backends/metax_gpu/kernels/impl/multi_dot_kernel_impl.h index aaa7fbd8d2c..7ba97234cc1 100644 --- a/backends/metax_gpu/kernels/impl/multi_dot_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/multi_dot_kernel_impl.h @@ -14,9 +14,9 @@ #pragma once -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" namespace phi { template diff --git a/backends/metax_gpu/kernels/impl/mv_kernel_impl.h b/backends/metax_gpu/kernels/impl/mv_kernel_impl.h index a87d431e250..4baee25a099 100644 --- a/backends/metax_gpu/kernels/impl/mv_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/mv_kernel_impl.h @@ -14,7 +14,7 @@ #pragma once -#include "kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" namespace phi { diff --git a/backends/metax_gpu/kernels/impl/solve_grad_kernel_impl.h b/backends/metax_gpu/kernels/impl/solve_grad_kernel_impl.h index 860bce2cba5..1dd276dde2f 100644 --- a/backends/metax_gpu/kernels/impl/solve_grad_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/solve_grad_kernel_impl.h @@ -14,11 +14,11 @@ limitations under the License. */ #pragma once -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/kernels/expand_as_kernel.h" #include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/matrix_solve.h" #include "paddle/phi/kernels/funcs/reduce_function.h" diff --git a/backends/metax_gpu/kernels/impl/triangular_solve_grad_kernel_impl.h b/backends/metax_gpu/kernels/impl/triangular_solve_grad_kernel_impl.h index 08138853099..ad656b7a6c8 100644 --- a/backends/metax_gpu/kernels/impl/triangular_solve_grad_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/triangular_solve_grad_kernel_impl.h @@ -14,10 +14,10 @@ #pragma once -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/common_shape.h" #include "paddle/phi/kernels/funcs/complex_functors.h" #include "paddle/phi/kernels/funcs/for_range.h" diff --git a/backends/metax_gpu/kernels/metax_kernel/batch_fc_grad_kernel_register.cu b/backends/metax_gpu/kernels/metax_kernel/batch_fc_grad_kernel_register.cu index 51f8f6792e2..c31d82920b3 100644 --- a/backends/metax_gpu/kernels/metax_kernel/batch_fc_grad_kernel_register.cu +++ b/backends/metax_gpu/kernels/metax_kernel/batch_fc_grad_kernel_register.cu @@ -14,10 +14,10 @@ #include -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/eigen/common.h" namespace phi { diff --git a/backends/metax_gpu/kernels/metax_kernel/block_attn.h b/backends/metax_gpu/kernels/metax_kernel/block_attn.h index 1e1eb2c0961..a5b88e34be1 100644 --- a/backends/metax_gpu/kernels/metax_kernel/block_attn.h +++ b/backends/metax_gpu/kernels/metax_kernel/block_attn.h @@ -14,11 +14,11 @@ #pragma once -#include "kernels/funcs/quant_dequant.h" #include "kernels/metax_kernel/mmha_util.cu.h" #include "paddle/common/flags.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/memory_utils.h" +#include "paddle/phi/kernels/funcs/quant_dequant.h" COMMON_DECLARE_bool(use_xqa_optim); COMMON_DECLARE_bool(blha_use_fp32_qk_sum); diff --git a/backends/metax_gpu/kernels/metax_kernel/elementwise.h b/backends/metax_gpu/kernels/metax_kernel/elementwise.h index 52a7709424b..b9f3d8af1c9 100644 --- a/backends/metax_gpu/kernels/metax_kernel/elementwise.h +++ b/backends/metax_gpu/kernels/metax_kernel/elementwise.h @@ -14,9 +14,9 @@ limitations under the License. */ #pragma once -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/common_shape.h" #include "paddle/phi/kernels/funcs/eigen/common.h" diff --git a/backends/metax_gpu/kernels/metax_kernel/metax_context.h b/backends/metax_gpu/kernels/metax_kernel/metax_context.h index 7386811a236..18f1e30f191 100644 --- a/backends/metax_gpu/kernels/metax_kernel/metax_context.h +++ b/backends/metax_gpu/kernels/metax_kernel/metax_context.h @@ -17,9 +17,9 @@ #include #include -#include "kernels/funcs/blas/cublasLt.h" #include "paddle/phi/backends/context_pool.h" #include "paddle/phi/backends/custom/custom_context.h" +#include "paddle/phi/backends/dynload/cublasLt.h" #include "paddle/phi/backends/gpu/forwards.h" #include "paddle/phi/backends/gpu/gpu_decls.h" #include "paddle/phi/backends/gpu/gpu_helper.h" @@ -28,8 +28,6 @@ #include "paddle/phi/core/attribute.h" #include "paddle/phi/core/device_context.h" -cublasLtHandle_t GetBlasLtHandle(); - namespace phi { class DnnWorkspaceHandle { public: diff --git a/backends/metax_gpu/kernels/metax_kernel/mmha_util.cu.h b/backends/metax_gpu/kernels/metax_kernel/mmha_util.cu.h index aa352e600b5..187b0fc534a 100644 --- a/backends/metax_gpu/kernels/metax_kernel/mmha_util.cu.h +++ b/backends/metax_gpu/kernels/metax_kernel/mmha_util.cu.h @@ -49,10 +49,10 @@ #pragma once -#if defined(__CUDACC__) && CUDA_VERSION >= 11000 +// #if defined(__CUDACC__) && CUDA_VERSION >= 11000 #define ENABLE_BF16 #include -#endif +// #endif #ifdef PADDLE_WITH_HIP #include @@ -72,8 +72,8 @@ namespace cub = hipcub; #endif #include "paddle/phi/common/datatype_traits.h" +#include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/funcs/math_function.h" - #ifdef PADDLE_WITH_HIP /// integral_constant template @@ -130,7 +130,7 @@ struct Float4_ { float2 y; }; -#if defined(ENABLE_BF16) || defined(PADDLE_WITH_HIP) +// #if defined(ENABLE_BF16) || defined(PADDLE_WITH_HIP) struct bf16_4_t { __nv_bfloat162 x; __nv_bfloat162 y; @@ -142,7 +142,7 @@ struct bf16_8_t { __nv_bfloat162 z; __nv_bfloat162 w; }; -#endif +// #endif //----------------------------------- template diff --git a/backends/metax_gpu/kernels/metax_kernel/mv_grad_kernel_register.cu b/backends/metax_gpu/kernels/metax_kernel/mv_grad_kernel_register.cu index 895484324a9..8cf069c0f4b 100644 --- a/backends/metax_gpu/kernels/metax_kernel/mv_grad_kernel_register.cu +++ b/backends/metax_gpu/kernels/metax_kernel/mv_grad_kernel_register.cu @@ -12,10 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/mv_grad_kernel.h" namespace phi { diff --git a/backends/metax_gpu/kernels/metax_kernel/quant_dequant.h b/backends/metax_gpu/kernels/metax_kernel/quant_dequant.h index a37fc8c5c57..80d325530f5 100644 --- a/backends/metax_gpu/kernels/metax_kernel/quant_dequant.h +++ b/backends/metax_gpu/kernels/metax_kernel/quant_dequant.h @@ -16,12 +16,12 @@ limitations under the License. */ #include -#include "kernels/funcs/blas/blas.h" #include "paddle/common/hostdevice.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/common/transform.h" #include "paddle/phi/kernels/funcs/aligned_vector.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" namespace phi { diff --git a/backends/metax_gpu/kernels/metax_kernel/rank_attention_grad_kernel_register.cu b/backends/metax_gpu/kernels/metax_kernel/rank_attention_grad_kernel_register.cu index bee25a721fa..ba33e68aa5e 100644 --- a/backends/metax_gpu/kernels/metax_kernel/rank_attention_grad_kernel_register.cu +++ b/backends/metax_gpu/kernels/metax_kernel/rank_attention_grad_kernel_register.cu @@ -17,8 +17,8 @@ #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" #include "paddle/phi/core/kernel_registry.h" -// #include "paddle/phi/kernels/funcs/blas/blas.h" -#include "kernels/funcs/blas/blas.h" +// #include "paddle/phi/paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/rank_attention.cu.h" diff --git a/backends/metax_gpu/kernels/metax_kernel/rank_attention_kernel_register.cu b/backends/metax_gpu/kernels/metax_kernel/rank_attention_kernel_register.cu index b6a4d2d76e9..eeb9c938888 100644 --- a/backends/metax_gpu/kernels/metax_kernel/rank_attention_kernel_register.cu +++ b/backends/metax_gpu/kernels/metax_kernel/rank_attention_kernel_register.cu @@ -17,8 +17,8 @@ #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" #include "paddle/phi/core/kernel_registry.h" -// #include "paddle/phi/kernels/funcs/blas/blas.h" -#include "kernels/funcs/blas/blas.h" +// #include "paddle/phi/paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/rank_attention.cu.h" diff --git a/backends/metax_gpu/kernels/metax_kernel/slogdeterminant_kernel_register.cu b/backends/metax_gpu/kernels/metax_kernel/slogdeterminant_kernel_register.cu index de263c91c4d..3e9a5683ae4 100644 --- a/backends/metax_gpu/kernels/metax_kernel/slogdeterminant_kernel_register.cu +++ b/backends/metax_gpu/kernels/metax_kernel/slogdeterminant_kernel_register.cu @@ -20,12 +20,12 @@ #include #include "glog/logging.h" -#include "kernels/funcs/blas/blas.h" #include "paddle/phi/common/complex.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/determinant_kernel.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/impl/determinant_kernel_impl.h" #include "paddle/phi/kernels/slogdeterminant_kernel.h" diff --git a/backends/metax_gpu/kernels/metax_kernel/softmax_kernel_grad_register.cu b/backends/metax_gpu/kernels/metax_kernel/softmax_kernel_grad_register.cu index 9b981029fc0..407180deca8 100644 --- a/backends/metax_gpu/kernels/metax_kernel/softmax_kernel_grad_register.cu +++ b/backends/metax_gpu/kernels/metax_kernel/softmax_kernel_grad_register.cu @@ -45,5 +45,6 @@ PD_REGISTER_PLUGIN_KERNEL(softmax_grad, ALL_LAYOUT, phi::SoftmaxGradGPUDNNKernel, float, + double, phi::dtype::float16, phi::dtype::bfloat16) {} diff --git a/backends/metax_gpu/kernels/metax_kernel/triangular_solve_kernel_register.cu b/backends/metax_gpu/kernels/metax_kernel/triangular_solve_kernel_register.cu index 5ff3211fe87..ed1ed259437 100644 --- a/backends/metax_gpu/kernels/metax_kernel/triangular_solve_kernel_register.cu +++ b/backends/metax_gpu/kernels/metax_kernel/triangular_solve_kernel_register.cu @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "kernels/funcs/blas/blas.h" #include "paddle/common/ddim.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/memory_utils.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/expand_kernel.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/common_shape.h" #include "paddle/phi/kernels/triangular_solve_kernel.h" diff --git a/backends/metax_gpu/kernels/metax_kernel/weight_only_linear_kernel.cu b/backends/metax_gpu/kernels/metax_kernel/weight_only_linear_kernel.cu index d2f39ccf751..65cf99d3065 100644 --- a/backends/metax_gpu/kernels/metax_kernel/weight_only_linear_kernel.cu +++ b/backends/metax_gpu/kernels/metax_kernel/weight_only_linear_kernel.cu @@ -166,7 +166,7 @@ void WeightOnlyLinearKernel(const Context& dev_ctx, mctlassGemmScaleOp_w4a16_nobias::epilogueParams( reinterpret_cast(bias_data)), mctlassGemmScaleOp_w4a16_nobias::quantscaleParams( - 1, + 2, group_size, reinterpret_cast(weight_scale_data)), reinterpret_cast(x_data), @@ -191,7 +191,7 @@ void WeightOnlyLinearKernel(const Context& dev_ctx, mctlassGemmScaleOp_w4a16_bias::epilogueParams( reinterpret_cast(bias_data)), mctlassGemmScaleOp_w4a16_bias::quantscaleParams( - 1, + 2, group_size, reinterpret_cast(weight_scale_data)), reinterpret_cast(x_data), diff --git a/backends/metax_gpu/kernels/metax_kernel/weight_quantize_kernel_register.cu b/backends/metax_gpu/kernels/metax_kernel/weight_quantize_kernel_register.cu index 44ac7f2fddc..8d72ed2138e 100644 --- a/backends/metax_gpu/kernels/metax_kernel/weight_quantize_kernel_register.cu +++ b/backends/metax_gpu/kernels/metax_kernel/weight_quantize_kernel_register.cu @@ -11,6 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +#include "../impl/metax_weight_quantize_kernel_impl.h" #include "paddle/common/enforce.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/datatype_traits.h" @@ -120,7 +121,6 @@ void WeightQuantizeKernel(const Context& dev_ctx, weight_shape, arch, algo); - out->Resize({m, n}); #ifdef PADDLE_WITH_HIP DenseTensor x_int_tmp(out->type()); x_int_tmp.Resize({m, n / 2}); @@ -141,6 +141,7 @@ void WeightQuantizeKernel(const Context& dev_ctx, // arch, // algo); #endif + MetaxQuantizedWeightLayoutTrans(dev_ctx, algo, weight_shape, out); } else if (algo == "w4a8") { weight_permute_gpu_w4a8(dev_ctx, x.data(), diff --git a/backends/metax_gpu/patch/paddle.patch b/backends/metax_gpu/patch/paddle.patch index 7ba32b5b399..6578029129e 100755 --- a/backends/metax_gpu/patch/paddle.patch +++ b/backends/metax_gpu/patch/paddle.patch @@ -31,6 +31,56 @@ index bff0f2bf70..9376b5781f 100644 #include "paddle/phi/core/os_info.h" #include "paddle/phi/core/platform/device/gpu/gpu_info.h" #include "paddle/phi/core/platform/profiler/utils.h" +diff --git a/paddle/phi/backends/dynload/cublas.h b/paddle/phi/backends/dynload/cublas.h +index 62beb53cfe..0b0ac09fc0 100644 +--- a/paddle/phi/backends/dynload/cublas.h ++++ b/paddle/phi/backends/dynload/cublas.h +@@ -49,7 +49,12 @@ extern void *cublas_dso_handle; + std::call_once(cublas_dso_flag, []() { \ + cublas_dso_handle = phi::dynload::GetCublasDsoHandle(); \ + }); \ +- static void *p_##__name = dlsym(cublas_dso_handle, #__name); \ ++ std::string replaced_name = #__name; \ ++ replaced_name = replaced_name.replace(0, 2, "mc"); \ ++ int index = replaced_name.find("_", 0); \ ++ if (index != -1) replaced_name = replaced_name.substr(0, index); \ ++ static void* p_##__name = \ ++ dlsym(cublas_dso_handle, replaced_name.c_str()); \ + return reinterpret_cast(p_##__name)(args...); \ + } \ + }; \ +diff --git a/paddle/phi/backends/dynload/cublasLt.h b/paddle/phi/backends/dynload/cublasLt.h +index 8b2e08c777..ca926df151 100644 +--- a/paddle/phi/backends/dynload/cublasLt.h ++++ b/paddle/phi/backends/dynload/cublasLt.h +@@ -46,12 +46,14 @@ extern void *cublasLt_dso_handle; + std::call_once(cublasLt_dso_flag, []() { \ + cublasLt_dso_handle = phi::dynload::GetCublasLtDsoHandle(); \ + }); \ +- static void *p_##__name = dlsym(cublasLt_dso_handle, #__name); \ ++ std::string replaced_name = #__name; \ ++ replaced_name = replaced_name.replace(0, 2, "mc"); \ ++ static void* p_##__name = \ ++ dlsym(cublasLt_dso_handle, replaced_name.c_str()); \ + return reinterpret_cast(p_##__name)(args...); \ + } \ + }; \ + extern DynLoad__##__name __name +- + // APIs available after CUDA 11.1 + #if CUDA_VERSION >= 11010 || defined(PADDLE_WITH_CUSTOM_DEVICE) + #define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \ +@@ -79,8 +81,8 @@ extern void *cublasLt_dso_handle; + __macro(cublasLtMatmulAlgoConfigGetAttribute); \ + __macro(cublasLtMatmulAlgoGetIds); \ + __macro(cublasLtMatmulAlgoCapGetAttribute); \ +- __macro(cublasLtMatmulAlgoCheck); \ +- __macro(cublasLtGetCudartVersion); ++ __macro(cublasLtMatmulAlgoCheck); ++ // __macro(cublasLtGetCudartVersion); + #else + #define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \ + __macro(cublasLtCreate); \ diff --git a/paddle/phi/backends/dynload/cudnn.h b/paddle/phi/backends/dynload/cudnn.h index c0080f0a5e..458ca3e2e8 100644 --- a/paddle/phi/backends/dynload/cudnn.h @@ -210,6 +260,29 @@ index 8ec3cf2792..6f5460df00 100644 return reinterpret_cast(p_##__name)(args...); \ } \ }; \ +diff --git a/paddle/phi/backends/dynload/dynamic_loader.cc b/paddle/phi/backends/dynload/dynamic_loader.cc +index 859f696896..87b5100a1b 100644 +--- a/paddle/phi/backends/dynload/dynamic_loader.cc ++++ b/paddle/phi/backends/dynload/dynamic_loader.cc +@@ -18,7 +18,6 @@ limitations under the License. */ + #include + #include + #include +-#include "paddle/phi/backends/dynload/cupti_lib_path.h" + #include "paddle/phi/common/port.h" + #include "paddle/phi/core/enforce.h" + +@@ -108,6 +107,10 @@ COMMON_DECLARE_string(win_cuda_bin_dir); + #define SPARSELT_LIB_NAME "libcusparseLt.so" + #endif + ++#ifndef CUPTI_LIB_PATH ++#define CUPTI_LIB_PATH "@CUPTI_LIBRARY_PATH@" ++#endif ++ + #ifdef PADDLE_WITH_HIP + + PHI_DEFINE_string(miopen_dir, diff --git a/paddle/phi/backends/dynload/nvjpeg.h b/paddle/phi/backends/dynload/nvjpeg.h index c5309e7e11..3328571380 100644 --- a/paddle/phi/backends/dynload/nvjpeg.h @@ -346,21 +419,10 @@ index 4ff2e528a9..23f7f4b583 100644 for (int offset = warpSize / 2; offset > 0; offset /= 2) diff --git a/paddle/phi/core/enforce.h b/paddle/phi/core/enforce.h -index 024a7de73e..1e4cdf16be 100644 +index 024a7de73e..66b373d698 100644 --- a/paddle/phi/core/enforce.h +++ b/paddle/phi/core/enforce.h -@@ -45,7 +45,9 @@ limitations under the License. */ - #endif - - #ifdef PADDLE_WITH_CUDA --#include "paddle/phi/backends/dynload/cublas.h" -+// #include "paddle/phi/backends/dynload/../../../../../cublas.h" -+#include "../backends/metax_gpu/kernels/funcs/blas/cublas.h" -+// #include "paddle/phi/backends/dynload/cublas.h" - #include "paddle/phi/backends/dynload/cudnn.h" - #include "paddle/phi/backends/dynload/curand.h" - #include "paddle/phi/backends/dynload/cusolver.h" -@@ -97,7 +99,7 @@ inline bool is_error(bool stat) { return !stat; } +@@ -97,7 +97,7 @@ inline bool is_error(bool stat) { return !stat; } void ThrowWarnInternal(const std::string& message); @@ -369,75 +431,225 @@ index 024a7de73e..1e4cdf16be 100644 // For cuda, the assertions can affect performance and it is therefore // recommended to disable them in production code // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#assertion -@@ -109,7 +111,7 @@ void ThrowWarnInternal(const std::string& message); +@@ -109,7 +109,7 @@ void ThrowWarnInternal(const std::string& message); __LINE__, \ #_IS_NOT_ERROR, \ ##__VA_ARGS__); \ - asm("trap;"); \ -+ __builtin_trap(); \ ++ __builtin_trap(); \ } \ } while (0) #elif defined(__HIPCC__) -@@ -757,4 +759,4 @@ inline void retry_sleep(unsigned millisecond) { +diff --git a/paddle/phi/kernels/funcs/blas/blas_impl.cu.h b/paddle/phi/kernels/funcs/blas/blas_impl.cu.h +index ae7b67de6d..fbe9f67737 100644 +--- a/paddle/phi/kernels/funcs/blas/blas_impl.cu.h ++++ b/paddle/phi/kernels/funcs/blas/blas_impl.cu.h +@@ -368,7 +368,7 @@ struct CUBlas { + cudaDataType_t Ctype, + int ldc, + int batchCount, +- cudaDataType_t computeType) { ++ cublasComputeType_t computeType) { + #if CUDA_VERSION >= 8000 + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; + #if CUDA_VERSION >= 9000 +@@ -476,7 +476,7 @@ struct CUBlas { + void *C, + cudaDataType_t Ctype, + int ldc, +- cudaDataType_t computeType) { ++ cublasComputeType_t computeType) { + #if CUDA_VERSION >= 8000 + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; + #if CUDA_VERSION >= 9000 +@@ -532,7 +532,7 @@ struct CUBlas { + void *C, + cudaDataType_t Ctype, + int64_t ldc, +- cudaDataType_t computeType) { ++ cublasComputeType_t computeType) { + #if CUDA_VERSION >= 12030 && defined(__linux__) + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; + bool use_tensor_op_math = dev_ctx->tensor_core_available(); +@@ -759,7 +759,7 @@ struct CUBlas { + void *C, + cudaDataType_t Ctype, + int ldc, +- cudaDataType_t computeType) { ++ cublasComputeType_t computeType) { + #if CUDA_VERSION >= 8000 + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; + #if CUDA_VERSION >= 9000 +@@ -815,7 +815,7 @@ struct CUBlas { + void *C, + cudaDataType_t Ctype, + int64_t ldc, +- cudaDataType_t computeType) { ++ cublasComputeType_t computeType) { + #if CUDA_VERSION >= 12030 && defined(__linux__) + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; + bool use_tensor_op_math = dev_ctx->tensor_core_available(); +@@ -1154,7 +1154,7 @@ struct CUBlas { + void *C, + cudaDataType_t Ctype, + int ldc, +- cudaDataType_t computeType) { ++ cublasComputeType_t computeType) { + #if CUDA_VERSION >= 8000 + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; + #if CUDA_VERSION >= 9000 +@@ -1210,7 +1210,7 @@ struct CUBlas { + void *C, + cudaDataType_t Ctype, + int64_t ldc, +- cudaDataType_t computeType) { ++ cublasComputeType_t computeType) { + #if CUDA_VERSION >= 12030 && defined(__linux__) + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; + bool use_tensor_op_math = dev_ctx->tensor_core_available(); +@@ -1484,7 +1484,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, + C, + CUDA_R_16F, + N, +- CUDA_R_32F); ++ CUBLAS_COMPUTE_32F); + #else + PADDLE_THROW(common::errors::Unimplemented( + "GEMM_EX_64 is not supported on cuda < 12.3")); +@@ -1508,7 +1508,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, + C, + CUDA_R_16F, + static_cast(N), +- CUDA_R_32F); ++ CUBLAS_COMPUTE_32F); + } + #else + // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm +@@ -1694,7 +1694,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, + C, + CUDA_R_16F, + N, +- CUDA_R_32F); ++ CUBLAS_COMPUTE_32F); + #else + PADDLE_THROW(common::errors::Unimplemented( + "GEMM_EX_64 is not supported on cuda < 12.3")); +@@ -1719,7 +1719,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, + C, + CUDA_R_16F, + static_cast(N), +- CUDA_R_32F); ++ CUBLAS_COMPUTE_32F); + #else + // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm + dev_ctx_.CublasCall([&](cublasHandle_t handle) { +@@ -1831,7 +1831,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, + C, + CUDA_R_16BF, + static_cast(N), +- CUDA_R_32F, ++ CUBLAS_COMPUTE_32F, + algo)); + }); + } +@@ -1932,7 +1932,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, + C, + CUDA_R_16BF, + static_cast(N), +- CUDA_R_32F, ++ CUBLAS_COMPUTE_32F, + algo)); + }); + } +@@ -2026,7 +2026,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, + C, + CUDA_C_32F, + static_cast(N), +- CUDA_C_32F); ++ CUBLAS_COMPUTE_32F); - } // namespace enforce - using namespace enforce; // NOLINT --} // namespace phi -+} // namespace phi -\ No newline at end of file -diff --git a/paddle/phi/core/platform/device/gpu/gpu_types.h b/paddle/phi/core/platform/device/gpu/gpu_types.h -index c646e487d0..325122175c 100644 ---- a/paddle/phi/core/platform/device/gpu/gpu_types.h -+++ b/paddle/phi/core/platform/device/gpu/gpu_types.h -@@ -25,8 +25,9 @@ #else - #include - --#include "paddle/phi/backends/dynload/cublas.h" --#include "paddle/phi/backends/dynload/cublasLt.h" -+// #include "paddle/phi/backends/dynload/cublas.h" -+#include "kernels/funcs/blas/cublas.h" -+// #include "paddle/phi/backends/dynload/cublasLt.h" - #include "paddle/phi/backends/dynload/cudnn.h" - #endif + dev_ctx_.CublasCall([&](cublasHandle_t handle) { +@@ -2111,7 +2111,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, + C, + CUDA_C_64F, + N, +- CUDA_C_64F); ++ CUBLAS_COMPUTE_64F); + #else + PADDLE_THROW(common::errors::Unimplemented( + "GEMM_EX_64 is not supported on cuda < 12.3")); +@@ -2136,7 +2136,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, + C, + CUDA_C_64F, + static_cast(N), +- CUDA_C_64F); ++ CUBLAS_COMPUTE_64F); + #else // CUDA_VERSION >= 8000 + // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm + dev_ctx_.CublasCall([&](cublasHandle_t handle) { +@@ -3129,7 +3129,7 @@ inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, + CUDA_R_16F, + ldc, + batchCount, +- CUDA_R_32F); ++ CUBLAS_COMPUTE_32F); + } -@@ -90,7 +91,7 @@ DECLARE_TYPE_FOR_GPU(gpuStreamCaptureMode, - - // TODO(Ming Huang): Since there is no blasLt handler, - // use rocblas_handle for workaround. --DECLARE_TYPE_FOR_GPU(blasLtHandle_t, cublasLtHandle_t, rocblas_handle); -+// DECLARE_TYPE_FOR_GPU(blasLtHandle_t, cublasLtHandle_t, rocblas_handle); - - #undef DECLARE_TYPE_FOR_GPU - -diff --git a/paddle/phi/core/platform/device_context.h b/paddle/phi/core/platform/device_context.h -index 2d02eb370b..8a7233e34e 100644 ---- a/paddle/phi/core/platform/device_context.h -+++ b/paddle/phi/core/platform/device_context.h -@@ -25,8 +25,8 @@ limitations under the License. */ - #include "paddle/phi/core/platform/device/gpu/gpu_types.h" - #include "paddle/phi/core/platform/device_type.h" - #ifdef PADDLE_WITH_CUDA --#include "paddle/phi/backends/dynload/cublas.h" --#include "paddle/phi/backends/dynload/cublasLt.h" -+#include "kernels/funcs/blas/cublas.h" -+#include "kernels/funcs/blas/cublasLt.h" - #include "paddle/phi/backends/dynload/cudnn.h" - #include "paddle/phi/backends/dynload/cusolver.h" - #include "paddle/phi/backends/dynload/cusparse.h" -diff --git a/paddle/phi/kernels/cpu/index_select_impl.h b/paddle/phi/kernels/cpu/index_select_impl.h -index d69eb67d6f..1d8b6e9375 100644 ---- a/paddle/phi/kernels/cpu/index_select_impl.h -+++ b/paddle/phi/kernels/cpu/index_select_impl.h -@@ -18,7 +18,7 @@ - - #include "paddle/phi/core/dense_tensor.h" - #include "paddle/phi/core/tensor_utils.h" --#include "paddle/phi/kernels/funcs/blas/blas.h" -+#include "kernels/funcs/blas/blas.h" - #include "paddle/phi/kernels/funcs/eigen/common.h" - #include "paddle/phi/kernels/funcs/math_function.h" + template <> +diff --git a/paddle/phi/kernels/funcs/blas/blaslt_gemm_search.h b/paddle/phi/kernels/funcs/blas/blaslt_gemm_search.h +index e63b3d2f6e..95d7e6f204 100644 +--- a/paddle/phi/kernels/funcs/blas/blaslt_gemm_search.h ++++ b/paddle/phi/kernels/funcs/blas/blaslt_gemm_search.h +@@ -628,7 +628,13 @@ class CublasLtAlgoCache { + infile >> cublaslt_version; + VLOG(1) << "cublaslt_version " << cublaslt_version; + +- if (dynload::cublasLtGetCudartVersion() != cublaslt_version) { ++ // if (dynload::cublasLtGetCudartVersion() != cublaslt_version) { ++ // LOG(INFO) << algo_caches_file_ ++ // << " is not compatible with current cublaslt_version " ++ // << real_cublaslt_version; ++ // return; ++ // } ++ if (3000 != cublaslt_version) { + LOG(INFO) << algo_caches_file_ + << " is not compatible with current cublaslt_version " + << real_cublaslt_version; +@@ -655,7 +661,8 @@ class CublasLtAlgoCache { + if (dev == 0) { + std::ofstream outfile; + outfile.open(algo_caches_file_, std::ios::out | std::ios::trunc); +- outfile << dynload::cublasLtGetCudartVersion() << std::endl; ++ // outfile << dynload::cublasLtGetCudartVersion() << std::endl; ++ outfile << 3000 << std::endl; + + for (const auto& [seed, algo] : algo_caches_) { + outfile << seed << " "; +diff --git a/paddle/phi/kernels/funcs/cublaslt.h b/paddle/phi/kernels/funcs/cublaslt.h +index fbbf57c25a..f690db59e9 100644 +--- a/paddle/phi/kernels/funcs/cublaslt.h ++++ b/paddle/phi/kernels/funcs/cublaslt.h +@@ -42,19 +42,11 @@ class CublasLtHelper { + CublasLtHelper(int m, int k, int n, cublasLtHandle_t handle) + : handle_(handle), alpha_(1), beta_(0), m_(m), k_(k), n_(n) { + cublasStatus_t status; +-#if CUBLAS_VER_MAJOR < 11 +- cudaDataType_t cudaComputeType = CUDA_R_32I; +-#else + cublasComputeType_t cudaComputeType = CUBLAS_COMPUTE_32I; +-#endif + + // matmul desc +-#if CUBLAS_VER_MAJOR < 11 +- status = dyl::cublasLtMatmulDescCreate(&matmul_desc_, cudaComputeType); +-#else + status = dyl::cublasLtMatmulDescCreate( + &matmul_desc_, cudaComputeType, CUDA_R_32I); +-#endif + PADDLE_ENFORCE_EQ( + status, diff --git a/paddle/phi/kernels/funcs/embedding_grad.h b/paddle/phi/kernels/funcs/embedding_grad.h index 461e6e2474..48a64ae9ce 100644 --- a/paddle/phi/kernels/funcs/embedding_grad.h @@ -453,38 +665,6 @@ index 461e6e2474..48a64ae9ce 100644 #endif dim3 threads(kWarpSize, kBlockDimY); dim3 grids(static_cast((D + kWarpSize - 1) / kWarpSize)); -diff --git a/paddle/phi/kernels/funcs/fc_functor.cu b/paddle/phi/kernels/funcs/fc_functor.cu -index cb35feee32..64f5bd24ac 100644 ---- a/paddle/phi/kernels/funcs/fc_functor.cu -+++ b/paddle/phi/kernels/funcs/fc_functor.cu -@@ -16,12 +16,12 @@ limitations under the License. */ - - #include "paddle/phi/backends/all_context.h" - #include "paddle/phi/kernels/funcs/aligned_vector.h" --#include "paddle/phi/kernels/funcs/blas/blas.h" -+#include "kernels/funcs/blas/blas.h" - #include "paddle/phi/kernels/funcs/fc_functor.h" - - #include "paddle/phi/backends/gpu/gpu_launch_config.h" - #include "paddle/phi/core/dense_tensor.h" --#include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h" -+// #include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h" - #include "paddle/phi/kernels/funcs/quant_dequant.h" - #include "paddle/phi/kernels/matmul_kernel.h" - -diff --git a/paddle/phi/kernels/funcs/gru_compute.cu b/paddle/phi/kernels/funcs/gru_compute.cu -index 88663ec880..98b93072a3 100644 ---- a/paddle/phi/kernels/funcs/gru_compute.cu -+++ b/paddle/phi/kernels/funcs/gru_compute.cu -@@ -12,7 +12,7 @@ limitations under the License. */ - #include "paddle/phi/kernels/funcs/gru_compute.h" - - #include "paddle/phi/backends/gpu/gpu_context.h" --#include "paddle/phi/kernels/funcs/blas/blas.h" -+#include "kernels/funcs/blas/blas.h" - #include "paddle/phi/kernels/funcs/detail/gru_gpu_kernel.h" - #include "paddle/phi/kernels/funcs/detail/gru_kernel.h" - diff --git a/paddle/phi/kernels/funcs/layer_norm_impl.cu.h b/paddle/phi/kernels/funcs/layer_norm_impl.cu.h index 4eae698648..5c047723ea 100644 --- a/paddle/phi/kernels/funcs/layer_norm_impl.cu.h @@ -503,19 +683,6 @@ index 4eae698648..5c047723ea 100644 #endif return block_dim >= kMaxBlockDim ? kMaxBlockDim : lwarpSize; } -diff --git a/paddle/phi/kernels/funcs/math/context_project.h b/paddle/phi/kernels/funcs/math/context_project.h -index 15e1a4a3c3..e4780538d7 100644 ---- a/paddle/phi/kernels/funcs/math/context_project.h -+++ b/paddle/phi/kernels/funcs/math/context_project.h -@@ -18,7 +18,7 @@ - #include - - #include "paddle/phi/core/tensor_utils.h" --#include "paddle/phi/kernels/funcs/blas/blas.h" -+#include "kernels/funcs/blas/blas.h" - #include "paddle/phi/kernels/funcs/im2col.h" - - namespace phi { diff --git a/paddle/phi/kernels/funcs/math_cuda_utils.h b/paddle/phi/kernels/funcs/math_cuda_utils.h index e5361b836e..5ad238df08 100644 --- a/paddle/phi/kernels/funcs/math_cuda_utils.h @@ -559,51 +726,6 @@ index e5361b836e..5ad238df08 100644 return val; } -diff --git a/paddle/phi/kernels/funcs/matrix_inverse.cu b/paddle/phi/kernels/funcs/matrix_inverse.cu -index e101224970..a52eb6096f 100644 ---- a/paddle/phi/kernels/funcs/matrix_inverse.cu -+++ b/paddle/phi/kernels/funcs/matrix_inverse.cu -@@ -15,11 +15,13 @@ limitations under the License. */ - #include "paddle/phi/kernels/funcs/matrix_inverse.h" - - #include "paddle/phi/common/memory_utils.h" --#include "paddle/phi/kernels/funcs/blas/blas.h" -+#include "kernels/funcs/blas/blas.h" - - namespace phi { - namespace funcs { - -+ -+ - template - void MatrixInverseFunctor::operator()(const Context& dev_ctx, - const DenseTensor& a, -diff --git a/paddle/phi/kernels/funcs/matrix_solve.cu b/paddle/phi/kernels/funcs/matrix_solve.cu -index 558d363b39..05da04b517 100644 ---- a/paddle/phi/kernels/funcs/matrix_solve.cu -+++ b/paddle/phi/kernels/funcs/matrix_solve.cu -@@ -16,7 +16,7 @@ limitations under the License. */ - #include "paddle/phi/backends/gpu/cuda/cudnn_workspace_helper.h" - #include "paddle/phi/common/memory_utils.h" - #include "paddle/phi/core/tensor_utils.h" --#include "paddle/phi/kernels/funcs/blas/blas.h" -+#include "kernels/funcs/blas/blas.h" - #include "paddle/phi/kernels/funcs/math_function.h" - #include "paddle/phi/kernels/funcs/scatter.cu.h" - -diff --git a/paddle/phi/kernels/funcs/multihead_matmul_functor.cu b/paddle/phi/kernels/funcs/multihead_matmul_functor.cu -index 047f52bd91..a05b34d3ba 100644 ---- a/paddle/phi/kernels/funcs/multihead_matmul_functor.cu -+++ b/paddle/phi/kernels/funcs/multihead_matmul_functor.cu -@@ -27,7 +27,7 @@ namespace cub = hipcub; - - #include "paddle/phi/kernels/funcs/multihead_matmul_functor.h" - --#include "paddle/phi/kernels/funcs/blas/blas.h" -+#include "kernels/funcs/blas/blas.h" - #include "paddle/phi/kernels/funcs/math_cuda_utils.h" - - namespace phi { diff --git a/paddle/phi/kernels/funcs/top_k_function_cuda.h b/paddle/phi/kernels/funcs/top_k_function_cuda.h index e30d440ff3..108edda7ca 100644 --- a/paddle/phi/kernels/funcs/top_k_function_cuda.h @@ -873,31 +995,17 @@ index e30d440ff3..108edda7ca 100644 } // namespace funcs } // namespace phi +// -diff --git a/paddle/phi/kernels/fusion/gpu/fused_dropout_helper.h b/paddle/phi/kernels/fusion/gpu/fused_dropout_helper.h -index 32db61532f..0220316bc3 100644 ---- a/paddle/phi/kernels/fusion/gpu/fused_dropout_helper.h -+++ b/paddle/phi/kernels/fusion/gpu/fused_dropout_helper.h -@@ -15,7 +15,7 @@ - #pragma once - - #if defined(PADDLE_WITH_CUDA) --#include "paddle/phi/backends/dynload/cublasLt.h" -+// #include "paddle/phi/backends/dynload/cublasLt.h" - #endif - - #include "glog/logging.h" diff --git a/paddle/phi/kernels/fusion/gpu/fused_layernorm_residual_dropout_bias.h b/paddle/phi/kernels/fusion/gpu/fused_layernorm_residual_dropout_bias.h -index 9d4bb18d55..ea42cc10a9 100644 +index 9d4bb18d55..80405c2b78 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_layernorm_residual_dropout_bias.h +++ b/paddle/phi/kernels/fusion/gpu/fused_layernorm_residual_dropout_bias.h -@@ -638,9 +638,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel( +@@ -638,9 +638,6 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel( RandVec(&state, rand); #pragma unroll for (int jt = 0; jt < VecSize; jt++) { -#ifndef PADDLE_WITH_HIP -#pragma unroll -#endif -+// #pragma unroll mask_vec[it][jt] = static_cast(rand[jt] >= dropout_prob); } } @@ -928,7 +1036,7 @@ index b2d15a59f8..f64582e85a 100644 namespace phi { namespace fusion { diff --git a/paddle/phi/kernels/gpu/depthwise_conv.h b/paddle/phi/kernels/gpu/depthwise_conv.h -index f0cca0f701..02ea957240 100644 +index 2edac5eba5..4f265e3db7 100644 --- a/paddle/phi/kernels/gpu/depthwise_conv.h +++ b/paddle/phi/kernels/gpu/depthwise_conv.h @@ -29,8 +29,8 @@ namespace cub = hipcub; @@ -942,19 +1050,6 @@ index f0cca0f701..02ea957240 100644 namespace phi { // To determine use cudnn or not. -diff --git a/paddle/phi/kernels/gpu/dot_kernel.cu b/paddle/phi/kernels/gpu/dot_kernel.cu -index af27ac89ab..ee0edc6b8e 100644 ---- a/paddle/phi/kernels/gpu/dot_kernel.cu -+++ b/paddle/phi/kernels/gpu/dot_kernel.cu -@@ -15,7 +15,7 @@ - #include "paddle/phi/kernels/dot_kernel.h" - #include "paddle/phi/backends/gpu/gpu_context.h" - #include "paddle/phi/core/kernel_registry.h" --#include "paddle/phi/kernels/funcs/blas/blas.h" -+#include "kernels/funcs/blas/blas.h" - #include "paddle/phi/kernels/funcs/eigen/common.h" - - #include "paddle/phi/kernels/full_kernel.h" diff --git a/paddle/phi/kernels/gpu/gelu_funcs.h b/paddle/phi/kernels/gpu/gelu_funcs.h index 29fa252e96..4ae72b0935 100644 --- a/paddle/phi/kernels/gpu/gelu_funcs.h @@ -1007,7 +1102,7 @@ index 63c35dd4ee..15da9aea45 100644 namespace phi { diff --git a/paddle/phi/kernels/gpu/lstsq_kernel.cu b/paddle/phi/kernels/gpu/lstsq_kernel.cu -index 1bdbe1564c..f753b54bc6 100644 +index c7f27b2924..4cf6204ac7 100644 --- a/paddle/phi/kernels/gpu/lstsq_kernel.cu +++ b/paddle/phi/kernels/gpu/lstsq_kernel.cu @@ -21,7 +21,7 @@ @@ -1019,84 +1114,6 @@ index 1bdbe1564c..f753b54bc6 100644 #include "paddle/phi/kernels/impl/qr_kernel_impl.h" #include "paddle/phi/kernels/impl/tril_triu_kernel_impl.h" #include "paddle/phi/kernels/lstsq_kernel.h" -diff --git a/paddle/phi/kernels/impl/addmm_grad_kernel_impl.h b/paddle/phi/kernels/impl/addmm_grad_kernel_impl.h -index 9bc5326c90..79b57a8203 100644 ---- a/paddle/phi/kernels/impl/addmm_grad_kernel_impl.h -+++ b/paddle/phi/kernels/impl/addmm_grad_kernel_impl.h -@@ -21,7 +21,7 @@ limitations under the License. */ - #include "paddle/phi/common/amp_type_traits.h" - #include "paddle/phi/kernels/addmm_grad_kernel.h" - #include "paddle/phi/kernels/full_kernel.h" --#include "paddle/phi/kernels/funcs/blas/blas.h" -+#include "kernels/funcs/blas/blas.h" - #include "paddle/phi/kernels/funcs/eigen/common.h" - #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" - #include "paddle/phi/kernels/funcs/for_range.h" -diff --git a/paddle/phi/kernels/impl/baddbmm_grad_kernel_impl.h b/paddle/phi/kernels/impl/baddbmm_grad_kernel_impl.h -index cf80666b4e..ca76e055fb 100644 ---- a/paddle/phi/kernels/impl/baddbmm_grad_kernel_impl.h -+++ b/paddle/phi/kernels/impl/baddbmm_grad_kernel_impl.h -@@ -19,7 +19,7 @@ limitations under the License. */ - - #include "paddle/phi/common/amp_type_traits.h" - #include "paddle/phi/kernels/baddbmm_grad_kernel.h" --#include "paddle/phi/kernels/funcs/blas/blas.h" -+#include "kernels/funcs/blas/blas.h" - #include "paddle/phi/kernels/funcs/eigen/common.h" - #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" - #include "paddle/phi/kernels/funcs/for_range.h" -diff --git a/paddle/phi/kernels/impl/baddbmm_kernel_impl.h b/paddle/phi/kernels/impl/baddbmm_kernel_impl.h -index 2789cb59a2..b91b076f7f 100644 ---- a/paddle/phi/kernels/impl/baddbmm_kernel_impl.h -+++ b/paddle/phi/kernels/impl/baddbmm_kernel_impl.h -@@ -20,7 +20,7 @@ limitations under the License. */ - - #include "paddle/phi/common/amp_type_traits.h" - #include "paddle/phi/kernels/baddbmm_kernel.h" --#include "paddle/phi/kernels/funcs/blas/blas.h" -+#include "kernels/funcs/blas/blas.h" - #include "paddle/phi/kernels/funcs/eigen/common.h" - #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" - -diff --git a/paddle/phi/kernels/impl/conv_transpose_grad_kernel_impl.h b/paddle/phi/kernels/impl/conv_transpose_grad_kernel_impl.h -index 9a21c23666..86413d1577 100644 ---- a/paddle/phi/kernels/impl/conv_transpose_grad_kernel_impl.h -+++ b/paddle/phi/kernels/impl/conv_transpose_grad_kernel_impl.h -@@ -19,7 +19,7 @@ - #include "paddle/phi/kernels/conv_transpose_grad_kernel.h" - #include "paddle/phi/kernels/cpu/conv_util.h" - #include "paddle/phi/kernels/full_kernel.h" --#include "paddle/phi/kernels/funcs/blas/blas.h" -+#include "kernels/funcs/blas/blas.h" - #include "paddle/phi/kernels/funcs/concat_and_split_functor.h" - #include "paddle/phi/kernels/funcs/im2col.h" - #include "paddle/phi/kernels/funcs/slice.h" -diff --git a/paddle/phi/kernels/impl/deformable_conv_grad_kernel_impl.h b/paddle/phi/kernels/impl/deformable_conv_grad_kernel_impl.h -index 4459a931da..837c8682b8 100644 ---- a/paddle/phi/kernels/impl/deformable_conv_grad_kernel_impl.h -+++ b/paddle/phi/kernels/impl/deformable_conv_grad_kernel_impl.h -@@ -18,7 +18,7 @@ - #include "paddle/phi/core/dense_tensor.h" - #include "paddle/phi/kernels/empty_kernel.h" - #include "paddle/phi/kernels/full_kernel.h" --#include "paddle/phi/kernels/funcs/blas/blas.h" -+#include "kernels/funcs/blas/blas.h" - #include "paddle/phi/kernels/funcs/deformable_conv_functor.h" - - namespace phi { -diff --git a/paddle/phi/kernels/impl/deformable_conv_kernel_impl.h b/paddle/phi/kernels/impl/deformable_conv_kernel_impl.h -index ad9e9197dd..5478d9817d 100644 ---- a/paddle/phi/kernels/impl/deformable_conv_kernel_impl.h -+++ b/paddle/phi/kernels/impl/deformable_conv_kernel_impl.h -@@ -18,7 +18,7 @@ - #include "paddle/phi/core/dense_tensor.h" - #include "paddle/phi/kernels/empty_kernel.h" - #include "paddle/phi/kernels/full_kernel.h" --#include "paddle/phi/kernels/funcs/blas/blas.h" -+#include "kernels/funcs/blas/blas.h" - #include "paddle/phi/kernels/funcs/deformable_conv_functor.h" - #include "paddle/phi/kernels/transpose_kernel.h" - #include "paddle/utils/optional.h" diff --git a/paddle/phi/kernels/impl/gammaincc_kernel_impl.h b/paddle/phi/kernels/impl/gammaincc_kernel_impl.h index e6b3960f6d..564125f1f6 100644 --- a/paddle/phi/kernels/impl/gammaincc_kernel_impl.h @@ -1112,80 +1129,3 @@ index e6b3960f6d..564125f1f6 100644 if ((x <= T{0}) || (a <= T{0})) return (T{1.0}); -diff --git a/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h b/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h -index 410fb3c560..009ce03440 100644 ---- a/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h -+++ b/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h -@@ -54,7 +54,7 @@ HOSTDEVICE T digamma_positive_domain(T x) { - - template - HOSTDEVICE T digamma(T x) { -- static T pi = T{3.14159265358979323846}; -+ const static T pi = T{3.14159265358979323846}; - - if (x == T{0.0}) { - T inf = std::numeric_limits::infinity(); -diff --git a/paddle/phi/kernels/impl/llm_int8_matmul_kernel_impl.h b/paddle/phi/kernels/impl/llm_int8_matmul_kernel_impl.h -index 5ebbc8d2db..c7b6c338e2 100644 ---- a/paddle/phi/kernels/impl/llm_int8_matmul_kernel_impl.h -+++ b/paddle/phi/kernels/impl/llm_int8_matmul_kernel_impl.h -@@ -15,8 +15,9 @@ limitations under the License. */ - #include - #include - #include "paddle/phi/common/datatype_traits.h" --#include "paddle/phi/kernels/funcs/cublaslt.h" --#include "paddle/phi/kernels/funcs/quant_dequant.h" -+#include "kernels/funcs/blas/cublaslt.h" -+#include "kernels/funcs/quant_dequant.h" -+#include "kernels/metax_kernel/metax_context.h" - - #pragma once - -@@ -668,7 +669,7 @@ void LLMGemm(const phi::GPUContext& dev_ctx, - - { - auto helper = -- std::make_unique(m, k, n, dev_ctx.cublaslt_handle()); -+ std::make_unique(m, k, n, GetBlasLtHandle()); - helper->GEMM(quant_input.data(), - weight->data(), - int_out.data(), -diff --git a/paddle/phi/kernels/impl/matrix_power_grad_kernel_impl.h b/paddle/phi/kernels/impl/matrix_power_grad_kernel_impl.h -index 1f319c4ae3..9186eb6906 100644 ---- a/paddle/phi/kernels/impl/matrix_power_grad_kernel_impl.h -+++ b/paddle/phi/kernels/impl/matrix_power_grad_kernel_impl.h -@@ -15,7 +15,7 @@ limitations under the License. */ - #pragma once - - #include "paddle/phi/core/dense_tensor.h" --#include "paddle/phi/kernels/funcs/blas/blas.h" -+#include "kernels/funcs/blas/blas.h" - #include "paddle/phi/kernels/funcs/matrix_inverse.h" - - namespace phi { -diff --git a/paddle/phi/kernels/impl/matrix_power_kernel_impl.h b/paddle/phi/kernels/impl/matrix_power_kernel_impl.h -index 6f03f76eeb..5fe2c3e7dc 100644 ---- a/paddle/phi/kernels/impl/matrix_power_kernel_impl.h -+++ b/paddle/phi/kernels/impl/matrix_power_kernel_impl.h -@@ -15,7 +15,7 @@ limitations under the License. */ - #pragma once - - #include "paddle/phi/core/dense_tensor.h" --#include "paddle/phi/kernels/funcs/blas/blas.h" -+#include "kernels/funcs/blas/blas.h" - #include "paddle/phi/kernels/funcs/for_range.h" - #include "paddle/phi/kernels/funcs/matrix_inverse.h" - -diff --git a/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h b/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h -index 4099d8b506..baef2cd643 100644 ---- a/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h -+++ b/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h -@@ -14,7 +14,7 @@ - - #pragma once - --#include "paddle/phi/kernels/funcs/blas/blas.h" -+#include "kernels/funcs/blas/blas.h" - #include "paddle/phi/kernels/funcs/eigen/common.h" - #include "paddle/phi/kernels/funcs/math_function.h" - diff --git a/backends/metax_gpu/runtime/runtime.cc b/backends/metax_gpu/runtime/runtime.cc index 36fbd88c2ea..edbe937e7ba 100644 --- a/backends/metax_gpu/runtime/runtime.cc +++ b/backends/metax_gpu/runtime/runtime.cc @@ -36,12 +36,12 @@ #include #include "glog/logging.h" -#include "kernels/funcs/blas/cublasLt.h" #include "paddle/fluid/platform/profiler/cuda_tracer.h" #include "paddle/fluid/platform/profiler/cupti_data_process.h" #include "paddle/phi/api/profiler/trace_event_collector.h" #include "paddle/phi/backends/device_base.h" #include "paddle/phi/backends/device_ext.h" +#include "paddle/phi/backends/dynload/cublasLt.h" #include "paddle/phi/backends/dynload/cupti.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/allocator.h" diff --git a/backends/metax_gpu/tests/ignore.txt b/backends/metax_gpu/tests/ignore.txt index be0357e5319..2b0fae559e6 100644 --- a/backends/metax_gpu/tests/ignore.txt +++ b/backends/metax_gpu/tests/ignore.txt @@ -24,9 +24,9 @@ test_conv3d_layer test_conv3d_transpose_part2_op test_fused_conv2d_add_act_op test_swiglu_metax -test_set_value_op -test_pad_op test_squared_l2_norm_op -test_concat_op test_dygraph_spectral_norm test_bincount_op +test_adamw_op +test_einsum_op +test_complex_matmul