From 51da6b78d001551ed3a8a74f7b06fdbcc1ee0d0a Mon Sep 17 00:00:00 2001 From: jiaxinWang-metax <189149612@qq.com> Date: Tue, 16 Sep 2025 14:33:27 +0800 Subject: [PATCH 01/12] modify cmake for warpctc and warprnnt --- backends/metax_gpu/CMakeLists.txt | 4 +- backends/metax_gpu/cmake/warpctc.cmake | 7 +- backends/metax_gpu/cmake/warprnnt.cmake | 8 ++- .../fused_conv2d_add_act_kernel_register.cu | 3 +- backends/metax_gpu/kernels/impl/warpctc.h | 64 ------------------- .../kernels/impl/warpctc_grad_kernel_impl.h | 2 +- .../kernels/impl/warpctc_kernel_impl.h | 16 ++--- backends/metax_gpu/kernels/impl/warprnnt.h | 63 ------------------ .../kernels/impl/warprnnt_kernel_impl.h | 14 ++-- backends/metax_gpu/kernels/metax_context.cc | 20 +++++- backends/metax_gpu/kernels/metax_context.h | 1 + 11 files changed, 51 insertions(+), 151 deletions(-) delete mode 100644 backends/metax_gpu/kernels/impl/warpctc.h delete mode 100644 backends/metax_gpu/kernels/impl/warprnnt.h diff --git a/backends/metax_gpu/CMakeLists.txt b/backends/metax_gpu/CMakeLists.txt index cca23ab42f5..787aae13e40 100755 --- a/backends/metax_gpu/CMakeLists.txt +++ b/backends/metax_gpu/CMakeLists.txt @@ -736,7 +736,7 @@ add_library( target_include_directories( ${TARGET_NAME} PRIVATE ${PADDLE_SOURCE_DIR} ${CMAKE_SOURCE_DIR} ${CMAKE_SOURCE_DIR}/kernels - ${CUDA_INCLUDE_DIRS} ${PADDLE_SOURCE_DIR}/third_party/pybind/include + ${CUDA_INCLUDE_DIRS} ${WARPCTC_INCLUDE_DIR} ${WARPRNNT_INCLUDE_DIR} ${PADDLE_SOURCE_DIR}/third_party/pybind/include ${PADDLE_SOURCE_DIR}/paddle/phi/api/include/compat) target_link_libraries( @@ -749,6 +749,8 @@ target_link_libraries( protobuf external_error_proto dgc + ${WARPCTC_LIBRARIES} + ${WARPRNNT_LIBRARIES} ${PADDLE_CORE_LIB}) target_link_libraries(${TARGET_NAME} /opt/maca/lib/libmccl.so) target_link_libraries(${TARGET_NAME} /opt/maca/lib/libmcFlashAttn.so) diff --git a/backends/metax_gpu/cmake/warpctc.cmake b/backends/metax_gpu/cmake/warpctc.cmake index 71c892a6cfa..9edc92f0a94 100644 --- a/backends/metax_gpu/cmake/warpctc.cmake +++ b/backends/metax_gpu/cmake/warpctc.cmake @@ -145,5 +145,8 @@ get_filename_component(WARPCTC_LIBRARY_PATH ${WARPCTC_LIBRARIES} DIRECTORY) include_directories(${WARPCTC_INCLUDE_DIR}) # For warpctc code to include its # headers. -add_library(warpctc INTERFACE) -add_dependencies(warpctc extern_warpctc) +add_library(warpctc SHARED IMPORTED GLOBAL) +set_target_properties(warpctc PROPERTIES + IMPORTED_LOCATION ${WARPCTC_LIBRARIES} + INTERFACE_INCLUDE_DIRECTORIES ${WARPCTC_INCLUDE_DIR} +) \ No newline at end of file diff --git a/backends/metax_gpu/cmake/warprnnt.cmake b/backends/metax_gpu/cmake/warprnnt.cmake index 54a7ad6be86..527f2e55a1b 100644 --- a/backends/metax_gpu/cmake/warprnnt.cmake +++ b/backends/metax_gpu/cmake/warprnnt.cmake @@ -137,6 +137,8 @@ get_filename_component(WARPRNNT_LIBRARY_PATH ${WARPRNNT_LIBRARIES} DIRECTORY) include_directories(${WARPRNNT_INCLUDE_DIR}) # For warprnnt code to include its # headers. -add_library(warprnnt INTERFACE) -# set_property(TARGET warprnnt PROPERTY IMPORTED_LOCATION ${WARPRNNT_LIBRARIES}) -add_dependencies(warprnnt extern_warprnnt) +add_library(warprnnt SHARED IMPORTED GLOBAL) +set_target_properties(warprnnt PROPERTIES + IMPORTED_LOCATION ${WARPRNNT_LIBRARIES} + INTERFACE_INCLUDE_DIRECTORIES ${WARPRNNT_INCLUDE_DIR} +) \ No newline at end of file diff --git a/backends/metax_gpu/kernels/cuda_kernels/fused_conv2d_add_act_kernel_register.cu b/backends/metax_gpu/kernels/cuda_kernels/fused_conv2d_add_act_kernel_register.cu index ee4f105cbc5..6cf22a1918b 100644 --- a/backends/metax_gpu/kernels/cuda_kernels/fused_conv2d_add_act_kernel_register.cu +++ b/backends/metax_gpu/kernels/cuda_kernels/fused_conv2d_add_act_kernel_register.cu @@ -32,6 +32,7 @@ #include "paddle/phi/common/data_type.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/utils/optional.h" +#include "kernels/metax_context.h" namespace phi { namespace fusion { @@ -308,7 +309,7 @@ class CudnnConvDescManager { int groups, cudnnDataType_t dtype) { auto* desc = new phi::backends::gpu::ConvolutionDescriptor(); - desc->set(dtype, paddings, strides, dilations, true, groups); + desc->set(dtype, paddings, strides, dilations, phi::AllowTF32Cudnn(), groups); return desc; } diff --git a/backends/metax_gpu/kernels/impl/warpctc.h b/backends/metax_gpu/kernels/impl/warpctc.h deleted file mode 100644 index ba5da472ade..00000000000 --- a/backends/metax_gpu/kernels/impl/warpctc.h +++ /dev/null @@ -1,64 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include // NOLINT - -#include "paddle/phi/backends/dynload/dynamic_loader.h" -#include "paddle/phi/common/port.h" -#include "third_party/warpctc/include/ctc.h" - -namespace phi { -namespace dynload { - -extern std::once_flag warpctc_dso_flag; -extern void* warpctc_dso_handle; - -/** - * The following macro definition can generate structs - * (for each function) to dynamic load warpctc routine - * via operator overloading. - */ -#define DYNAMIC_LOAD_WARPCTC_WRAP(__name) \ - struct DynLoad__##__name { \ - template \ - auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \ - using warpctcFunc = decltype(&::__name); \ - std::call_once(warpctc_dso_flag, []() { \ - warpctc_dso_handle = phi::dynload::GetWarpCTCDsoHandle(); \ - }); \ - static void* p_##__name = dlsym(warpctc_dso_handle, #__name); \ - return reinterpret_cast(p_##__name)(args...); \ - } \ - }; \ - extern DynLoad__##__name __name - -#define DECLARE_DYNAMIC_LOAD_WARPCTC_WRAP(__name) \ - DYNAMIC_LOAD_WARPCTC_WRAP(__name) - -#define WARPCTC_ROUTINE_EACH(__macro) \ - __macro(get_warpctc_version); \ - __macro(ctcGetStatusString); \ - __macro(compute_ctc_loss); \ - __macro(compute_ctc_loss_double); \ - __macro(get_workspace_size); \ - __macro(get_workspace_size_double) - -WARPCTC_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_WARPCTC_WRAP); - -#undef DYNAMIC_LOAD_WARPCTC_WRAP - -} // namespace dynload -} // namespace phi diff --git a/backends/metax_gpu/kernels/impl/warpctc_grad_kernel_impl.h b/backends/metax_gpu/kernels/impl/warpctc_grad_kernel_impl.h index 51f4ce86890..dc9bc376e63 100644 --- a/backends/metax_gpu/kernels/impl/warpctc_grad_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/warpctc_grad_kernel_impl.h @@ -16,7 +16,7 @@ #include -#include "kernels/impl/warpctc.h" +#include "third_party/warpctc/include/ctc.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/eigen/common.h" diff --git a/backends/metax_gpu/kernels/impl/warpctc_kernel_impl.h b/backends/metax_gpu/kernels/impl/warpctc_kernel_impl.h index 9794ba1b3c0..e0b15feca03 100644 --- a/backends/metax_gpu/kernels/impl/warpctc_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/warpctc_kernel_impl.h @@ -16,7 +16,7 @@ #include -#include "kernels/impl/warpctc.h" +#include "third_party/warpctc/include/ctc.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/lod_utils.h" #include "paddle/phi/core/tensor_utils.h" @@ -58,7 +58,7 @@ class ComputeCtcLossFunctor { float* costs, void* workspace, ctcOptions options) { - return phi::dynload::compute_ctc_loss(activations, + return compute_ctc_loss(activations, gradients, flat_labels, label_lengths, @@ -84,7 +84,7 @@ class ComputeCtcLossFunctor { double* costs, void* workspace, ctcOptions options) { - return phi::dynload::compute_ctc_loss_double( + return compute_ctc_loss_double( activations, gradients, flat_labels, @@ -141,14 +141,14 @@ class WarpCTCFunctor { ctcStatus_t status = CTC_STATUS_UNKNOWN_ERROR; if (sizeof(T) == 4) { status = - phi::dynload::get_workspace_size(cpu_label_lengths, + get_workspace_size(cpu_label_lengths, cpu_input_lengths, static_cast(sequence_width), static_cast(num_sequences), options_, &workspace_bytes); } else { - status = phi::dynload::get_workspace_size_double( + status = get_workspace_size_double( cpu_label_lengths, cpu_input_lengths, static_cast(sequence_width), @@ -162,7 +162,7 @@ class WarpCTCFunctor { errors::PreconditionNotMet( "warp-ctc [version %d] Error in get_workspace_size: %s", warpctc_version_, - phi::dynload::ctcGetStatusString(status))); + ctcGetStatusString(status))); PADDLE_ENFORCE_GT( workspace_bytes, 0UL, @@ -197,12 +197,12 @@ class WarpCTCFunctor { errors::PreconditionNotMet( "warp-ctc [version %d] Error in get_workspace_size: %s", warpctc_version_, - phi::dynload::ctcGetStatusString(status))); + ctcGetStatusString(status))); } protected: void init(const Context& dev_ctx, const size_t blank) { - warpctc_version_ = phi::dynload::get_warpctc_version(); + warpctc_version_ = get_warpctc_version(); if (dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU || dev_ctx.GetPlace().GetType() == phi::AllocationType::CUSTOM) { diff --git a/backends/metax_gpu/kernels/impl/warprnnt.h b/backends/metax_gpu/kernels/impl/warprnnt.h deleted file mode 100644 index 50b0dfc0efc..00000000000 --- a/backends/metax_gpu/kernels/impl/warprnnt.h +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include // NOLINT - -#include "paddle/phi/backends/dynload/dynamic_loader.h" -#include "paddle/phi/common/port.h" -#include "third_party/warprnnt/include/rnnt.h" - -namespace phi { -namespace dynload { - -extern std::once_flag warprnnt_dso_flag; -extern void* warprnnt_dso_handle; - -/** - * The following macro definition can generate structs - * (for each function) to dynamic load warprnnt routine - * via operator overloading. - */ -#define DYNAMIC_LOAD_WARPRNNT_WRAP(__name) \ - struct DynLoad__##__name { \ - template \ - auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \ - using warprnntFunc = decltype(&::__name); \ - std::call_once(warprnnt_dso_flag, []() { \ - warprnnt_dso_handle = phi::dynload::GetWarpRNNTDsoHandle(); \ - }); \ - static void* p_##__name = dlsym(warprnnt_dso_handle, #__name); \ - return reinterpret_cast(p_##__name)(args...); \ - } \ - }; \ - extern DynLoad__##__name __name - -#define DECLARE_DYNAMIC_LOAD_WARPRNNT_WRAP(__name) \ - DYNAMIC_LOAD_WARPRNNT_WRAP(__name) - -#define WARPRNNT_ROUTINE_EACH(__macro) \ - __macro(get_warprnnt_version); \ - __macro(rnntGetStatusString); \ - __macro(compute_rnnt_loss); \ - __macro(compute_rnnt_loss_fp64); \ - __macro(get_rnnt_workspace_size); - -WARPRNNT_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_WARPRNNT_WRAP); - -#undef DYNAMIC_LOAD_WARPRNNT_WRAP - -} // namespace dynload -} // namespace phi diff --git a/backends/metax_gpu/kernels/impl/warprnnt_kernel_impl.h b/backends/metax_gpu/kernels/impl/warprnnt_kernel_impl.h index bb4311f5912..457fdcb9bff 100644 --- a/backends/metax_gpu/kernels/impl/warprnnt_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/warprnnt_kernel_impl.h @@ -16,7 +16,7 @@ #include -#include "kernels/impl/warprnnt.h" +#include "third_party/warprnnt/include/rnnt.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/empty_kernel.h" @@ -55,7 +55,7 @@ class ComputeRnntLossFunctor { float* costs, void* workspace, rnntOptions options) { - return phi::dynload::compute_rnnt_loss(activations, + return compute_rnnt_loss(activations, gradients, label, label_lengths, @@ -81,7 +81,7 @@ class ComputeRnntLossFunctor { double* costs, void* workspace, rnntOptions options) { - return phi::dynload::compute_rnnt_loss_fp64(activations, + return compute_rnnt_loss_fp64(activations, gradients, label, label_lengths, @@ -149,7 +149,7 @@ class WarpRNNTFunctor { } size_t workspace_bytes = 0; - status = phi::dynload::get_rnnt_workspace_size( + status = get_rnnt_workspace_size( maxT, maxU, B, gpu, &workspace_bytes, sizeof(T)); PADDLE_ENFORCE_EQ( @@ -158,7 +158,7 @@ class WarpRNNTFunctor { errors::PreconditionNotMet( "warp-rnnt [version %d] Error in get_rnnt_workspace_size: %s", warprnnt_version_, - phi::dynload::rnntGetStatusString(status))); + rnntGetStatusString(status))); PADDLE_ENFORCE_GT( workspace_bytes, 0UL, @@ -190,7 +190,7 @@ class WarpRNNTFunctor { errors::PreconditionNotMet( "warp-rnnt [version %d] Error in get_workspace_size: %s", warprnnt_version_, - phi::dynload::rnntGetStatusString(status))); + rnntGetStatusString(status))); } protected: @@ -200,7 +200,7 @@ class WarpRNNTFunctor { const size_t blank, const float fastemit_lambda, const int num_threads) { - warprnnt_version_ = phi::dynload::get_warprnnt_version(); + warprnnt_version_ = get_warprnnt_version(); options_.maxT = maxT; options_.maxU = maxU; diff --git a/backends/metax_gpu/kernels/metax_context.cc b/backends/metax_gpu/kernels/metax_context.cc index 4df4d88b0b4..f0c92f00565 100644 --- a/backends/metax_gpu/kernels/metax_context.cc +++ b/backends/metax_gpu/kernels/metax_context.cc @@ -15,7 +15,25 @@ #include "kernels/metax_context.h" namespace phi { -bool AllowTF32Cudnn() { return false; } +const bool allow_tf32_cublas = []() -> bool { + const char* v = std::getenv("ALLOW_TF32_CUBLAS"); + if (v) { + return std::atoi(v); + } + return false; +}(); + +const bool allow_tf32_cudnn = []() -> bool { + const char* v = std::getenv("ALLOW_TF32_CUDNN"); + if (v) { + return std::atoi(v); + } + return false; +}(); + +bool AllowTF32Cublas() { return allow_tf32_cublas; } +bool AllowTF32Cudnn() { return allow_tf32_cudnn; } + void DnnWorkspaceHandle::RunFuncSync( const std::function& cudnn_func, size_t required_workspace_bytes, diff --git a/backends/metax_gpu/kernels/metax_context.h b/backends/metax_gpu/kernels/metax_context.h index 5974aadcc41..683a6df7017 100644 --- a/backends/metax_gpu/kernels/metax_context.h +++ b/backends/metax_gpu/kernels/metax_context.h @@ -128,6 +128,7 @@ inline void InitCusolverDnHandle(cusolverDnHandle_t* handle, } } +bool AllowTF32Cublas(); bool AllowTF32Cudnn(); inline cusolverDnHandle_t GetCusolverDnHandle(gpuStream_t stream, Place place) { std::call_once(flag_cusolver_dn_, [&]() { From 1abea54b5448a01a83115ca1e09fa90132df9e59 Mon Sep 17 00:00:00 2001 From: jiaxinWang-metax <189149612@qq.com> Date: Tue, 16 Sep 2025 15:23:46 +0800 Subject: [PATCH 02/12] modify conv for tf32 and fp32 --- .../cuda_kernels/conv_grad_kernel_register.cu | 1555 ----------------- .../fused_conv2d_add_act_kernel_register.cu | 1 - .../kernels/gpudnn/conv_kernel_register.cu | 2 +- .../kernels/gpudnn/conv_transpose_kernel.cu | 2 +- 4 files changed, 2 insertions(+), 1558 deletions(-) delete mode 100644 backends/metax_gpu/kernels/cuda_kernels/conv_grad_kernel_register.cu diff --git a/backends/metax_gpu/kernels/cuda_kernels/conv_grad_kernel_register.cu b/backends/metax_gpu/kernels/cuda_kernels/conv_grad_kernel_register.cu deleted file mode 100644 index 885137675b4..00000000000 --- a/backends/metax_gpu/kernels/cuda_kernels/conv_grad_kernel_register.cu +++ /dev/null @@ -1,1555 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "glog/logging.h" -#include "kernels/gpudnn/conv_gpudnn.h" -#include "paddle/phi/backends/context_pool.h" -#include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/conv_grad_kernel.h" -#ifdef PADDLE_WITH_HIP -#include "paddle/phi/kernels/gpudnn/conv_miopen_helper.h" -#else -#include "kernels/gpudnn/conv_cudnn_v7.h" -#endif - -#include "kernels/impl/conv_cudnn_impl.h" -#include "paddle/phi/backends/gpu/cuda/cudnn_workspace_helper.h" -#include "paddle/phi/common/bfloat16.h" -#include "paddle/phi/common/float16.h" -#include "paddle/phi/kernels/cpu/conv_util.h" -#include "paddle/phi/kernels/full_kernel.h" -#include "paddle/phi/kernels/funcs/batch_norm_utils.h" -#include "paddle/phi/kernels/funcs/padding.h" -#ifdef PADDLE_WITH_CUDNN_FRONTEND -// clang-format off -#include "paddle/phi/backends/dynload/cudnn_frontend.h" -#include "paddle/phi/kernels/gpudnn/conv_cudnn_frontend.h" -// clang-format on -#endif - -namespace phi { - -template -void ConvCudnnGradKernelImplV7( - const DenseTensor* transformed_input, - const DenseTensor* transformed_filter_channel, - const DenseTensor* transformed_output_grad_channel, - DenseTensor* input_grad, - DenseTensor* filter_grad, - const Context& dev_ctx, - const std::vector& strides, - const std::vector& padding_common, - const std::vector& dilations, - phi::backends::gpu::DataLayout compute_format, - phi::backends::gpu::DataLayout layout, - bool use_addto, - bool exhaustive_search, - bool deterministic, - int groups, - DenseTensor* transformed_input_grad, - DenseTensor* transformed_filter_grad_channel) { - const T* input_data = transformed_input->data(); - const T* output_grad_data = transformed_output_grad_channel->data(); - const T* filter_data = transformed_filter_channel->data(); - T* filter_grad_data = nullptr; - T* input_grad_data = nullptr; - T* transformed_input_grad_data = nullptr; - - // auto handle = dev_ctx.cudnn_handle(); - auto handle = GetDnnHandle(dev_ctx.stream(), dev_ctx.GetPlace()); - // auto workspace_handle = dev_ctx.cudnn_workspace_handle(); - auto workspace_handle = GetDnnWorkspace( - const_cast(&(dev_ctx.GetAllocator())), dev_ctx.stream()); - auto dtype = phi::backends::gpu::CudnnDataType::type; - auto layout_tensor = phi::backends::gpu::GetCudnnTensorFormat(layout); - - ConvArgs args1{handle, - transformed_input_grad, - transformed_filter_channel, - transformed_output_grad_channel, - strides, - padding_common, - dilations, - dtype, - groups, - layout}; - ConvArgs args2{handle, - transformed_input, - transformed_filter_grad_channel, - transformed_output_grad_channel, - strides, - padding_common, - dilations, - dtype, - groups, - layout}; - - int i_n, i_c, i_d, i_h, i_w; - int o_n, o_c, o_d, o_h, o_w; - if (compute_format == phi::backends::gpu::DataLayout::kNHWC) { - GetNCDHW(transformed_input->dims(), - phi::backends::gpu::DataLayout::kNHWC, - &i_n, - &i_c, - &i_d, - &i_h, - &i_w); - GetNCDHW(transformed_output_grad_channel->dims(), - phi::backends::gpu::DataLayout::kNHWC, - &o_n, - &o_c, - &o_d, - &o_h, - &o_w); - } else { - GetNCDHW(transformed_input->dims(), - phi::backends::gpu::DataLayout::kNCHW, - &i_n, - &i_c, - &i_d, - &i_h, - &i_w); - GetNCDHW(transformed_output_grad_channel->dims(), - phi::backends::gpu::DataLayout::kNCHW, - &o_n, - &o_c, - &o_d, - &o_h, - &o_w); - } - - int group_offset_in = i_c / groups * i_h * i_w * i_d; - int group_offset_out = o_c / groups * o_h * o_w * o_d; - int group_offset_filter = transformed_filter_channel->numel() / groups; - -// ------------------- cudnn backward algorithm --------------------- -#ifdef PADDLE_WITH_HIP - SearchResult bwd_result; - SearchResult filter_result; -#else - SearchResult bwd_result; - SearchResult filter_result; -#endif - size_t workspace_size = 0; - int iwo_groups = groups; - int c_groups = 1; - -#if defined(PADDLE_WITH_HIP) || CUDNN_VERSION_MIN(7, 0, 1) - iwo_groups = 1; - c_groups = groups; - groups = 1; -#endif - - if (input_grad) { - // ------------------- cudnn descriptors --------------------- - input_grad_data = input_grad->data(); - transformed_input_grad_data = transformed_input_grad->data(); - - args1.idesc.set(*transformed_input_grad, layout_tensor); - args1.wdesc.set(*transformed_filter_channel, layout_tensor, iwo_groups); - args1.odesc.set(*transformed_output_grad_channel, layout_tensor); - args1.cdesc.set(dtype, padding_common, strides, dilations, true, c_groups); - -#ifdef PADDLE_WITH_HIP - using search1 = SearchAlgorithm; - workspace_size = std::max(workspace_size, search1::GetWorkspaceSize(args1)); - bwd_result.algo = search1::Find( - args1, exhaustive_search, deterministic, workspace_size, dev_ctx); -#else - using search1 = SearchAlgorithm; - bwd_result = - search1::Find(dev_ctx, args1, exhaustive_search, deterministic); - workspace_size = std::max(workspace_size, bwd_result.workspace_size); -#endif - } - - if (filter_grad) { - // ------------------- cudnn descriptors --------------------- - filter_grad_data = transformed_filter_grad_channel->data(); - - args2.idesc.set(*transformed_input, layout_tensor); - args2.wdesc.set( - *transformed_filter_grad_channel, layout_tensor, iwo_groups); - args2.odesc.set(*transformed_output_grad_channel, layout_tensor); - args2.cdesc.set(dtype, padding_common, strides, dilations, true, c_groups); -#ifdef PADDLE_WITH_HIP - using search2 = SearchAlgorithm; - workspace_size = std::max(workspace_size, search2::GetWorkspaceSize(args2)); - filter_result.algo = search2::Find( - args2, exhaustive_search, deterministic, workspace_size, dev_ctx); -#else - using search2 = SearchAlgorithm; - filter_result = - search2::Find(dev_ctx, args2, exhaustive_search, deterministic); - VLOG(3) << "filter algo: " << filter_result.algo << ", time " - << filter_result.time; - workspace_size = std::max(workspace_size, filter_result.workspace_size); -#endif - } - - // ------------------- cudnn conv backward data --------------------- - ScalingParamType alpha = 1.0f; -#ifdef PADDLE_WITH_HIP - // MIOPEN ONLY support beta to be 0.0f - ScalingParamType beta = 0.0f; -#else - ScalingParamType beta = use_addto ? 1.0f : 0.0f; - -#endif - VLOG(4) << "Conv_grad: use_addto = " << use_addto; - - if (input_grad) { -// When beta is 0, it is unnecessary to reset input_grad. -// When beta is 1, the output cannot be reset since addt strategy used. -#ifdef PADDLE_WITH_HIP - if (use_addto) { - DenseTensor temp_tensor(transformed_input_grad->type()); - temp_tensor.Resize(transformed_input_grad->dims()); - T* temp_tensor_data = dev_ctx.template Alloc(&temp_tensor); - workspace_handle.RunFunc( - [&](void* cudnn_workspace_ptr) { - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::miopenConvolutionBackwardData(handle, - &alpha, - args1.odesc.desc(), - output_grad_data, - args1.wdesc.desc(), - filter_data, - args1.cdesc.desc(), - bwd_result.algo, - &beta, - args1.idesc.desc(), - temp_tensor_data, - cudnn_workspace_ptr, - workspace_size)); - }, - workspace_size); - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::miopenOpTensor(handle, - miopenTensorOpAdd, - &alpha, - args1.idesc.desc(), - transformed_input_grad_data, - &alpha, - args1.idesc.desc(), - temp_tensor_data, - &beta, - args1.idesc.desc(), - transformed_input_grad_data)); - } else { - workspace_handle.RunFunc( - [&](void* cudnn_workspace_ptr) { - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::miopenConvolutionBackwardData( - handle, - &alpha, - args1.odesc.desc(), - output_grad_data, - args1.wdesc.desc(), - filter_data, - args1.cdesc.desc(), - bwd_result.algo, - &beta, - args1.idesc.desc(), - transformed_input_grad_data, - cudnn_workspace_ptr, - workspace_size)); - }, - workspace_size); - } -#else - ConvRunner::Apply(dev_ctx, - args1, - bwd_result, - output_grad_data, - filter_data, - transformed_input_grad_data, - groups, - group_offset_in, - group_offset_filter, - group_offset_out, - workspace_size, - &workspace_handle, - use_addto); -#endif - } - - // ------------------- cudnn conv backward filter --------------------- - if (filter_grad) { -// Because beta is zero, it is unnecessary to reset filter_grad. -#ifdef PADDLE_WITH_HIP - workspace_handle.RunFunc( - [&](void* cudnn_workspace_ptr) { - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::miopenConvolutionBackwardWeights( - handle, - &alpha, - args2.odesc.desc(), - output_grad_data, - args2.idesc.desc(), - input_data, - args2.cdesc.desc(), - filter_result.algo, - &beta, - args2.wdesc.desc(), - filter_grad_data, - cudnn_workspace_ptr, - workspace_size)); - }, - workspace_size); -#else - ConvRunner::Apply(dev_ctx, - args2, - filter_result, - output_grad_data, - input_data, - filter_grad_data, - groups, - group_offset_in, - group_offset_filter, - group_offset_out, - workspace_size, - &workspace_handle, - false); -#endif - } -} - -#ifdef PADDLE_WITH_CUDNN_FRONTEND -template -void ConvCudnnGradKernelImplV8( - const DenseTensor* transformed_input, - const DenseTensor* transformed_filter_channel, - const DenseTensor* transformed_output_grad_channel, - DenseTensor* input_grad, - DenseTensor* filter_grad, - const Context& dev_ctx, - const std::vector& strides, - const std::vector& padding_common, - const std::vector& dilations, - phi::backends::gpu::DataLayout layout, - bool use_addto, - bool exhaustive_search, - bool deterministic, - int groups, - DenseTensor* transformed_input_grad, - DenseTensor* transformed_filter_grad_channel) { - PADDLE_ENFORCE_EQ( - groups, - 1, - common::errors::Unimplemented( - "Group concolution using CUDNNv8 API is unsupported for now")); - - cudnnHandle_t handle = const_cast( - GetDnnHandle(dev_ctx.stream(), dev_ctx.GetPlace());); - // auto workspace_handle = dev_ctx.cudnn_workspace_handle(); - auto workspace_handle = GetDnnWorkspace( - const_cast(&(dev_ctx.GetAllocator())), dev_ctx.stream()); - auto dtype = phi::backends::gpu::CudnnDataType::type; - auto layout_format = phi::backends::gpu::GetCudnnTensorFormat(layout); - - if (input_grad) { - CudnnConvBwdDataV8(transformed_output_grad_channel, - transformed_filter_channel, - handle, - &workspace_handle, - strides, - padding_common, - dilations, - dtype, - layout_format, - use_addto, - exhaustive_search, - deterministic, - transformed_input_grad); - } - - if (filter_grad) { - CudnnConvBwdFilterV8(transformed_input, - transformed_output_grad_channel, - handle, - &workspace_handle, - strides, - padding_common, - dilations, - dtype, - layout_format, - use_addto, - exhaustive_search, - deterministic, - transformed_filter_grad_channel); - } -} -#endif - -template -void ConvCudnnGradKernel(const Context& dev_ctx, - const DenseTensor& input, - const DenseTensor& filter, - const DenseTensor& output_grad, - const std::vector& strides_t, - const std::vector& paddings_t, - const std::string& padding_algorithm, - const std::vector& dilations_t, - int groups, - const std::string& data_format, - DenseTensor* input_grad, - DenseTensor* filter_grad) { - // 0-size - if (input.numel() == 0 || filter.numel() == 0) { - if (input_grad) dev_ctx.template Alloc(input_grad); - if (filter_grad) { - phi::Full( - dev_ctx, - phi::IntArray(common::vectorize(filter_grad->dims())), - 0, - filter_grad); - } - return; - } - if (input_grad) { - dev_ctx.template Alloc(input_grad); - } - if (filter_grad) { - dev_ctx.template Alloc(filter_grad); - } - - // bool has_use_addto = dev_ctx.HasDnnAttr("use_addto"); - bool has_use_addto = "true"; - VLOG(4) << "GPUContext contains `use_addto`: " << has_use_addto; - // bool use_addto = has_use_addto - // ? PADDLE_GET_CONST(bool, "true") - // : false; - bool use_addto = "true"; - std::vector dilations = dilations_t; - std::vector strides = strides_t; - std::vector paddings = paddings_t; - - // bool has_exhaustive_search = dev_ctx.HasDnnAttr("exhaustive_search"); - bool has_exhaustive_search = "true"; - VLOG(4) << "GPUContext contains `exhaustive_search`: " - << has_exhaustive_search; - // bool exhaustive_search_attr = - // has_exhaustive_search - // ? PADDLE_GET_CONST(bool, "true") - // : false; - bool exhaustive_search_attr = "true"; - bool exhaustive_search = - FLAGS_cudnn_exhaustive_search || exhaustive_search_attr; - bool deterministic = FLAGS_cudnn_deterministic; - auto exhaustive_deterministic = exhaustive_search && deterministic; - PADDLE_ENFORCE_EQ(exhaustive_deterministic, - false, - common::errors::InvalidArgument( - "Can't set exhaustive_search True and " - "FLAGS_cudnn_deterministic True at same time.")); - - const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC"); - - auto dtype = phi::backends::gpu::CudnnDataType::type; - -#ifdef PADDLE_WITH_HIP - // HIP MIOPEN ONLY SUPPORT NCHW format - auto compute_format = phi::backends::gpu::DataLayout::kNCHW; -#else -#if CUDNN_VERSION_MIN(8, 1, 0) - const bool compute_in_nhwc = - (dtype == CUDNN_DATA_HALF || dtype == CUDNN_DATA_BFLOAT16) && - IsVoltaOrLater(dev_ctx); -#else - const bool compute_in_nhwc = - dtype == CUDNN_DATA_HALF && IsVoltaOrLater(dev_ctx); -#endif - auto compute_format = compute_in_nhwc && channel_last - ? phi::backends::gpu::DataLayout::kNHWC - : phi::backends::gpu::DataLayout::kNCHW; -#endif - VLOG(3) << "Compute ConvGradOp with cuDNN:" - << " data_format=" << data_format << " compute_format=" - << (compute_format == phi::backends::gpu::DataLayout::kNHWC ? "NHWC" - : "NCHW"); - - // transform Tensor - DenseTensor transformed_input_channel(input.type()); - DenseTensor transformed_output_grad_channel(output_grad.type()); - DenseTensor transformed_input_grad_channel(input.type()); - DenseTensor transformed_filter_channel(filter.type()); - DenseTensor transformed_filter_grad_channel(filter.type()); - - if (channel_last && compute_format == phi::backends::gpu::DataLayout::kNCHW) { - VLOG(3) << "Transform input, output_grad, input_grad and tensor from " - "NHWC to NCHW."; - ResizeToChannelFirst( - dev_ctx, &input, &transformed_input_channel); - TransToChannelFirst( - dev_ctx, &input, &transformed_input_channel); - - ResizeToChannelFirst( - dev_ctx, &output_grad, &transformed_output_grad_channel); - TransToChannelFirst( - dev_ctx, &output_grad, &transformed_output_grad_channel); - - if (input_grad) { - ResizeToChannelFirst( - dev_ctx, input_grad, &transformed_input_grad_channel); - // NOTE(zhiqiu): If inplace_addto strategy is enabled, we need to copy - // the data of input_grad to transformed_input_grad_channel. - if (use_addto) { - TransToChannelFirst( - dev_ctx, input_grad, &transformed_input_grad_channel); - } - } - } else { - transformed_input_channel.ShareDataWith(input); - transformed_output_grad_channel.ShareDataWith(output_grad); - if (input_grad) { - transformed_input_grad_channel.ShareDataWith(*input_grad); - } - } - - if (compute_format == phi::backends::gpu::DataLayout::kNHWC) { - VLOG(3) << "Transform filter and filter_grad tensor from NCHW to NHWC."; - ResizeToChannelLast( - dev_ctx, &filter, &transformed_filter_channel); - TransToChannelLast( - dev_ctx, &filter, &transformed_filter_channel); - - if (filter_grad) { - ResizeToChannelLast( - dev_ctx, filter_grad, &transformed_filter_grad_channel); - } - } else { - transformed_filter_channel.ShareDataWith(filter); - if (filter_grad) { - transformed_filter_grad_channel.ShareDataWith(*filter_grad); - } - } - - // update paddings - auto in_dims = transformed_input_channel.dims(); - auto filter_dims = transformed_filter_channel.dims(); - DDim in_data_dims; - DDim filter_data_dims; - if (compute_format == phi::backends::gpu::DataLayout::kNCHW) { - in_data_dims = slice_ddim(in_dims, 2, in_dims.size()); - filter_data_dims = slice_ddim(filter_dims, 2, filter_dims.size()); - } else { - in_data_dims = slice_ddim(in_dims, 1, in_dims.size() - 1); - filter_data_dims = slice_ddim(filter_dims, 1, filter_dims.size() - 1); - } - std::vector ksize = common::vectorize(filter_data_dims); - UpdatePaddingAndDilation( - &paddings, &dilations, padding_algorithm, in_data_dims, strides, ksize); - - // cuDNN only supports padding the same amount on every dimension. - // So we create a new padded input tensor. - int data_dim = strides.size(); // 2d or 3d - bool is_sys_pad = funcs::IsSymmetricPadding(paddings, data_dim); - Tensor transformed_input(input.type()); - Tensor transformed_input_grad(input.type()); - std::vector padding_common(data_dim, 0); - std::vector input_pad(transformed_input_channel.dims().size() * 2, 0); - - if (!is_sys_pad) { - // get pad - std::vector padding_diff(data_dim); - std::vector new_input_shape_vec(data_dim + 2); - new_input_shape_vec[0] = transformed_input_channel.dims()[0]; - if (compute_format == phi::backends::gpu::DataLayout::kNCHW) { - new_input_shape_vec[1] = transformed_input_channel.dims()[1]; - } else { - new_input_shape_vec[data_dim + 1] = - transformed_input_channel.dims()[data_dim + 1]; - } - - for (size_t i = 0; i < data_dim; ++i) { - padding_diff[i] = std::abs(paddings[2 * i] - paddings[2 * i + 1]); - padding_common[i] = std::min(paddings[2 * i], paddings[2 * i + 1]); - if (compute_format == phi::backends::gpu::DataLayout::kNCHW) { - new_input_shape_vec[i + 2] = - transformed_input_channel.dims()[i + 2] + padding_diff[i]; - } else { - new_input_shape_vec[i + 1] = - transformed_input_channel.dims()[i + 1] + padding_diff[i]; - } - if (compute_format == phi::backends::gpu::DataLayout::kNCHW) { - input_pad[2 * i + 4] = paddings[2 * i] - padding_common[i]; - input_pad[2 * i + 4 + 1] = paddings[2 * i + 1] - padding_common[i]; - } else { - input_pad[2 * i + 2] = paddings[2 * i] - padding_common[i]; - input_pad[2 * i + 2 + 1] = paddings[2 * i + 1] - padding_common[i]; - } - } - DDim new_input_shape(common::make_ddim(new_input_shape_vec)); - transformed_input.Resize(new_input_shape); - dev_ctx.template Alloc(&transformed_input); - - transformed_input_grad.Resize(new_input_shape); - - if (input_grad) { - dev_ctx.template Alloc(&transformed_input_grad); - } - // pad for input - const int rank = transformed_input_channel.dims().size(); - T pad_value(0.0); - switch (rank) { - case 4: { - funcs::PadFunction(dev_ctx, - input_pad, - transformed_input_channel, - pad_value, - &transformed_input); - } break; - case 5: { - funcs::PadFunction(dev_ctx, - input_pad, - transformed_input_channel, - pad_value, - &transformed_input); - } break; - default: - PADDLE_THROW(common::errors::InvalidArgument( - "ConvOp only support tensors with 4 or 5 dimensions.")); - } - } else { - transformed_input.ShareDataWith(transformed_input_channel); - if (input_grad) { - transformed_input_grad.ShareDataWith(transformed_input_grad_channel); - } - if (paddings.size() == data_dim) { - for (size_t i = 0; i < data_dim; ++i) { - padding_common[i] = paddings[i]; - } - } else { - for (size_t i = 0; i < data_dim; ++i) { - padding_common[i] = paddings[2 * i]; - } - } - } - phi::backends::gpu::DataLayout layout = - compute_format == phi::backends::gpu::DataLayout::kNHWC - ? phi::backends::gpu::DataLayout::kNHWC - : phi::backends::gpu::DataLayout::kNCHW; - if (transformed_input.dims().size() == 5) { - layout = compute_format == phi::backends::gpu::DataLayout::kNHWC - ? phi::backends::gpu::DataLayout::kNDHWC - : phi::backends::gpu::DataLayout::kNCDHW; - } - CUDNN_ENFORCE_TENSOR_SIZE_SUPPORTED(transformed_input); - CUDNN_ENFORCE_TENSOR_SIZE_SUPPORTED(transformed_filter_channel); - CUDNN_ENFORCE_TENSOR_SIZE_SUPPORTED(transformed_output_grad_channel); - -#ifdef PADDLE_WITH_CUDNN_FRONTEND - if (dynload::IsCudnnFrontendEnabled() && (groups == 1)) - ConvCudnnGradKernelImplV8(&transformed_input, - &transformed_filter_channel, - &transformed_output_grad_channel, - input_grad, - filter_grad, - dev_ctx, - strides, - padding_common, - dilations, - layout, - use_addto, - exhaustive_search, - deterministic, - groups, - &transformed_input_grad, - &transformed_filter_grad_channel); - else - ConvCudnnGradKernelImplV7(&transformed_input, - &transformed_filter_channel, - &transformed_output_grad_channel, - input_grad, - filter_grad, - dev_ctx, - strides, - padding_common, - dilations, - compute_format, - layout, - use_addto, - exhaustive_search, - deterministic, - groups, - &transformed_input_grad, - &transformed_filter_grad_channel); -#else - ConvCudnnGradKernelImplV7(&transformed_input, - &transformed_filter_channel, - &transformed_output_grad_channel, - input_grad, - filter_grad, - dev_ctx, - strides, - padding_common, - dilations, - compute_format, - layout, - use_addto, - exhaustive_search, - deterministic, - groups, - &transformed_input_grad, - &transformed_filter_grad_channel); -#endif - - if (input_grad) { - if (!is_sys_pad) { - std::vector starts(transformed_input_channel.dims().size(), 0); - std::vector axes(transformed_input_channel.dims().size(), 0); - - for (size_t i = 0; i < transformed_input_channel.dims().size(); ++i) { - starts[i] = input_pad[2 * i]; - axes[i] = i; - } - - dev_ctx.template Alloc(&transformed_input_grad_channel); - if (transformed_input_channel.dims().size() == 4) { - RemovePaddingSlice(dev_ctx, - &transformed_input_grad, - &transformed_input_grad_channel, - starts, - axes); - } else { - RemovePaddingSlice(dev_ctx, - &transformed_input_grad, - &transformed_input_grad_channel, - starts, - axes); - } - } - - if (channel_last && - compute_format == phi::backends::gpu::DataLayout::kNCHW) { - TransToChannelLast( - dev_ctx, &transformed_input_grad_channel, input_grad); - } - } - - if (filter_grad) { - if (compute_format == phi::backends::gpu::DataLayout::kNHWC) { - TransToChannelFirst( - dev_ctx, &transformed_filter_grad_channel, filter_grad); - } - } -} - -template -void Conv3DCudnnGradKernel(const Context& dev_ctx, - const DenseTensor& input, - const DenseTensor& filter, - const DenseTensor& out_grad, - const std::vector& strides, - const std::vector& paddings, - const std::string& padding_algorithm, - int groups, - const std::vector& dilations, - const std::string& data_format, - DenseTensor* input_grad, - DenseTensor* filter_grad) { - ConvCudnnGradKernel(dev_ctx, - input, - filter, - out_grad, - strides, - paddings, - padding_algorithm, - dilations, - groups, - data_format, - input_grad, - filter_grad); -} - -template -void ConvCudnnGradGradKernel( - const Context& dev_ctx, - const DenseTensor& input, - const DenseTensor& filter, - const DenseTensor& out_grad, - const paddle::optional& input_grad_grad, - const paddle::optional& filter_grad_grad, - const std::vector& strides, - const std::vector& paddings_t, - const std::string& padding_algorithm, - const std::vector& dilations_t, - int groups, - const std::string& data_format, - DenseTensor* input_grad, - DenseTensor* filter_grad, - DenseTensor* out_grad_grad) { - auto X = &input; - auto W = &filter; - auto dO = &out_grad; - auto ddX = input_grad_grad.get_ptr(); - auto ddW = filter_grad_grad.get_ptr(); - - auto ddO = out_grad_grad; - auto dW = filter_grad; - auto dX = input_grad; - if (ddO) { - dev_ctx.template Alloc(ddO); - phi::funcs::SetConstant set_zero; - set_zero(dev_ctx, ddO, static_cast(0)); - } - if (dW) { - dev_ctx.template Alloc(dW); - } - if (dX) { - dev_ctx.template Alloc(dX); - } - - // const T* x = X->data(); - const T* dy = dO->data(); - const T* w = W->data(); - - const T* ddx = nullptr; - const T* ddw = nullptr; - T *dw, *dx, *ddy; - dw = dx = ddy = nullptr; - T* transformed_dx = nullptr; - std::vector dilations = dilations_t; - - // bool has_exhaustive_search = dev_ctx.HasDnnAttr("exhaustive_search"); - // VLOG(4) << "GPUContext contains `exhaustive_search`: " - // << has_exhaustive_search; - // bool exhaustive_search_attr = - // has_exhaustive_search - // ? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("exhaustive_search")) - // : false; - bool exhaustive_search_attr = "true"; - bool exhaustive_search = - FLAGS_cudnn_exhaustive_search || exhaustive_search_attr; - bool deterministic = FLAGS_cudnn_deterministic; - auto exhaustive_deterministic = exhaustive_search && deterministic; - PADDLE_ENFORCE_EQ(exhaustive_deterministic, - false, - common::errors::InvalidArgument( - "Can't set exhaustive_search True and " - "FLAGS_cudnn_deterministic True at same time.")); - - std::vector paddings = paddings_t; - - const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC"); - - // transform Tensors to channel first----------- - DenseTensor transformed_X_channel(X->type()); - DenseTensor transformed_dO_channel(dO->type()); - DenseTensor transformed_ddX_channel(X->type()); - - DenseTensor transformed_ddO_channel(dO->type()); - DenseTensor transformed_dX_channel(X->type()); - - if (channel_last) { - ResizeToChannelFirst(dev_ctx, X, &transformed_X_channel); - TransToChannelFirst(dev_ctx, X, &transformed_X_channel); - - ResizeToChannelFirst(dev_ctx, dO, &transformed_dO_channel); - TransToChannelFirst(dev_ctx, dO, &transformed_dO_channel); - - if (ddX) { - ResizeToChannelFirst(dev_ctx, ddX, &transformed_ddX_channel); - TransToChannelFirst(dev_ctx, ddX, &transformed_ddX_channel); - } - - if (ddO) { - ResizeToChannelFirst(dev_ctx, ddO, &transformed_ddO_channel); - } - if (dX) { - ResizeToChannelFirst(dev_ctx, dX, &transformed_dX_channel); - dev_ctx.template Alloc(&transformed_dX_channel); - } - - } else { - transformed_X_channel = *X; - transformed_dO_channel = *dO; - if (ddX) { - transformed_ddX_channel = *ddX; - } - if (ddO) { - transformed_ddO_channel.ShareDataWith(*ddO); - } - if (dX) { - transformed_dX_channel.ShareDataWith(*dX); - } - } - - auto in_dims = transformed_X_channel.dims(); - auto filter_dims = W->dims(); - DDim in_data_dims = slice_ddim(in_dims, 2, in_dims.size()); - DDim filter_data_dims = slice_ddim(filter_dims, 2, filter_dims.size()); - std::vector ksize = common::vectorize(filter_data_dims); - UpdatePaddingAndDilation( - &paddings, &dilations, padding_algorithm, in_data_dims, strides, ksize); - - int data_dim = strides.size(); // 2d or 3d - bool is_sys_pad = funcs::IsSymmetricPadding(paddings, data_dim); - DenseTensor transformed_X(X->type()); - DenseTensor transformed_ddX(X->type()); - - DenseTensor transformed_dX(X->type()); - - std::vector padding_common(data_dim, 0); - std::vector input_pad(X->dims().size() * 2, 0); - - if (!is_sys_pad) { - // get pad - std::vector padding_diff(data_dim); - std::vector new_input_shape_vec(data_dim + 2); - new_input_shape_vec[0] = transformed_X_channel.dims()[0]; - new_input_shape_vec[1] = transformed_X_channel.dims()[1]; - - for (size_t i = 0; i < data_dim; ++i) { - padding_diff[i] = std::abs(paddings[2 * i] - paddings[2 * i + 1]); - padding_common[i] = std::min(paddings[2 * i], paddings[2 * i + 1]); - new_input_shape_vec[i + 2] = - transformed_X_channel.dims()[i + 2] + padding_diff[i]; - input_pad[2 * i + 4] = paddings[2 * i] - padding_common[i]; - input_pad[2 * i + 4 + 1] = paddings[2 * i + 1] - padding_common[i]; - } - DDim new_input_shape(common::make_ddim(new_input_shape_vec)); - transformed_X.Resize(new_input_shape); - transformed_ddX.Resize(new_input_shape); - transformed_dX.Resize(new_input_shape); - - dev_ctx.template Alloc(&transformed_X); - - if (ddX) { - dev_ctx.template Alloc(&transformed_ddX); - } - if (dX) { - dev_ctx.template Alloc(&transformed_dX); - } - - // pad for input - const int rank = X->dims().size(); - T pad_value(0.0); - switch (rank) { - case 4: { - funcs::PadFunction(dev_ctx, - input_pad, - transformed_X_channel, - pad_value, - &transformed_X); - if (ddX) { - funcs::PadFunction(dev_ctx, - input_pad, - transformed_ddX_channel, - pad_value, - &transformed_ddX); - } - } break; - case 5: { - funcs::PadFunction(dev_ctx, - input_pad, - transformed_X_channel, - pad_value, - &transformed_X); - if (ddX) { - funcs::PadFunction(dev_ctx, - input_pad, - transformed_ddX_channel, - pad_value, - &transformed_ddX); - } - } break; - default: - PADDLE_THROW(common::errors::InvalidArgument( - "ConvOp only support tensors with 4 or 5 dimensions.")); - } - - } else { - transformed_X.ShareDataWith(transformed_X_channel); - if (ddX) { - transformed_ddX.ShareDataWith(transformed_ddX_channel); - } - if (dX) { - transformed_dX.ShareDataWith(transformed_dX_channel); - } - - if (paddings.size() == data_dim) { - for (size_t i = 0; i < data_dim; ++i) { - padding_common[i] = paddings[i]; - } - } else { - for (size_t i = 0; i < data_dim; ++i) { - padding_common[i] = paddings[2 * i]; - } - } - } - - const T* x = transformed_X.data(); - - int iwo_group = groups; - int c_group = 1; -#if defined(PADDLE_WITH_HIP) || CUDNN_VERSION_MIN(7, 0, 1) - iwo_group = 1; - c_group = groups; - groups = 1; -#endif - auto dtype = phi::backends::gpu::CudnnDataType::type; - - // auto handle = dev_ctx.cudnn_handle(); - auto handle = GetDnnHandle(dev_ctx.stream(), dev_ctx.GetPlace()); - auto layout = phi::backends::gpu::GetCudnnTensorFormat( - phi::backends::gpu::DataLayout::kNCHW); - - ConvArgs args1{handle, - &transformed_ddX, - W, - &transformed_ddO_channel, - strides, - padding_common, - dilations, - dtype, - groups, - phi::backends::gpu::DataLayout::kNCHW}; - ConvArgs args2{handle, - &transformed_X, - ddW, - &transformed_ddO_channel, - strides, - padding_common, - dilations, - dtype, - groups, - phi::backends::gpu::DataLayout::kNCHW}; - ConvArgs args3{handle, - &transformed_ddX, - dW, - &transformed_dO_channel, - strides, - padding_common, - dilations, - dtype, - groups, - phi::backends::gpu::DataLayout::kNCHW}; - ConvArgs args4{handle, - &transformed_dX, - ddW, - &transformed_dO_channel, - strides, - padding_common, - dilations, - dtype, - groups, - phi::backends::gpu::DataLayout::kNCHW}; - -#ifdef PADDLE_WITH_HIP - SearchResult fwd_result1; - SearchResult fwd_result2; - SearchResult data_result; - SearchResult filter_result; -#else - SearchResult fwd_result1; - SearchResult fwd_result2; - SearchResult data_result; - SearchResult filter_result; -#endif - - // ddo = conv(ddI, W) + conv(I, ddW) - size_t workspace_size = 0; - - T* transformed_ddy_channel = nullptr; - if (ddO) { - ddy = ddO->data(); - transformed_ddy_channel = transformed_ddO_channel.data(); - if (ddX) { - args1.idesc.set(transformed_ddX, iwo_group); - args1.wdesc.set(*W, layout, iwo_group); - args1.odesc.set(transformed_ddO_channel, iwo_group); - args1.cdesc.set(dtype, padding_common, strides, dilations, true, c_group); - -#ifdef PADDLE_WITH_HIP - using search1 = SearchAlgorithm; - workspace_size = search1::GetWorkspaceSize(args1); - fwd_result1.algo = search1::Find( - args1, exhaustive_search, false, workspace_size, dev_ctx); -#else - using search1 = SearchAlgorithm; - fwd_result1 = search1::Find(dev_ctx, args1, exhaustive_search, false); - workspace_size = search1::GetWorkspaceSize(args1, fwd_result1.algo); -#endif - } - - if (ddW) { - ddw = ddW->data(); - args2.idesc.set(transformed_X, iwo_group); - args2.wdesc.set(*ddW, layout, iwo_group); - args2.odesc.set(transformed_ddO_channel, iwo_group); - args2.cdesc.set(dtype, padding_common, strides, dilations, true, c_group); - -#ifdef PADDLE_WITH_HIP - using search2 = SearchAlgorithm; - workspace_size = - std::max(workspace_size, search2::GetWorkspaceSize(args2)); - fwd_result2.algo = search2::Find( - args2, exhaustive_search, false, workspace_size, dev_ctx); -#else - using search2 = SearchAlgorithm; - fwd_result2 = search2::Find(dev_ctx, args2, exhaustive_search, false); - workspace_size = std::max( - workspace_size, search2::GetWorkspaceSize(args2, fwd_result2.algo)); -#endif - } - } - - if (dW && ddX) { - dw = dW->data(); - args3.idesc.set(transformed_ddX, iwo_group); - args3.wdesc.set(*dW, layout, iwo_group); - args3.odesc.set(transformed_dO_channel, iwo_group); - args3.cdesc.set(dtype, padding_common, strides, dilations, true, c_group); - -#ifdef PADDLE_WITH_HIP - using search3 = SearchAlgorithm; - workspace_size = std::max(workspace_size, search3::GetWorkspaceSize(args3)); - filter_result.algo = search3::Find( - args3, exhaustive_search, deterministic, workspace_size, dev_ctx); -#else - using search3 = SearchAlgorithm; - filter_result = - search3::Find(dev_ctx, args3, exhaustive_search, deterministic); - workspace_size = std::max( - workspace_size, search3::GetWorkspaceSize(args3, filter_result.algo)); -#endif - } - - if (ddW && dX) { - transformed_dx = transformed_dX.data(); - - args4.idesc.set(transformed_dX, iwo_group); - args4.wdesc.set(*ddW, layout, iwo_group); - args4.odesc.set(transformed_dO_channel, iwo_group); - args4.cdesc.set(dtype, padding_common, strides, dilations, true, c_group); - -#ifdef PADDLE_WITH_HIP - using search4 = SearchAlgorithm; - workspace_size = std::max(workspace_size, search4::GetWorkspaceSize(args4)); - data_result.algo = search4::Find( - args4, exhaustive_search, deterministic, workspace_size, dev_ctx); -#else - using search4 = SearchAlgorithm; - data_result = - search4::Find(dev_ctx, args4, exhaustive_search, deterministic); - workspace_size = std::max( - workspace_size, search4::GetWorkspaceSize(args4, data_result.algo)); -#endif - } - - int i_n, i_c, i_d, i_h, i_w; - GetNCDHW( - transformed_X.dims(), DataLayout::kNCHW, &i_n, &i_c, &i_d, &i_h, &i_w); - - int o_n, o_c, o_d, o_h, o_w; - GetNCDHW(transformed_dO_channel.dims(), - DataLayout::kNCHW, - &o_n, - &o_c, - &o_d, - &o_h, - &o_w); - - int group_offset_in = i_c / groups * i_h * i_w * i_d; - int group_offset_out = o_c / groups * o_h * o_w * o_d; - int group_offset_filter = W->numel() / groups; - - ScalingParamType alpha = 1.0f; - ScalingParamType beta = 0.0f; - - // NOTE(zhiqiu): inplace addto is not supported in double grad yet. - // ScalingParamType beta = dev_ctx.Attr("use_addto") ? 1.0f : - // 0.0f; - // VLOG(4) << "Conv_grad_grad: use_addto = " << - // dev_ctx.Attr("use_addto"); - // auto workspace_handle = dev_ctx.cudnn_workspace_handle(); - auto workspace_handle = GetDnnWorkspace( - const_cast(&(dev_ctx.GetAllocator())), dev_ctx.stream()); - - if (ddO) { - if (ddX) { - ddx = transformed_ddX.data(); -#ifdef PADDLE_WITH_HIP - workspace_handle.RunFunc( - [&](void* workspace_ptr) { - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::miopenConvolutionForward(handle, - &alpha, - args1.idesc.desc(), - ddx, - args1.wdesc.desc(), - w, - args1.cdesc.desc(), - fwd_result1.algo, - &beta, - args1.odesc.desc(), - transformed_ddy_channel, - workspace_ptr, - workspace_size)); - }, - workspace_size); -#else - ConvRunner::Apply(dev_ctx, - args1, - fwd_result1, - ddx, - w, - transformed_ddy_channel, - groups, - group_offset_in, - group_offset_filter, - group_offset_out, - workspace_size, - &workspace_handle, - false); -#endif - } - if (ddW) { -#ifdef PADDLE_WITH_HIP - // MIOPEN ONLY support beta to be 0.0f - workspace_handle.RunFunc( - [&](void* workspace_ptr) { - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::miopenConvolutionForward(handle, - &alpha, - args2.idesc.desc(), - x, - args2.wdesc.desc(), - ddw, - args2.cdesc.desc(), - fwd_result2.algo, - &beta, - args2.odesc.desc(), - transformed_ddy_channel, - workspace_ptr, - workspace_size)); - }, - workspace_size); -#else - ConvRunner::Apply(dev_ctx, - args2, - fwd_result2, - x, - ddw, - transformed_ddy_channel, - groups, - group_offset_in, - group_offset_filter, - group_offset_out, - workspace_size, - &workspace_handle, - true); -#endif - } - if (channel_last) { - TransToChannelLast(dev_ctx, &transformed_ddO_channel, ddO); - } - } - T* transformed_dy_channel = transformed_dO_channel.data(); - if (dW && ddX) { - ddx = transformed_ddX.data(); -#ifdef PADDLE_WITH_HIP - workspace_handle.RunFunc( - [&](void* workspace_ptr) { - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::miopenConvolutionBackwardWeights( - handle, - &alpha, - args3.odesc.desc(), - transformed_dy_channel, - args3.idesc.desc(), - ddx, - args3.cdesc.desc(), - filter_result.algo, - &beta, - args3.wdesc.desc(), - dw, - workspace_ptr, - workspace_size)); - }, - workspace_size); -#else - ConvRunner::Apply(dev_ctx, - args3, - filter_result, - transformed_dy_channel, - ddx, - dw, - groups, - group_offset_in, - group_offset_filter, - group_offset_out, - workspace_size, - &workspace_handle, - false); -#endif - } - - if (dX && ddW) { - ddw = ddW->data(); -#ifdef PADDLE_WITH_HIP - workspace_handle.RunFunc( - [&](void* workspace_ptr) { - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::miopenConvolutionBackwardData( - handle, - &alpha, - args4.odesc.desc(), - transformed_dy_channel, - args4.wdesc.desc(), - ddw, - args4.cdesc.desc(), - data_result.algo, - &beta, - args4.idesc.desc(), - transformed_dx, - workspace_ptr, - workspace_size)); - }, - workspace_size); -#else - ConvRunner::Apply(dev_ctx, - args4, - data_result, - transformed_dy_channel, - ddw, - transformed_dx, - groups, - group_offset_in, - group_offset_filter, - group_offset_out, - workspace_size, - &workspace_handle, - false); -#endif - - if (!is_sys_pad) { - // reverse padded input - std::vector starts(X->dims().size(), 0); - std::vector axes(X->dims().size(), 0); - - for (size_t i = 0; i < X->dims().size(); ++i) { - starts[i] = input_pad[2 * i]; - axes[i] = i; - } - if (X->dims().size() == 4) { - RemovePaddingSlice( - dev_ctx, &transformed_dX, &transformed_dX_channel, starts, axes); - } else { - RemovePaddingSlice( - dev_ctx, &transformed_dX, &transformed_dX_channel, starts, axes); - } - } - if (channel_last) { - TransToChannelLast(dev_ctx, &transformed_dX_channel, dX); - } - } -} - -template -void DepthwiseConvDoubleGradGPUDNNKernel( - const Context& dev_ctx, - const DenseTensor& input, - const DenseTensor& filter, - const DenseTensor& out_grad, - const paddle::optional& input_grad_grad, - const paddle::optional& filter_grad_grad, - const std::vector& strides, - const std::vector& paddings_t, - const std::string& padding_algorithm, - int groups, - const std::vector& dilations_t, - const std::string& data_format, - DenseTensor* input_grad, - DenseTensor* filter_grad, - DenseTensor* out_grad_grad) { - ConvCudnnGradGradKernel(dev_ctx, - input, - filter, - out_grad, - input_grad_grad, - filter_grad_grad, - strides, - paddings_t, - padding_algorithm, - dilations_t, - groups, - data_format, - input_grad, - filter_grad, - out_grad_grad); -} - -template -void Conv3DCudnnDoubleGradKernel( - const Context& dev_ctx, - const DenseTensor& input, - const DenseTensor& filter, - const DenseTensor& out_grad, - const paddle::optional& input_grad_grad, - const paddle::optional& filter_grad_grad, - const std::vector& strides, - const std::vector& paddings_t, - const std::string& padding_algorithm, - int groups, - const std::vector& dilations_t, - const std::string& data_format, - DenseTensor* input_grad, - DenseTensor* filter_grad, - DenseTensor* out_grad_grad) { - ConvCudnnGradGradKernel(dev_ctx, - input, - filter, - out_grad, - input_grad_grad, - filter_grad_grad, - strides, - paddings_t, - padding_algorithm, - dilations_t, - groups, - data_format, - input_grad, - filter_grad, - out_grad_grad); -} - -} // namespace phi - -#ifdef PADDLE_WITH_HIP -PD_REGISTER_PLUGIN_KERNEL(conv2d_grad, - metax_gpu, - ALL_LAYOUT, - phi::ConvCudnnGradKernel, - float, - phi::dtype::float16) {} - -PD_REGISTER_PLUGIN_KERNEL(conv3d_grad, - metax_gpu, - ALL_LAYOUT, - phi::Conv3DCudnnGradKernel, - float, - phi::dtype::float16) {} -PD_REGISTER_PLUGIN_KERNEL(conv2d_double_grad, - metax_gpu, - ALL_LAYOUT, - phi::ConvCudnnGradGradKernel, - float, - phi::dtype::float16) {} - -PD_REGISTER_PLUGIN_KERNEL(conv3d_double_grad, - metax_gpu, - ALL_LAYOUT, - phi::Conv3DCudnnDoubleGradKernel, - float, - phi::dtype::float16) {} - -PD_REGISTER_PLUGIN_KERNEL(depthwise_conv2d_double_grad, - GPU, - ALL_LAYOUT, - phi::DepthwiseConvDoubleGradGPUDNNKernel, - float, - phi::dtype::float16) {} -#else -#if CUDNN_VERSION_MIN(8, 1, 0) -PD_REGISTER_PLUGIN_KERNEL(conv2d_grad, - metax_gpu, - ALL_LAYOUT, - phi::ConvCudnnGradKernel, - float, - double, - phi::dtype::float16, - phi::dtype::bfloat16) {} - -PD_REGISTER_PLUGIN_KERNEL(conv3d_grad, - metax_gpu, - ALL_LAYOUT, - phi::Conv3DCudnnGradKernel, - float, - double, - phi::dtype::float16, - phi::dtype::bfloat16) {} -PD_REGISTER_PLUGIN_KERNEL(conv2d_double_grad, - metax_gpu, - ALL_LAYOUT, - phi::ConvCudnnGradGradKernel, - float, - double, - phi::dtype::float16, - phi::dtype::bfloat16) {} - -PD_REGISTER_PLUGIN_KERNEL(conv3d_double_grad, - metax_gpu, - ALL_LAYOUT, - phi::Conv3DCudnnDoubleGradKernel, - float, - double, - phi::dtype::float16, - phi::dtype::bfloat16) {} - -PD_REGISTER_PLUGIN_KERNEL(depthwise_conv2d_double_grad, - metax_gpu, - ALL_LAYOUT, - phi::DepthwiseConvDoubleGradGPUDNNKernel, - float, - double, - phi::dtype::float16, - phi::dtype::bfloat16) {} -#else -PD_REGISTER_PLUGIN_KERNEL(conv2d_grad, - metax_gpu, - ALL_LAYOUT, - phi::ConvCudnnGradKernel, - float, - double, - phi::dtype::float16) {} - -PD_REGISTER_PLUGIN_KERNEL(conv3d_grad, - metax_gpu, - ALL_LAYOUT, - phi::Conv3DCudnnGradKernel, - float, - double, - phi::dtype::float16) {} - -PD_REGISTER_PLUGIN_KERNEL(conv2d_double_grad, - metax_gpu, - ALL_LAYOUT, - phi::ConvCudnnGradGradKernel, - float, - double, - phi::dtype::float16) {} - -PD_REGISTER_PLUGIN_KERNEL(conv3d_double_grad, - metax_gpu, - ALL_LAYOUT, - phi::Conv3DCudnnDoubleGradKernel, - float, - double, - phi::dtype::float16) {} - -PD_REGISTER_PLUGIN_KERNEL(depthwise_conv2d_double_grad, - metax_gpu, - ALL_LAYOUT, - phi::DepthwiseConvDoubleGradGPUDNNKernel, - float, - double, - phi::dtype::float16) {} -#endif - -#endif diff --git a/backends/metax_gpu/kernels/cuda_kernels/fused_conv2d_add_act_kernel_register.cu b/backends/metax_gpu/kernels/cuda_kernels/fused_conv2d_add_act_kernel_register.cu index 6cf22a1918b..48809ceefa4 100644 --- a/backends/metax_gpu/kernels/cuda_kernels/fused_conv2d_add_act_kernel_register.cu +++ b/backends/metax_gpu/kernels/cuda_kernels/fused_conv2d_add_act_kernel_register.cu @@ -32,7 +32,6 @@ #include "paddle/phi/common/data_type.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/utils/optional.h" -#include "kernels/metax_context.h" namespace phi { namespace fusion { diff --git a/backends/metax_gpu/kernels/gpudnn/conv_kernel_register.cu b/backends/metax_gpu/kernels/gpudnn/conv_kernel_register.cu index bdff5fa9f93..bf129fed05c 100644 --- a/backends/metax_gpu/kernels/gpudnn/conv_kernel_register.cu +++ b/backends/metax_gpu/kernels/gpudnn/conv_kernel_register.cu @@ -81,7 +81,7 @@ void ConvCudnnKernelImplV7(const DenseTensor* transformed_input, args.cdesc.set( dtype, padding_common, strides, dilations, phi::AllowTF32Cudnn(), groups); #else - args.cdesc.set(dtype, padding_common, strides, dilations, true); + args.cdesc.set(dtype, padding_common, strides, dilations, phi::AllowTF32Cudnn()); #endif #if defined(PADDLE_WITH_CUDA) && CUDNN_VERSION_MIN(7, 0, 1) diff --git a/backends/metax_gpu/kernels/gpudnn/conv_transpose_kernel.cu b/backends/metax_gpu/kernels/gpudnn/conv_transpose_kernel.cu index aa1cc80d06d..928201c705f 100644 --- a/backends/metax_gpu/kernels/gpudnn/conv_transpose_kernel.cu +++ b/backends/metax_gpu/kernels/gpudnn/conv_transpose_kernel.cu @@ -93,7 +93,7 @@ void ConvTransposeCudnnKernelImplV7(const DenseTensor* transformed_x, args.idesc.set(*transformed_out, iwo_groups); args.wdesc.set(*filter, layout_tensor, iwo_groups); args.odesc.set(*transformed_x, iwo_groups); - args.cdesc.set(dtype, padding_common, strides, dilations_, false, c_groups); + args.cdesc.set(dtype, padding_common, strides, dilations_, phi::AllowTF32Cudnn(), c_groups); #ifdef PADDLE_WITH_HIP SearchResult bwd_result; From f26987f3dae76f5643986f1016066ef5d8b0e891 Mon Sep 17 00:00:00 2001 From: jiaxinWang-metax <189149612@qq.com> Date: Tue, 16 Sep 2025 15:48:11 +0800 Subject: [PATCH 03/12] modify conv kernel --- .../gpudnn/conv_grad_kernel_register.cu | 1585 +++++++++++++++++ 1 file changed, 1585 insertions(+) create mode 100644 backends/metax_gpu/kernels/gpudnn/conv_grad_kernel_register.cu diff --git a/backends/metax_gpu/kernels/gpudnn/conv_grad_kernel_register.cu b/backends/metax_gpu/kernels/gpudnn/conv_grad_kernel_register.cu new file mode 100644 index 00000000000..e4acb2f95b6 --- /dev/null +++ b/backends/metax_gpu/kernels/gpudnn/conv_grad_kernel_register.cu @@ -0,0 +1,1585 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "glog/logging.h" +#include "kernels/gpudnn/conv_gpudnn.h" +#include "paddle/phi/backends/context_pool.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/conv_grad_kernel.h" +#ifdef PADDLE_WITH_HIP +#include "paddle/phi/kernels/gpudnn/conv_miopen_helper.h" +#else +#include "kernels/gpudnn/conv_cudnn_v7.h" +#endif + +#include "kernels/impl/conv_cudnn_impl.h" +#include "paddle/phi/backends/gpu/cuda/cudnn_workspace_helper.h" +#include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/kernels/cpu/conv_util.h" +#include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/funcs/batch_norm_utils.h" +#include "paddle/phi/kernels/funcs/padding.h" +#ifdef PADDLE_WITH_CUDNN_FRONTEND +// clang-format off +#include "paddle/phi/backends/dynload/cudnn_frontend.h" +#include "paddle/phi/kernels/gpudnn/conv_cudnn_frontend.h" +// clang-format on +#endif + +namespace phi { + +template +void ConvCudnnGradKernelImplV7( + const DenseTensor* transformed_input, + const DenseTensor* transformed_filter_channel, + const DenseTensor* transformed_output_grad_channel, + DenseTensor* input_grad, + DenseTensor* filter_grad, + const Context& dev_ctx, + const std::vector& strides, + const std::vector& padding_common, + const std::vector& dilations, + phi::backends::gpu::DataLayout compute_format, + phi::backends::gpu::DataLayout layout, + bool use_addto, + bool exhaustive_search, + bool deterministic, + int groups, + DenseTensor* transformed_input_grad, + DenseTensor* transformed_filter_grad_channel) { + const T* input_data = transformed_input->data(); + const T* output_grad_data = transformed_output_grad_channel->data(); + const T* filter_data = transformed_filter_channel->data(); + T* filter_grad_data = nullptr; + T* input_grad_data = nullptr; + T* transformed_input_grad_data = nullptr; + + // auto handle = dev_ctx.cudnn_handle(); + auto handle = GetDnnHandle(dev_ctx.stream(), dev_ctx.GetPlace()); + // auto workspace_handle = dev_ctx.cudnn_workspace_handle(); + auto workspace_handle = GetDnnWorkspace( + const_cast(&(dev_ctx.GetAllocator())), dev_ctx.stream()); + auto dtype = phi::backends::gpu::CudnnDataType::type; + auto layout_tensor = phi::backends::gpu::GetCudnnTensorFormat(layout); + + ConvArgs args1{handle, + transformed_input_grad, + transformed_filter_channel, + transformed_output_grad_channel, + strides, + padding_common, + dilations, + dtype, + groups, + layout}; + ConvArgs args2{handle, + transformed_input, + transformed_filter_grad_channel, + transformed_output_grad_channel, + strides, + padding_common, + dilations, + dtype, + groups, + layout}; + + int i_n, i_c, i_d, i_h, i_w; + int o_n, o_c, o_d, o_h, o_w; + if (compute_format == phi::backends::gpu::DataLayout::kNHWC) { + GetNCDHW(transformed_input->dims(), + phi::backends::gpu::DataLayout::kNHWC, + &i_n, + &i_c, + &i_d, + &i_h, + &i_w); + GetNCDHW(transformed_output_grad_channel->dims(), + phi::backends::gpu::DataLayout::kNHWC, + &o_n, + &o_c, + &o_d, + &o_h, + &o_w); + } else { + GetNCDHW(transformed_input->dims(), + phi::backends::gpu::DataLayout::kNCHW, + &i_n, + &i_c, + &i_d, + &i_h, + &i_w); + GetNCDHW(transformed_output_grad_channel->dims(), + phi::backends::gpu::DataLayout::kNCHW, + &o_n, + &o_c, + &o_d, + &o_h, + &o_w); + } + + int group_offset_in = i_c / groups * i_h * i_w * i_d; + int group_offset_out = o_c / groups * o_h * o_w * o_d; + int group_offset_filter = transformed_filter_channel->numel() / groups; + +// ------------------- cudnn backward algorithm --------------------- +#ifdef PADDLE_WITH_HIP + SearchResult bwd_result; + SearchResult filter_result; +#else + SearchResult bwd_result; + SearchResult filter_result; +#endif + size_t workspace_size = 0; + int iwo_groups = groups; + int c_groups = 1; + +#if defined(PADDLE_WITH_HIP) || CUDNN_VERSION_MIN(7, 0, 1) + iwo_groups = 1; + c_groups = groups; + groups = 1; +#endif + + if (input_grad) { + // ------------------- cudnn descriptors --------------------- + input_grad_data = input_grad->data(); + transformed_input_grad_data = transformed_input_grad->data(); + + args1.idesc.set(*transformed_input_grad, layout_tensor); + args1.wdesc.set(*transformed_filter_channel, layout_tensor, iwo_groups); + args1.odesc.set(*transformed_output_grad_channel, layout_tensor); + args1.cdesc.set(dtype, + padding_common, + strides, + dilations, + phi::AllowTF32Cudnn(), + c_groups); + +#ifdef PADDLE_WITH_HIP + using search1 = SearchAlgorithm; + workspace_size = std::max(workspace_size, search1::GetWorkspaceSize(args1)); + bwd_result.algo = search1::Find( + args1, exhaustive_search, deterministic, workspace_size, dev_ctx); +#else + using search1 = SearchAlgorithm; + bwd_result = + search1::Find(dev_ctx, args1, exhaustive_search, deterministic); + workspace_size = std::max(workspace_size, bwd_result.workspace_size); +#endif + } + + if (filter_grad) { + // ------------------- cudnn descriptors --------------------- + filter_grad_data = transformed_filter_grad_channel->data(); + + args2.idesc.set(*transformed_input, layout_tensor); + args2.wdesc.set( + *transformed_filter_grad_channel, layout_tensor, iwo_groups); + args2.odesc.set(*transformed_output_grad_channel, layout_tensor); + args2.cdesc.set(dtype, + padding_common, + strides, + dilations, + phi::AllowTF32Cudnn(), + c_groups); +#ifdef PADDLE_WITH_HIP + using search2 = SearchAlgorithm; + workspace_size = std::max(workspace_size, search2::GetWorkspaceSize(args2)); + filter_result.algo = search2::Find( + args2, exhaustive_search, deterministic, workspace_size, dev_ctx); +#else + using search2 = SearchAlgorithm; + filter_result = + search2::Find(dev_ctx, args2, exhaustive_search, deterministic); + VLOG(3) << "filter algo: " << filter_result.algo << ", time " + << filter_result.time; + workspace_size = std::max(workspace_size, filter_result.workspace_size); +#endif + } + + // ------------------- cudnn conv backward data --------------------- + ScalingParamType alpha = 1.0f; +#ifdef PADDLE_WITH_HIP + // MIOPEN ONLY support beta to be 0.0f + ScalingParamType beta = 0.0f; +#else + ScalingParamType beta = use_addto ? 1.0f : 0.0f; + +#endif + VLOG(4) << "Conv_grad: use_addto = " << use_addto; + + if (input_grad) { +// When beta is 0, it is unnecessary to reset input_grad. +// When beta is 1, the output cannot be reset since addt strategy used. +#ifdef PADDLE_WITH_HIP + if (use_addto) { + DenseTensor temp_tensor(transformed_input_grad->type()); + temp_tensor.Resize(transformed_input_grad->dims()); + T* temp_tensor_data = dev_ctx.template Alloc(&temp_tensor); + workspace_handle.RunFunc( + [&](void* cudnn_workspace_ptr) { + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::miopenConvolutionBackwardData(handle, + &alpha, + args1.odesc.desc(), + output_grad_data, + args1.wdesc.desc(), + filter_data, + args1.cdesc.desc(), + bwd_result.algo, + &beta, + args1.idesc.desc(), + temp_tensor_data, + cudnn_workspace_ptr, + workspace_size)); + }, + workspace_size); + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::miopenOpTensor(handle, + miopenTensorOpAdd, + &alpha, + args1.idesc.desc(), + transformed_input_grad_data, + &alpha, + args1.idesc.desc(), + temp_tensor_data, + &beta, + args1.idesc.desc(), + transformed_input_grad_data)); + } else { + workspace_handle.RunFunc( + [&](void* cudnn_workspace_ptr) { + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::miopenConvolutionBackwardData( + handle, + &alpha, + args1.odesc.desc(), + output_grad_data, + args1.wdesc.desc(), + filter_data, + args1.cdesc.desc(), + bwd_result.algo, + &beta, + args1.idesc.desc(), + transformed_input_grad_data, + cudnn_workspace_ptr, + workspace_size)); + }, + workspace_size); + } +#else + ConvRunner::Apply(dev_ctx, + args1, + bwd_result, + output_grad_data, + filter_data, + transformed_input_grad_data, + groups, + group_offset_in, + group_offset_filter, + group_offset_out, + workspace_size, + &workspace_handle, + use_addto); +#endif + } + + // ------------------- cudnn conv backward filter --------------------- + if (filter_grad) { +// Because beta is zero, it is unnecessary to reset filter_grad. +#ifdef PADDLE_WITH_HIP + workspace_handle.RunFunc( + [&](void* cudnn_workspace_ptr) { + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::miopenConvolutionBackwardWeights( + handle, + &alpha, + args2.odesc.desc(), + output_grad_data, + args2.idesc.desc(), + input_data, + args2.cdesc.desc(), + filter_result.algo, + &beta, + args2.wdesc.desc(), + filter_grad_data, + cudnn_workspace_ptr, + workspace_size)); + }, + workspace_size); +#else + ConvRunner::Apply(dev_ctx, + args2, + filter_result, + output_grad_data, + input_data, + filter_grad_data, + groups, + group_offset_in, + group_offset_filter, + group_offset_out, + workspace_size, + &workspace_handle, + false); +#endif + } +} + +#ifdef PADDLE_WITH_CUDNN_FRONTEND +template +void ConvCudnnGradKernelImplV8( + const DenseTensor* transformed_input, + const DenseTensor* transformed_filter_channel, + const DenseTensor* transformed_output_grad_channel, + DenseTensor* input_grad, + DenseTensor* filter_grad, + const Context& dev_ctx, + const std::vector& strides, + const std::vector& padding_common, + const std::vector& dilations, + phi::backends::gpu::DataLayout layout, + bool use_addto, + bool exhaustive_search, + bool deterministic, + int groups, + DenseTensor* transformed_input_grad, + DenseTensor* transformed_filter_grad_channel) { + PADDLE_ENFORCE_EQ( + groups, + 1, + common::errors::Unimplemented( + "Group concolution using CUDNNv8 API is unsupported for now")); + + cudnnHandle_t handle = const_cast( + GetDnnHandle(dev_ctx.stream(), dev_ctx.GetPlace());); + // auto workspace_handle = dev_ctx.cudnn_workspace_handle(); + auto workspace_handle = GetDnnWorkspace( + const_cast(&(dev_ctx.GetAllocator())), dev_ctx.stream()); + auto dtype = phi::backends::gpu::CudnnDataType::type; + auto layout_format = phi::backends::gpu::GetCudnnTensorFormat(layout); + + if (input_grad) { + CudnnConvBwdDataV8(transformed_output_grad_channel, + transformed_filter_channel, + handle, + &workspace_handle, + strides, + padding_common, + dilations, + dtype, + layout_format, + use_addto, + exhaustive_search, + deterministic, + transformed_input_grad); + } + + if (filter_grad) { + CudnnConvBwdFilterV8(transformed_input, + transformed_output_grad_channel, + handle, + &workspace_handle, + strides, + padding_common, + dilations, + dtype, + layout_format, + use_addto, + exhaustive_search, + deterministic, + transformed_filter_grad_channel); + } +} +#endif + +template +void ConvCudnnGradKernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& filter, + const DenseTensor& output_grad, + const std::vector& strides_t, + const std::vector& paddings_t, + const std::string& padding_algorithm, + const std::vector& dilations_t, + int groups, + const std::string& data_format, + DenseTensor* input_grad, + DenseTensor* filter_grad) { + // 0-size + if (input.numel() == 0 || filter.numel() == 0) { + if (input_grad) dev_ctx.template Alloc(input_grad); + if (filter_grad) { + phi::Full( + dev_ctx, + phi::IntArray(common::vectorize(filter_grad->dims())), + 0, + filter_grad); + } + return; + } + if (input_grad) { + dev_ctx.template Alloc(input_grad); + } + if (filter_grad) { + dev_ctx.template Alloc(filter_grad); + } + + // bool has_use_addto = dev_ctx.HasDnnAttr("use_addto"); + bool has_use_addto = "true"; + VLOG(4) << "GPUContext contains `use_addto`: " << has_use_addto; + // bool use_addto = has_use_addto + // ? PADDLE_GET_CONST(bool, "true") + // : false; + bool use_addto = "true"; + std::vector dilations = dilations_t; + std::vector strides = strides_t; + std::vector paddings = paddings_t; + + // bool has_exhaustive_search = dev_ctx.HasDnnAttr("exhaustive_search"); + bool has_exhaustive_search = "true"; + VLOG(4) << "GPUContext contains `exhaustive_search`: " + << has_exhaustive_search; + // bool exhaustive_search_attr = + // has_exhaustive_search + // ? PADDLE_GET_CONST(bool, "true") + // : false; + bool exhaustive_search_attr = "true"; + bool exhaustive_search = + FLAGS_cudnn_exhaustive_search || exhaustive_search_attr; + bool deterministic = FLAGS_cudnn_deterministic; + auto exhaustive_deterministic = exhaustive_search && deterministic; + PADDLE_ENFORCE_EQ(exhaustive_deterministic, + false, + common::errors::InvalidArgument( + "Can't set exhaustive_search True and " + "FLAGS_cudnn_deterministic True at same time.")); + + const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC"); + + auto dtype = phi::backends::gpu::CudnnDataType::type; + +#ifdef PADDLE_WITH_HIP + // HIP MIOPEN ONLY SUPPORT NCHW format + auto compute_format = phi::backends::gpu::DataLayout::kNCHW; +#else +#if CUDNN_VERSION_MIN(8, 1, 0) + const bool compute_in_nhwc = + (dtype == CUDNN_DATA_HALF || dtype == CUDNN_DATA_BFLOAT16) && + IsVoltaOrLater(dev_ctx); +#else + const bool compute_in_nhwc = + dtype == CUDNN_DATA_HALF && IsVoltaOrLater(dev_ctx); +#endif + auto compute_format = compute_in_nhwc && channel_last + ? phi::backends::gpu::DataLayout::kNHWC + : phi::backends::gpu::DataLayout::kNCHW; +#endif + VLOG(3) << "Compute ConvGradOp with cuDNN:" + << " data_format=" << data_format << " compute_format=" + << (compute_format == phi::backends::gpu::DataLayout::kNHWC ? "NHWC" + : "NCHW"); + + // transform Tensor + DenseTensor transformed_input_channel(input.type()); + DenseTensor transformed_output_grad_channel(output_grad.type()); + DenseTensor transformed_input_grad_channel(input.type()); + DenseTensor transformed_filter_channel(filter.type()); + DenseTensor transformed_filter_grad_channel(filter.type()); + + if (channel_last && compute_format == phi::backends::gpu::DataLayout::kNCHW) { + VLOG(3) << "Transform input, output_grad, input_grad and tensor from " + "NHWC to NCHW."; + ResizeToChannelFirst( + dev_ctx, &input, &transformed_input_channel); + TransToChannelFirst( + dev_ctx, &input, &transformed_input_channel); + + ResizeToChannelFirst( + dev_ctx, &output_grad, &transformed_output_grad_channel); + TransToChannelFirst( + dev_ctx, &output_grad, &transformed_output_grad_channel); + + if (input_grad) { + ResizeToChannelFirst( + dev_ctx, input_grad, &transformed_input_grad_channel); + // NOTE(zhiqiu): If inplace_addto strategy is enabled, we need to copy + // the data of input_grad to transformed_input_grad_channel. + if (use_addto) { + TransToChannelFirst( + dev_ctx, input_grad, &transformed_input_grad_channel); + } + } + } else { + transformed_input_channel.ShareDataWith(input); + transformed_output_grad_channel.ShareDataWith(output_grad); + if (input_grad) { + transformed_input_grad_channel.ShareDataWith(*input_grad); + } + } + + if (compute_format == phi::backends::gpu::DataLayout::kNHWC) { + VLOG(3) << "Transform filter and filter_grad tensor from NCHW to NHWC."; + ResizeToChannelLast( + dev_ctx, &filter, &transformed_filter_channel); + TransToChannelLast( + dev_ctx, &filter, &transformed_filter_channel); + + if (filter_grad) { + ResizeToChannelLast( + dev_ctx, filter_grad, &transformed_filter_grad_channel); + } + } else { + transformed_filter_channel.ShareDataWith(filter); + if (filter_grad) { + transformed_filter_grad_channel.ShareDataWith(*filter_grad); + } + } + + // update paddings + auto in_dims = transformed_input_channel.dims(); + auto filter_dims = transformed_filter_channel.dims(); + DDim in_data_dims; + DDim filter_data_dims; + if (compute_format == phi::backends::gpu::DataLayout::kNCHW) { + in_data_dims = slice_ddim(in_dims, 2, in_dims.size()); + filter_data_dims = slice_ddim(filter_dims, 2, filter_dims.size()); + } else { + in_data_dims = slice_ddim(in_dims, 1, in_dims.size() - 1); + filter_data_dims = slice_ddim(filter_dims, 1, filter_dims.size() - 1); + } + std::vector ksize = common::vectorize(filter_data_dims); + UpdatePaddingAndDilation( + &paddings, &dilations, padding_algorithm, in_data_dims, strides, ksize); + + // cuDNN only supports padding the same amount on every dimension. + // So we create a new padded input tensor. + int data_dim = strides.size(); // 2d or 3d + bool is_sys_pad = funcs::IsSymmetricPadding(paddings, data_dim); + Tensor transformed_input(input.type()); + Tensor transformed_input_grad(input.type()); + std::vector padding_common(data_dim, 0); + std::vector input_pad(transformed_input_channel.dims().size() * 2, 0); + + if (!is_sys_pad) { + // get pad + std::vector padding_diff(data_dim); + std::vector new_input_shape_vec(data_dim + 2); + new_input_shape_vec[0] = transformed_input_channel.dims()[0]; + if (compute_format == phi::backends::gpu::DataLayout::kNCHW) { + new_input_shape_vec[1] = transformed_input_channel.dims()[1]; + } else { + new_input_shape_vec[data_dim + 1] = + transformed_input_channel.dims()[data_dim + 1]; + } + + for (size_t i = 0; i < data_dim; ++i) { + padding_diff[i] = std::abs(paddings[2 * i] - paddings[2 * i + 1]); + padding_common[i] = std::min(paddings[2 * i], paddings[2 * i + 1]); + if (compute_format == phi::backends::gpu::DataLayout::kNCHW) { + new_input_shape_vec[i + 2] = + transformed_input_channel.dims()[i + 2] + padding_diff[i]; + } else { + new_input_shape_vec[i + 1] = + transformed_input_channel.dims()[i + 1] + padding_diff[i]; + } + if (compute_format == phi::backends::gpu::DataLayout::kNCHW) { + input_pad[2 * i + 4] = paddings[2 * i] - padding_common[i]; + input_pad[2 * i + 4 + 1] = paddings[2 * i + 1] - padding_common[i]; + } else { + input_pad[2 * i + 2] = paddings[2 * i] - padding_common[i]; + input_pad[2 * i + 2 + 1] = paddings[2 * i + 1] - padding_common[i]; + } + } + DDim new_input_shape(common::make_ddim(new_input_shape_vec)); + transformed_input.Resize(new_input_shape); + dev_ctx.template Alloc(&transformed_input); + + transformed_input_grad.Resize(new_input_shape); + + if (input_grad) { + dev_ctx.template Alloc(&transformed_input_grad); + } + // pad for input + const int rank = transformed_input_channel.dims().size(); + T pad_value(0.0); + switch (rank) { + case 4: { + funcs::PadFunction(dev_ctx, + input_pad, + transformed_input_channel, + pad_value, + &transformed_input); + } break; + case 5: { + funcs::PadFunction(dev_ctx, + input_pad, + transformed_input_channel, + pad_value, + &transformed_input); + } break; + default: + PADDLE_THROW(common::errors::InvalidArgument( + "ConvOp only support tensors with 4 or 5 dimensions.")); + } + } else { + transformed_input.ShareDataWith(transformed_input_channel); + if (input_grad) { + transformed_input_grad.ShareDataWith(transformed_input_grad_channel); + } + if (paddings.size() == data_dim) { + for (size_t i = 0; i < data_dim; ++i) { + padding_common[i] = paddings[i]; + } + } else { + for (size_t i = 0; i < data_dim; ++i) { + padding_common[i] = paddings[2 * i]; + } + } + } + phi::backends::gpu::DataLayout layout = + compute_format == phi::backends::gpu::DataLayout::kNHWC + ? phi::backends::gpu::DataLayout::kNHWC + : phi::backends::gpu::DataLayout::kNCHW; + if (transformed_input.dims().size() == 5) { + layout = compute_format == phi::backends::gpu::DataLayout::kNHWC + ? phi::backends::gpu::DataLayout::kNDHWC + : phi::backends::gpu::DataLayout::kNCDHW; + } + CUDNN_ENFORCE_TENSOR_SIZE_SUPPORTED(transformed_input); + CUDNN_ENFORCE_TENSOR_SIZE_SUPPORTED(transformed_filter_channel); + CUDNN_ENFORCE_TENSOR_SIZE_SUPPORTED(transformed_output_grad_channel); + +#ifdef PADDLE_WITH_CUDNN_FRONTEND + if (dynload::IsCudnnFrontendEnabled() && (groups == 1)) + ConvCudnnGradKernelImplV8(&transformed_input, + &transformed_filter_channel, + &transformed_output_grad_channel, + input_grad, + filter_grad, + dev_ctx, + strides, + padding_common, + dilations, + layout, + use_addto, + exhaustive_search, + deterministic, + groups, + &transformed_input_grad, + &transformed_filter_grad_channel); + else + ConvCudnnGradKernelImplV7(&transformed_input, + &transformed_filter_channel, + &transformed_output_grad_channel, + input_grad, + filter_grad, + dev_ctx, + strides, + padding_common, + dilations, + compute_format, + layout, + use_addto, + exhaustive_search, + deterministic, + groups, + &transformed_input_grad, + &transformed_filter_grad_channel); +#else + ConvCudnnGradKernelImplV7(&transformed_input, + &transformed_filter_channel, + &transformed_output_grad_channel, + input_grad, + filter_grad, + dev_ctx, + strides, + padding_common, + dilations, + compute_format, + layout, + use_addto, + exhaustive_search, + deterministic, + groups, + &transformed_input_grad, + &transformed_filter_grad_channel); +#endif + + if (input_grad) { + if (!is_sys_pad) { + std::vector starts(transformed_input_channel.dims().size(), 0); + std::vector axes(transformed_input_channel.dims().size(), 0); + + for (size_t i = 0; i < transformed_input_channel.dims().size(); ++i) { + starts[i] = input_pad[2 * i]; + axes[i] = i; + } + + dev_ctx.template Alloc(&transformed_input_grad_channel); + if (transformed_input_channel.dims().size() == 4) { + RemovePaddingSlice(dev_ctx, + &transformed_input_grad, + &transformed_input_grad_channel, + starts, + axes); + } else { + RemovePaddingSlice(dev_ctx, + &transformed_input_grad, + &transformed_input_grad_channel, + starts, + axes); + } + } + + if (channel_last && + compute_format == phi::backends::gpu::DataLayout::kNCHW) { + TransToChannelLast( + dev_ctx, &transformed_input_grad_channel, input_grad); + } + } + + if (filter_grad) { + if (compute_format == phi::backends::gpu::DataLayout::kNHWC) { + TransToChannelFirst( + dev_ctx, &transformed_filter_grad_channel, filter_grad); + } + } +} + +template +void Conv3DCudnnGradKernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& filter, + const DenseTensor& out_grad, + const std::vector& strides, + const std::vector& paddings, + const std::string& padding_algorithm, + int groups, + const std::vector& dilations, + const std::string& data_format, + DenseTensor* input_grad, + DenseTensor* filter_grad) { + ConvCudnnGradKernel(dev_ctx, + input, + filter, + out_grad, + strides, + paddings, + padding_algorithm, + dilations, + groups, + data_format, + input_grad, + filter_grad); +} + +template +void ConvCudnnGradGradKernel( + const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& filter, + const DenseTensor& out_grad, + const paddle::optional& input_grad_grad, + const paddle::optional& filter_grad_grad, + const std::vector& strides, + const std::vector& paddings_t, + const std::string& padding_algorithm, + const std::vector& dilations_t, + int groups, + const std::string& data_format, + DenseTensor* input_grad, + DenseTensor* filter_grad, + DenseTensor* out_grad_grad) { + auto X = &input; + auto W = &filter; + auto dO = &out_grad; + auto ddX = input_grad_grad.get_ptr(); + auto ddW = filter_grad_grad.get_ptr(); + + auto ddO = out_grad_grad; + auto dW = filter_grad; + auto dX = input_grad; + if (ddO) { + dev_ctx.template Alloc(ddO); + phi::funcs::SetConstant set_zero; + set_zero(dev_ctx, ddO, static_cast(0)); + } + if (dW) { + dev_ctx.template Alloc(dW); + } + if (dX) { + dev_ctx.template Alloc(dX); + } + + // const T* x = X->data(); + const T* dy = dO->data(); + const T* w = W->data(); + + const T* ddx = nullptr; + const T* ddw = nullptr; + T *dw, *dx, *ddy; + dw = dx = ddy = nullptr; + T* transformed_dx = nullptr; + std::vector dilations = dilations_t; + + // bool has_exhaustive_search = dev_ctx.HasDnnAttr("exhaustive_search"); + // VLOG(4) << "GPUContext contains `exhaustive_search`: " + // << has_exhaustive_search; + // bool exhaustive_search_attr = + // has_exhaustive_search + // ? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("exhaustive_search")) + // : false; + bool exhaustive_search_attr = "true"; + bool exhaustive_search = + FLAGS_cudnn_exhaustive_search || exhaustive_search_attr; + bool deterministic = FLAGS_cudnn_deterministic; + auto exhaustive_deterministic = exhaustive_search && deterministic; + PADDLE_ENFORCE_EQ(exhaustive_deterministic, + false, + common::errors::InvalidArgument( + "Can't set exhaustive_search True and " + "FLAGS_cudnn_deterministic True at same time.")); + + std::vector paddings = paddings_t; + + const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC"); + + // transform Tensors to channel first----------- + DenseTensor transformed_X_channel(X->type()); + DenseTensor transformed_dO_channel(dO->type()); + DenseTensor transformed_ddX_channel(X->type()); + + DenseTensor transformed_ddO_channel(dO->type()); + DenseTensor transformed_dX_channel(X->type()); + + if (channel_last) { + ResizeToChannelFirst(dev_ctx, X, &transformed_X_channel); + TransToChannelFirst(dev_ctx, X, &transformed_X_channel); + + ResizeToChannelFirst(dev_ctx, dO, &transformed_dO_channel); + TransToChannelFirst(dev_ctx, dO, &transformed_dO_channel); + + if (ddX) { + ResizeToChannelFirst(dev_ctx, ddX, &transformed_ddX_channel); + TransToChannelFirst(dev_ctx, ddX, &transformed_ddX_channel); + } + + if (ddO) { + ResizeToChannelFirst(dev_ctx, ddO, &transformed_ddO_channel); + } + if (dX) { + ResizeToChannelFirst(dev_ctx, dX, &transformed_dX_channel); + dev_ctx.template Alloc(&transformed_dX_channel); + } + + } else { + transformed_X_channel = *X; + transformed_dO_channel = *dO; + if (ddX) { + transformed_ddX_channel = *ddX; + } + if (ddO) { + transformed_ddO_channel.ShareDataWith(*ddO); + } + if (dX) { + transformed_dX_channel.ShareDataWith(*dX); + } + } + + auto in_dims = transformed_X_channel.dims(); + auto filter_dims = W->dims(); + DDim in_data_dims = slice_ddim(in_dims, 2, in_dims.size()); + DDim filter_data_dims = slice_ddim(filter_dims, 2, filter_dims.size()); + std::vector ksize = common::vectorize(filter_data_dims); + UpdatePaddingAndDilation( + &paddings, &dilations, padding_algorithm, in_data_dims, strides, ksize); + + int data_dim = strides.size(); // 2d or 3d + bool is_sys_pad = funcs::IsSymmetricPadding(paddings, data_dim); + DenseTensor transformed_X(X->type()); + DenseTensor transformed_ddX(X->type()); + + DenseTensor transformed_dX(X->type()); + + std::vector padding_common(data_dim, 0); + std::vector input_pad(X->dims().size() * 2, 0); + + if (!is_sys_pad) { + // get pad + std::vector padding_diff(data_dim); + std::vector new_input_shape_vec(data_dim + 2); + new_input_shape_vec[0] = transformed_X_channel.dims()[0]; + new_input_shape_vec[1] = transformed_X_channel.dims()[1]; + + for (size_t i = 0; i < data_dim; ++i) { + padding_diff[i] = std::abs(paddings[2 * i] - paddings[2 * i + 1]); + padding_common[i] = std::min(paddings[2 * i], paddings[2 * i + 1]); + new_input_shape_vec[i + 2] = + transformed_X_channel.dims()[i + 2] + padding_diff[i]; + input_pad[2 * i + 4] = paddings[2 * i] - padding_common[i]; + input_pad[2 * i + 4 + 1] = paddings[2 * i + 1] - padding_common[i]; + } + DDim new_input_shape(common::make_ddim(new_input_shape_vec)); + transformed_X.Resize(new_input_shape); + transformed_ddX.Resize(new_input_shape); + transformed_dX.Resize(new_input_shape); + + dev_ctx.template Alloc(&transformed_X); + + if (ddX) { + dev_ctx.template Alloc(&transformed_ddX); + } + if (dX) { + dev_ctx.template Alloc(&transformed_dX); + } + + // pad for input + const int rank = X->dims().size(); + T pad_value(0.0); + switch (rank) { + case 4: { + funcs::PadFunction(dev_ctx, + input_pad, + transformed_X_channel, + pad_value, + &transformed_X); + if (ddX) { + funcs::PadFunction(dev_ctx, + input_pad, + transformed_ddX_channel, + pad_value, + &transformed_ddX); + } + } break; + case 5: { + funcs::PadFunction(dev_ctx, + input_pad, + transformed_X_channel, + pad_value, + &transformed_X); + if (ddX) { + funcs::PadFunction(dev_ctx, + input_pad, + transformed_ddX_channel, + pad_value, + &transformed_ddX); + } + } break; + default: + PADDLE_THROW(common::errors::InvalidArgument( + "ConvOp only support tensors with 4 or 5 dimensions.")); + } + + } else { + transformed_X.ShareDataWith(transformed_X_channel); + if (ddX) { + transformed_ddX.ShareDataWith(transformed_ddX_channel); + } + if (dX) { + transformed_dX.ShareDataWith(transformed_dX_channel); + } + + if (paddings.size() == data_dim) { + for (size_t i = 0; i < data_dim; ++i) { + padding_common[i] = paddings[i]; + } + } else { + for (size_t i = 0; i < data_dim; ++i) { + padding_common[i] = paddings[2 * i]; + } + } + } + + const T* x = transformed_X.data(); + + int iwo_group = groups; + int c_group = 1; +#if defined(PADDLE_WITH_HIP) || CUDNN_VERSION_MIN(7, 0, 1) + iwo_group = 1; + c_group = groups; + groups = 1; +#endif + auto dtype = phi::backends::gpu::CudnnDataType::type; + + // auto handle = dev_ctx.cudnn_handle(); + auto handle = GetDnnHandle(dev_ctx.stream(), dev_ctx.GetPlace()); + auto layout = phi::backends::gpu::GetCudnnTensorFormat( + phi::backends::gpu::DataLayout::kNCHW); + + ConvArgs args1{handle, + &transformed_ddX, + W, + &transformed_ddO_channel, + strides, + padding_common, + dilations, + dtype, + groups, + phi::backends::gpu::DataLayout::kNCHW}; + ConvArgs args2{handle, + &transformed_X, + ddW, + &transformed_ddO_channel, + strides, + padding_common, + dilations, + dtype, + groups, + phi::backends::gpu::DataLayout::kNCHW}; + ConvArgs args3{handle, + &transformed_ddX, + dW, + &transformed_dO_channel, + strides, + padding_common, + dilations, + dtype, + groups, + phi::backends::gpu::DataLayout::kNCHW}; + ConvArgs args4{handle, + &transformed_dX, + ddW, + &transformed_dO_channel, + strides, + padding_common, + dilations, + dtype, + groups, + phi::backends::gpu::DataLayout::kNCHW}; + +#ifdef PADDLE_WITH_HIP + SearchResult fwd_result1; + SearchResult fwd_result2; + SearchResult data_result; + SearchResult filter_result; +#else + SearchResult fwd_result1; + SearchResult fwd_result2; + SearchResult data_result; + SearchResult filter_result; +#endif + + // ddo = conv(ddI, W) + conv(I, ddW) + size_t workspace_size = 0; + + T* transformed_ddy_channel = nullptr; + if (ddO) { + ddy = ddO->data(); + transformed_ddy_channel = transformed_ddO_channel.data(); + if (ddX) { + args1.idesc.set(transformed_ddX, iwo_group); + args1.wdesc.set(*W, layout, iwo_group); + args1.odesc.set(transformed_ddO_channel, iwo_group); + args1.cdesc.set(dtype, + padding_common, + strides, + dilations, + phi::AllowTF32Cudnn(), + c_group); + +#ifdef PADDLE_WITH_HIP + using search1 = SearchAlgorithm; + workspace_size = search1::GetWorkspaceSize(args1); + fwd_result1.algo = search1::Find( + args1, exhaustive_search, false, workspace_size, dev_ctx); +#else + using search1 = SearchAlgorithm; + fwd_result1 = search1::Find(dev_ctx, args1, exhaustive_search, false); + workspace_size = search1::GetWorkspaceSize(args1, fwd_result1.algo); +#endif + } + + if (ddW) { + ddw = ddW->data(); + args2.idesc.set(transformed_X, iwo_group); + args2.wdesc.set(*ddW, layout, iwo_group); + args2.odesc.set(transformed_ddO_channel, iwo_group); + args2.cdesc.set(dtype, + padding_common, + strides, + dilations, + phi::AllowTF32Cudnn(), + c_group); + +#ifdef PADDLE_WITH_HIP + using search2 = SearchAlgorithm; + workspace_size = + std::max(workspace_size, search2::GetWorkspaceSize(args2)); + fwd_result2.algo = search2::Find( + args2, exhaustive_search, false, workspace_size, dev_ctx); +#else + using search2 = SearchAlgorithm; + fwd_result2 = search2::Find(dev_ctx, args2, exhaustive_search, false); + workspace_size = std::max( + workspace_size, search2::GetWorkspaceSize(args2, fwd_result2.algo)); +#endif + } + } + + if (dW && ddX) { + dw = dW->data(); + args3.idesc.set(transformed_ddX, iwo_group); + args3.wdesc.set(*dW, layout, iwo_group); + args3.odesc.set(transformed_dO_channel, iwo_group); + args3.cdesc.set(dtype, + padding_common, + strides, + dilations, + phi::AllowTF32Cudnn(), + c_group); + +#ifdef PADDLE_WITH_HIP + using search3 = SearchAlgorithm; + workspace_size = std::max(workspace_size, search3::GetWorkspaceSize(args3)); + filter_result.algo = search3::Find( + args3, exhaustive_search, deterministic, workspace_size, dev_ctx); +#else + using search3 = SearchAlgorithm; + filter_result = + search3::Find(dev_ctx, args3, exhaustive_search, deterministic); + workspace_size = std::max( + workspace_size, search3::GetWorkspaceSize(args3, filter_result.algo)); +#endif + } + + if (ddW && dX) { + transformed_dx = transformed_dX.data(); + + args4.idesc.set(transformed_dX, iwo_group); + args4.wdesc.set(*ddW, layout, iwo_group); + args4.odesc.set(transformed_dO_channel, iwo_group); + args4.cdesc.set(dtype, + padding_common, + strides, + dilations, + phi::AllowTF32Cudnn(), + c_group); + +#ifdef PADDLE_WITH_HIP + using search4 = SearchAlgorithm; + workspace_size = std::max(workspace_size, search4::GetWorkspaceSize(args4)); + data_result.algo = search4::Find( + args4, exhaustive_search, deterministic, workspace_size, dev_ctx); +#else + using search4 = SearchAlgorithm; + data_result = + search4::Find(dev_ctx, args4, exhaustive_search, deterministic); + workspace_size = std::max( + workspace_size, search4::GetWorkspaceSize(args4, data_result.algo)); +#endif + } + + int i_n, i_c, i_d, i_h, i_w; + GetNCDHW( + transformed_X.dims(), DataLayout::kNCHW, &i_n, &i_c, &i_d, &i_h, &i_w); + + int o_n, o_c, o_d, o_h, o_w; + GetNCDHW(transformed_dO_channel.dims(), + DataLayout::kNCHW, + &o_n, + &o_c, + &o_d, + &o_h, + &o_w); + + int group_offset_in = i_c / groups * i_h * i_w * i_d; + int group_offset_out = o_c / groups * o_h * o_w * o_d; + int group_offset_filter = W->numel() / groups; + + ScalingParamType alpha = 1.0f; + ScalingParamType beta = 0.0f; + + // NOTE(zhiqiu): inplace addto is not supported in double grad yet. + // ScalingParamType beta = dev_ctx.Attr("use_addto") ? 1.0f : + // 0.0f; + // VLOG(4) << "Conv_grad_grad: use_addto = " << + // dev_ctx.Attr("use_addto"); + // auto workspace_handle = dev_ctx.cudnn_workspace_handle(); + auto workspace_handle = GetDnnWorkspace( + const_cast(&(dev_ctx.GetAllocator())), dev_ctx.stream()); + + if (ddO) { + if (ddX) { + ddx = transformed_ddX.data(); +#ifdef PADDLE_WITH_HIP + workspace_handle.RunFunc( + [&](void* workspace_ptr) { + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::miopenConvolutionForward(handle, + &alpha, + args1.idesc.desc(), + ddx, + args1.wdesc.desc(), + w, + args1.cdesc.desc(), + fwd_result1.algo, + &beta, + args1.odesc.desc(), + transformed_ddy_channel, + workspace_ptr, + workspace_size)); + }, + workspace_size); +#else + ConvRunner::Apply(dev_ctx, + args1, + fwd_result1, + ddx, + w, + transformed_ddy_channel, + groups, + group_offset_in, + group_offset_filter, + group_offset_out, + workspace_size, + &workspace_handle, + false); +#endif + } + if (ddW) { +#ifdef PADDLE_WITH_HIP + // MIOPEN ONLY support beta to be 0.0f + workspace_handle.RunFunc( + [&](void* workspace_ptr) { + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::miopenConvolutionForward(handle, + &alpha, + args2.idesc.desc(), + x, + args2.wdesc.desc(), + ddw, + args2.cdesc.desc(), + fwd_result2.algo, + &beta, + args2.odesc.desc(), + transformed_ddy_channel, + workspace_ptr, + workspace_size)); + }, + workspace_size); +#else + ConvRunner::Apply(dev_ctx, + args2, + fwd_result2, + x, + ddw, + transformed_ddy_channel, + groups, + group_offset_in, + group_offset_filter, + group_offset_out, + workspace_size, + &workspace_handle, + true); +#endif + } + if (channel_last) { + TransToChannelLast(dev_ctx, &transformed_ddO_channel, ddO); + } + } + T* transformed_dy_channel = transformed_dO_channel.data(); + if (dW && ddX) { + ddx = transformed_ddX.data(); +#ifdef PADDLE_WITH_HIP + workspace_handle.RunFunc( + [&](void* workspace_ptr) { + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::miopenConvolutionBackwardWeights( + handle, + &alpha, + args3.odesc.desc(), + transformed_dy_channel, + args3.idesc.desc(), + ddx, + args3.cdesc.desc(), + filter_result.algo, + &beta, + args3.wdesc.desc(), + dw, + workspace_ptr, + workspace_size)); + }, + workspace_size); +#else + ConvRunner::Apply(dev_ctx, + args3, + filter_result, + transformed_dy_channel, + ddx, + dw, + groups, + group_offset_in, + group_offset_filter, + group_offset_out, + workspace_size, + &workspace_handle, + false); +#endif + } + + if (dX && ddW) { + ddw = ddW->data(); +#ifdef PADDLE_WITH_HIP + workspace_handle.RunFunc( + [&](void* workspace_ptr) { + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::miopenConvolutionBackwardData( + handle, + &alpha, + args4.odesc.desc(), + transformed_dy_channel, + args4.wdesc.desc(), + ddw, + args4.cdesc.desc(), + data_result.algo, + &beta, + args4.idesc.desc(), + transformed_dx, + workspace_ptr, + workspace_size)); + }, + workspace_size); +#else + ConvRunner::Apply(dev_ctx, + args4, + data_result, + transformed_dy_channel, + ddw, + transformed_dx, + groups, + group_offset_in, + group_offset_filter, + group_offset_out, + workspace_size, + &workspace_handle, + false); +#endif + + if (!is_sys_pad) { + // reverse padded input + std::vector starts(X->dims().size(), 0); + std::vector axes(X->dims().size(), 0); + + for (size_t i = 0; i < X->dims().size(); ++i) { + starts[i] = input_pad[2 * i]; + axes[i] = i; + } + if (X->dims().size() == 4) { + RemovePaddingSlice( + dev_ctx, &transformed_dX, &transformed_dX_channel, starts, axes); + } else { + RemovePaddingSlice( + dev_ctx, &transformed_dX, &transformed_dX_channel, starts, axes); + } + } + if (channel_last) { + TransToChannelLast(dev_ctx, &transformed_dX_channel, dX); + } + } +} + +template +void DepthwiseConvDoubleGradGPUDNNKernel( + const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& filter, + const DenseTensor& out_grad, + const paddle::optional& input_grad_grad, + const paddle::optional& filter_grad_grad, + const std::vector& strides, + const std::vector& paddings_t, + const std::string& padding_algorithm, + int groups, + const std::vector& dilations_t, + const std::string& data_format, + DenseTensor* input_grad, + DenseTensor* filter_grad, + DenseTensor* out_grad_grad) { + ConvCudnnGradGradKernel(dev_ctx, + input, + filter, + out_grad, + input_grad_grad, + filter_grad_grad, + strides, + paddings_t, + padding_algorithm, + dilations_t, + groups, + data_format, + input_grad, + filter_grad, + out_grad_grad); +} + +template +void Conv3DCudnnDoubleGradKernel( + const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& filter, + const DenseTensor& out_grad, + const paddle::optional& input_grad_grad, + const paddle::optional& filter_grad_grad, + const std::vector& strides, + const std::vector& paddings_t, + const std::string& padding_algorithm, + int groups, + const std::vector& dilations_t, + const std::string& data_format, + DenseTensor* input_grad, + DenseTensor* filter_grad, + DenseTensor* out_grad_grad) { + ConvCudnnGradGradKernel(dev_ctx, + input, + filter, + out_grad, + input_grad_grad, + filter_grad_grad, + strides, + paddings_t, + padding_algorithm, + dilations_t, + groups, + data_format, + input_grad, + filter_grad, + out_grad_grad); +} + +} // namespace phi + +#ifdef PADDLE_WITH_HIP +PD_REGISTER_PLUGIN_KERNEL(conv2d_grad, + metax_gpu, + ALL_LAYOUT, + phi::ConvCudnnGradKernel, + float, + phi::dtype::float16) {} + +PD_REGISTER_PLUGIN_KERNEL(conv3d_grad, + metax_gpu, + ALL_LAYOUT, + phi::Conv3DCudnnGradKernel, + float, + phi::dtype::float16) {} +PD_REGISTER_PLUGIN_KERNEL(conv2d_double_grad, + metax_gpu, + ALL_LAYOUT, + phi::ConvCudnnGradGradKernel, + float, + phi::dtype::float16) {} + +PD_REGISTER_PLUGIN_KERNEL(conv3d_double_grad, + metax_gpu, + ALL_LAYOUT, + phi::Conv3DCudnnDoubleGradKernel, + float, + phi::dtype::float16) {} + +PD_REGISTER_PLUGIN_KERNEL(depthwise_conv2d_double_grad, + GPU, + ALL_LAYOUT, + phi::DepthwiseConvDoubleGradGPUDNNKernel, + float, + phi::dtype::float16) {} +#else +#if CUDNN_VERSION_MIN(8, 1, 0) +PD_REGISTER_PLUGIN_KERNEL(conv2d_grad, + metax_gpu, + ALL_LAYOUT, + phi::ConvCudnnGradKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} + +PD_REGISTER_PLUGIN_KERNEL(conv3d_grad, + metax_gpu, + ALL_LAYOUT, + phi::Conv3DCudnnGradKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} +PD_REGISTER_PLUGIN_KERNEL(conv2d_double_grad, + metax_gpu, + ALL_LAYOUT, + phi::ConvCudnnGradGradKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} + +PD_REGISTER_PLUGIN_KERNEL(conv3d_double_grad, + metax_gpu, + ALL_LAYOUT, + phi::Conv3DCudnnDoubleGradKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} + +PD_REGISTER_PLUGIN_KERNEL(depthwise_conv2d_double_grad, + metax_gpu, + ALL_LAYOUT, + phi::DepthwiseConvDoubleGradGPUDNNKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} +#else +PD_REGISTER_PLUGIN_KERNEL(conv2d_grad, + metax_gpu, + ALL_LAYOUT, + phi::ConvCudnnGradKernel, + float, + double, + phi::dtype::float16) {} + +PD_REGISTER_PLUGIN_KERNEL(conv3d_grad, + metax_gpu, + ALL_LAYOUT, + phi::Conv3DCudnnGradKernel, + float, + double, + phi::dtype::float16) {} + +PD_REGISTER_PLUGIN_KERNEL(conv2d_double_grad, + metax_gpu, + ALL_LAYOUT, + phi::ConvCudnnGradGradKernel, + float, + double, + phi::dtype::float16) {} + +PD_REGISTER_PLUGIN_KERNEL(conv3d_double_grad, + metax_gpu, + ALL_LAYOUT, + phi::Conv3DCudnnDoubleGradKernel, + float, + double, + phi::dtype::float16) {} + +PD_REGISTER_PLUGIN_KERNEL(depthwise_conv2d_double_grad, + metax_gpu, + ALL_LAYOUT, + phi::DepthwiseConvDoubleGradGPUDNNKernel, + float, + double, + phi::dtype::float16) {} +#endif + +#endif From a0cb0a7c91e9764a46fb7cf698658d3f7d9cd280 Mon Sep 17 00:00:00 2001 From: jiaxinWang-metax <189149612@qq.com> Date: Tue, 16 Sep 2025 17:24:55 +0800 Subject: [PATCH 04/12] modify library to static library --- backends/metax_gpu/cmake/warpctc.cmake | 19 +++++++++---------- backends/metax_gpu/cmake/warprnnt.cmake | 19 +++++++++---------- 2 files changed, 18 insertions(+), 20 deletions(-) diff --git a/backends/metax_gpu/cmake/warpctc.cmake b/backends/metax_gpu/cmake/warpctc.cmake index 9edc92f0a94..0733c0f9ce5 100644 --- a/backends/metax_gpu/cmake/warpctc.cmake +++ b/backends/metax_gpu/cmake/warpctc.cmake @@ -66,11 +66,11 @@ set(WARPCTC_LIB_DIR if(WIN32) set(WARPCTC_LIBRARIES - "${WARPCTC_INSTALL_DIR}/bin/warpctc${CMAKE_SHARED_LIBRARY_SUFFIX}" + "${WARPCTC_INSTALL_DIR}/bin/warpctc${CMAKE_STATIC_LIBRARY_SUFFIX}" CACHE FILEPATH "Warp-ctc Library" FORCE) else() set(WARPCTC_LIBRARIES - "${WARPCTC_INSTALL_DIR}/lib/libwarpctc${CMAKE_SHARED_LIBRARY_SUFFIX}" + "${WARPCTC_INSTALL_DIR}/lib/libwarpctc${CMAKE_STATIC_LIBRARY_SUFFIX}" CACHE FILEPATH "Warp-ctc Library" FORCE) endif() @@ -93,10 +93,10 @@ if(WIN32) set(WARPCTC_CXX_FLAGS_DEBUG $) else() - set(WARPCTC_C_FLAGS ${CMAKE_C_FLAGS}) + set(WARPCTC_C_FLAGS "${CMAKE_C_FLAGS} -fPIC") set(WARPCTC_C_FLAGS_DEBUG ${CMAKE_C_FLAGS_DEBUG}) set(WARPCTC_C_FLAGS_RELEASE ${CMAKE_C_FLAGS_RELEASE}) - set(WARPCTC_CXX_FLAGS ${CMAKE_CXX_FLAGS}) + set(WARPCTC_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC") set(WARPCTC_CXX_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE}) set(WARPCTC_CXX_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG}) endif() @@ -127,7 +127,7 @@ ExternalProject_Add( -DNVCC_FLAGS_EXTRA=${NVCC_FLAGS_EXTRA} -DWITH_TORCH=OFF -DCMAKE_DISABLE_FIND_PACKAGE_Torch=ON - -DBUILD_SHARED=ON + -DBUILD_SHARED=OFF -DBUILD_TESTS=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} @@ -145,8 +145,7 @@ get_filename_component(WARPCTC_LIBRARY_PATH ${WARPCTC_LIBRARIES} DIRECTORY) include_directories(${WARPCTC_INCLUDE_DIR}) # For warpctc code to include its # headers. -add_library(warpctc SHARED IMPORTED GLOBAL) -set_target_properties(warpctc PROPERTIES - IMPORTED_LOCATION ${WARPCTC_LIBRARIES} - INTERFACE_INCLUDE_DIRECTORIES ${WARPCTC_INCLUDE_DIR} -) \ No newline at end of file +add_library(warpctc STATIC IMPORTED GLOBAL) +set_target_properties( + warpctc PROPERTIES IMPORTED_LOCATION ${WARPCTC_LIBRARIES} + INTERFACE_INCLUDE_DIRECTORIES ${WARPCTC_INCLUDE_DIR}) diff --git a/backends/metax_gpu/cmake/warprnnt.cmake b/backends/metax_gpu/cmake/warprnnt.cmake index 527f2e55a1b..a8d6683af2b 100644 --- a/backends/metax_gpu/cmake/warprnnt.cmake +++ b/backends/metax_gpu/cmake/warprnnt.cmake @@ -62,11 +62,11 @@ set(WARPRNNT_LIB_DIR if(WIN32) set(WARPRNNT_LIBRARIES - "${WARPRNNT_INSTALL_DIR}/bin/warprnnt${CMAKE_SHARED_LIBRARY_SUFFIX}" + "${WARPRNNT_INSTALL_DIR}/bin/warprnnt${CMAKE_STATIC_LIBRARY_SUFFIX}" CACHE FILEPATH "Warp-rnnt Library" FORCE) else() set(WARPRNNT_LIBRARIES - "${WARPRNNT_INSTALL_DIR}/lib/libwarprnnt${CMAKE_SHARED_LIBRARY_SUFFIX}" + "${WARPRNNT_INSTALL_DIR}/lib/libwarprnnt${CMAKE_STATIC_LIBRARY_SUFFIX}" CACHE FILEPATH "Warp-rnnt Library" FORCE) endif() @@ -90,10 +90,10 @@ if(WIN32) set(WARPRNNT_CXX_FLAGS_DEBUG $) else() - set(WARPRNNT_C_FLAGS ${CMAKE_C_FLAGS}) + set(WARPRNNT_C_FLAGS "${CMAKE_C_FLAGS} -fPIC") set(WARPRNNT_C_FLAGS_DEBUG ${CMAKE_C_FLAGS_DEBUG}) set(WARPRNNT_C_FLAGS_RELEASE ${CMAKE_C_FLAGS_RELEASE}) - set(WARPRNNT_CXX_FLAGS ${CMAKE_CXX_FLAGS}) + set(WARPRNNT_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC") set(WARPRNNT_CXX_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE}) set(WARPRNNT_CXX_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG}) endif() @@ -120,7 +120,7 @@ ExternalProject_Add( -DWITH_ROCM=${WITH_ROCM} -DWITH_OMP=${USE_OMP} -DNVCC_FLAGS_EXTRA=${NVCC_FLAGS_EXTRA} - -DBUILD_SHARED=ON + -DBUILD_SHARED=OFF -DBUILD_TESTS=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} @@ -137,8 +137,7 @@ get_filename_component(WARPRNNT_LIBRARY_PATH ${WARPRNNT_LIBRARIES} DIRECTORY) include_directories(${WARPRNNT_INCLUDE_DIR}) # For warprnnt code to include its # headers. -add_library(warprnnt SHARED IMPORTED GLOBAL) -set_target_properties(warprnnt PROPERTIES - IMPORTED_LOCATION ${WARPRNNT_LIBRARIES} - INTERFACE_INCLUDE_DIRECTORIES ${WARPRNNT_INCLUDE_DIR} -) \ No newline at end of file +add_library(warprnnt STATIC IMPORTED GLOBAL) +set_target_properties( + warprnnt PROPERTIES IMPORTED_LOCATION ${WARPRNNT_LIBRARIES} + INTERFACE_INCLUDE_DIRECTORIES ${WARPRNNT_INCLUDE_DIR}) From 7b018df3c526a3febb2387760da0bb4c3823c535 Mon Sep 17 00:00:00 2001 From: jiaxinWang-metax <189149612@qq.com> Date: Wed, 17 Sep 2025 13:54:30 +0800 Subject: [PATCH 05/12] modify kernel --- backends/metax_gpu/patch/paddle.patch | 257 ++++++++++++++------------ 1 file changed, 138 insertions(+), 119 deletions(-) diff --git a/backends/metax_gpu/patch/paddle.patch b/backends/metax_gpu/patch/paddle.patch index e56826c4f3e..667d9f75d1c 100755 --- a/backends/metax_gpu/patch/paddle.patch +++ b/backends/metax_gpu/patch/paddle.patch @@ -16,16 +16,16 @@ index cfada544d4..a690e97d74 100644 - set(EIGEN_PATCH_COMMAND ${EIGEN_PATCH_COMMAND} && git apply ${complex_header}) + # set(EIGEN_PATCH_COMMAND ${EIGEN_PATCH_COMMAND} && git apply ${complex_header}) endif() - + set(EIGEN_INCLUDE_DIR ${SOURCE_DIR}) diff --git a/paddle/fluid/platform/profiler/cupti_data_process.cc b/paddle/fluid/platform/profiler/cupti_data_process.cc index bff0f2bf70..9376b5781f 100644 --- a/paddle/fluid/platform/profiler/cupti_data_process.cc +++ b/paddle/fluid/platform/profiler/cupti_data_process.cc @@ -16,7 +16,7 @@ - + #include - + -#include "paddle/fluid/platform/enforce.h" +// #include "paddle/fluid/platform/enforce.h" #include "paddle/phi/core/os_info.h" @@ -76,7 +76,7 @@ index c0080f0a5e..458ca3e2e8 100644 + __macro(cudnnDestroyActivationDescriptor); \ + __macro(cudnnSetRNNDescriptor_v6); CUDNN_DNN_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) - + #if CUDNN_VERSION >= 7000 && CUDNN_VERSION < 8000 @@ -152,7 +161,12 @@ CUDNN_DNN_ROUTINE_EACH_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) #define CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7(__macro) \ @@ -91,11 +91,11 @@ index c0080f0a5e..458ca3e2e8 100644 + __macro(cudnnRNNForwardInferenceEx); CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) #endif - + @@ -195,40 +209,6 @@ CUDNN_DNN_ROUTINE_EACH_R8(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) CUDNN_DNN_ROUTINE_EACH_FRONTEND(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) #endif - + -#if CUDNN_VERSION < 90000 -#define CUDNN_DNN_ROUTINE_EACH_REMOVED_IN_E9(__macro) \ - __macro(cudnnGetRNNParamsSize); \ @@ -132,15 +132,15 @@ index c0080f0a5e..458ca3e2e8 100644 -#endif } // namespace dynload } // namespace phi - + diff --git a/paddle/phi/backends/dynload/cufft.h b/paddle/phi/backends/dynload/cufft.h -index 1547909d92..66b2779392 100644 +index 1547909d92..ef20838434 100644 --- a/paddle/phi/backends/dynload/cufft.h +++ b/paddle/phi/backends/dynload/cufft.h @@ -1,3 +1,4 @@ +// 2024 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - + Licensed under the Apache License, Version 2.0 (the "License"); @@ -40,7 +41,9 @@ extern void EnforceCUFFTLoaded(const char* fn_name); cufft_dso_handle = phi::dynload::GetCUFFTDsoHandle(); \ @@ -160,23 +160,23 @@ index 59e92955c9..d2f8c2da15 100644 @@ -24,8 +24,8 @@ limitations under the License. */ #include "paddle/phi/backends/dynload/dynamic_loader.h" #include "paddle/phi/common/port.h" - + -namespace phi { -namespace dynload { +// namespace phi { +// namespace dynload { - + extern std::once_flag cupti_dso_flag; extern void *cupti_dso_handle; @@ -71,7 +71,7 @@ extern void *cupti_dso_handle; CUPTI_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUPTI_WRAP); - + #undef DECLARE_DYNAMIC_LOAD_CUPTI_WRAP -} // namespace dynload -} // namespace phi +// } // namespace dynload +// } // namespace phi - + -#endif // PADDLE_WITH_CUPTI +#endif // PADDLE_WITH_CUPTI \ No newline at end of file @@ -226,32 +226,32 @@ index c5309e7e11..3328571380 100644 } \ }; \ diff --git a/paddle/phi/backends/gpu/cuda/cuda_device_function.h b/paddle/phi/backends/gpu/cuda/cuda_device_function.h -index 4ff2e528a9..81421c8ca1 100644 +index 4ff2e528a9..23f7f4b583 100644 --- a/paddle/phi/backends/gpu/cuda/cuda_device_function.h +++ b/paddle/phi/backends/gpu/cuda/cuda_device_function.h @@ -1,3 +1,4 @@ +// 2024 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved. /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - + Licensed under the Apache License, Version 2.0 (the "License"); @@ -25,7 +26,7 @@ namespace phi { namespace backends { namespace gpu { - + -#define FULL_WARP_MASK 0xFFFFFFFF +#define FULL_WARP_MASK 0xFFFFFFFFFFFFFFFFULL #define CREATE_SHFL_MASK(mask, predicate) \ mask = __ballot_sync(FULL_WARP_MASK, (predicate)) - + @@ -45,12 +46,12 @@ namespace gpu { - + template __forceinline__ __device__ T -CudaShuffleDownSync(unsigned mask, T val, int delta, int width = warpSize) { +CudaShuffleDownSync(unsigned long long mask, T val, int delta, int width = warpSize) { return __shfl_down_sync(mask, val, static_cast(delta), width); } - + template -__forceinline__ __device__ T CudaShuffleXorSync(unsigned mask, +__forceinline__ __device__ T CudaShuffleXorSync(unsigned long long mask, @@ -259,7 +259,7 @@ index 4ff2e528a9..81421c8ca1 100644 int width = warpSize) { return __shfl_xor_sync(mask, val, width); @@ -58,14 +59,14 @@ __forceinline__ __device__ T CudaShuffleXorSync(unsigned mask, - + template <> __forceinline__ __device__ phi::dtype::float16 CudaShuffleDownSync( - unsigned mask, phi::dtype::float16 val, int delta, int width) { @@ -267,7 +267,7 @@ index 4ff2e528a9..81421c8ca1 100644 return phi::dtype::float16(__shfl_down_sync( mask, val.to_half(), static_cast(delta), width)); } - + template <> __forceinline__ __device__ phi::dtype::bfloat16 CudaShuffleDownSync( - unsigned mask, phi::dtype::bfloat16 val, int delta, int width) { @@ -276,7 +276,7 @@ index 4ff2e528a9..81421c8ca1 100644 return phi::dtype::bfloat16(__shfl_down_sync( mask, val.to_nv_bfloat16(), static_cast(delta), width)); @@ -77,7 +78,7 @@ __forceinline__ __device__ phi::dtype::bfloat16 CudaShuffleDownSync( - + template <> __forceinline__ __device__ phi::dtype::complex CudaShuffleDownSync( - unsigned mask, phi::dtype::complex val, int delta, int width) { @@ -285,7 +285,7 @@ index 4ff2e528a9..81421c8ca1 100644 mask, static_cast(val.real), static_cast(delta), width)); float imag = static_cast(__shfl_down_sync( @@ -87,7 +88,7 @@ __forceinline__ __device__ phi::dtype::complex CudaShuffleDownSync( - + template <> __forceinline__ __device__ phi::dtype::complex CudaShuffleDownSync( - unsigned mask, phi::dtype::complex val, int delta, int width) { @@ -294,14 +294,14 @@ index 4ff2e528a9..81421c8ca1 100644 static_cast(__shfl_down_sync(mask, static_cast(val.real), @@ -103,13 +104,13 @@ __forceinline__ __device__ phi::dtype::complex CudaShuffleDownSync( - + template <> __forceinline__ __device__ phi::dtype::float16 CudaShuffleXorSync( - unsigned mask, phi::dtype::float16 val, int width) { + unsigned long long mask, phi::dtype::float16 val, int width) { return phi::dtype::float16(__shfl_xor_sync(mask, val.to_half(), width)); } - + template <> __forceinline__ __device__ phi::dtype::bfloat16 CudaShuffleXorSync( - unsigned mask, phi::dtype::bfloat16 val, int width) { @@ -310,7 +310,7 @@ index 4ff2e528a9..81421c8ca1 100644 return phi::dtype::bfloat16( __shfl_xor_sync(mask, val.to_nv_bfloat16(), width)); @@ -121,7 +122,7 @@ __forceinline__ __device__ phi::dtype::bfloat16 CudaShuffleXorSync( - + template <> __forceinline__ __device__ phi::dtype::complex CudaShuffleXorSync( - unsigned mask, phi::dtype::complex val, int width) { @@ -319,7 +319,7 @@ index 4ff2e528a9..81421c8ca1 100644 __shfl_xor_sync(mask, static_cast(val.real), width)); float imag = static_cast( @@ -131,7 +132,7 @@ __forceinline__ __device__ phi::dtype::complex CudaShuffleXorSync( - + template <> __forceinline__ __device__ phi::dtype::complex CudaShuffleXorSync( - unsigned mask, phi::dtype::complex val, int width) { @@ -328,14 +328,14 @@ index 4ff2e528a9..81421c8ca1 100644 __shfl_xor_sync(mask, static_cast(val.real), width)); double imag = static_cast( @@ -141,7 +142,7 @@ __forceinline__ __device__ phi::dtype::complex CudaShuffleXorSync( - + template __forceinline__ __device__ T -CudaShuffleSync(unsigned mask, T val, int src_line, int width = 32) { +CudaShuffleSync(unsigned long long mask, T val, int src_line, int width = 32) { return __shfl_sync(mask, val, src_line, width); } - + @@ -160,7 +161,7 @@ __device__ T reduceSum(T val, int tid, int len) { // but most card's warp size is 32. const int warpSize = 32; @@ -343,7 +343,7 @@ index 4ff2e528a9..81421c8ca1 100644 - unsigned mask = 0u; + unsigned long long mask = 0ull; CREATE_SHFL_MASK(mask, tid < len); - + for (int offset = warpSize / 2; offset > 0; offset /= 2) diff --git a/paddle/phi/core/enforce.h b/paddle/phi/core/enforce.h index 024a7de73e..1e4cdf16be 100644 @@ -351,7 +351,7 @@ index 024a7de73e..1e4cdf16be 100644 +++ b/paddle/phi/core/enforce.h @@ -45,7 +45,9 @@ limitations under the License. */ #endif - + #ifdef PADDLE_WITH_CUDA -#include "paddle/phi/backends/dynload/cublas.h" +// #include "paddle/phi/backends/dynload/../../../../../cublas.h" @@ -361,9 +361,9 @@ index 024a7de73e..1e4cdf16be 100644 #include "paddle/phi/backends/dynload/curand.h" #include "paddle/phi/backends/dynload/cusolver.h" @@ -97,7 +99,7 @@ inline bool is_error(bool stat) { return !stat; } - + void ThrowWarnInternal(const std::string& message); - + -#if defined(__CUDA_ARCH__) +#if defined(__CUDACC__) // For cuda, the assertions can affect performance and it is therefore @@ -379,7 +379,7 @@ index 024a7de73e..1e4cdf16be 100644 } while (0) #elif defined(__HIPCC__) @@ -757,4 +759,4 @@ inline void retry_sleep(unsigned millisecond) { - + } // namespace enforce using namespace enforce; // NOLINT -} // namespace phi @@ -392,7 +392,7 @@ index c646e487d0..325122175c 100644 @@ -25,8 +25,9 @@ #else #include - + -#include "paddle/phi/backends/dynload/cublas.h" -#include "paddle/phi/backends/dynload/cublasLt.h" +// #include "paddle/phi/backends/dynload/cublas.h" @@ -400,16 +400,16 @@ index c646e487d0..325122175c 100644 +// #include "paddle/phi/backends/dynload/cublasLt.h" #include "paddle/phi/backends/dynload/cudnn.h" #endif - + @@ -90,7 +91,7 @@ DECLARE_TYPE_FOR_GPU(gpuStreamCaptureMode, - + // TODO(Ming Huang): Since there is no blasLt handler, // use rocblas_handle for workaround. -DECLARE_TYPE_FOR_GPU(blasLtHandle_t, cublasLtHandle_t, rocblas_handle); +// DECLARE_TYPE_FOR_GPU(blasLtHandle_t, cublasLtHandle_t, rocblas_handle); - + #undef DECLARE_TYPE_FOR_GPU - + diff --git a/paddle/phi/core/platform/device_context.h b/paddle/phi/core/platform/device_context.h index 2d02eb370b..8a7233e34e 100644 --- a/paddle/phi/core/platform/device_context.h @@ -430,58 +430,58 @@ index d69eb67d6f..1d8b6e9375 100644 --- a/paddle/phi/kernels/cpu/index_select_impl.h +++ b/paddle/phi/kernels/cpu/index_select_impl.h @@ -18,7 +18,7 @@ - + #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/tensor_utils.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/math_function.h" - + diff --git a/paddle/phi/kernels/funcs/fc_functor.cu b/paddle/phi/kernels/funcs/fc_functor.cu index cb35feee32..64f5bd24ac 100644 --- a/paddle/phi/kernels/funcs/fc_functor.cu +++ b/paddle/phi/kernels/funcs/fc_functor.cu @@ -16,12 +16,12 @@ limitations under the License. */ - + #include "paddle/phi/backends/all_context.h" #include "paddle/phi/kernels/funcs/aligned_vector.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/fc_functor.h" - + #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h" +// #include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h" #include "paddle/phi/kernels/funcs/quant_dequant.h" #include "paddle/phi/kernels/matmul_kernel.h" - + diff --git a/paddle/phi/kernels/funcs/gru_compute.cu b/paddle/phi/kernels/funcs/gru_compute.cu index 88663ec880..98b93072a3 100644 --- a/paddle/phi/kernels/funcs/gru_compute.cu +++ b/paddle/phi/kernels/funcs/gru_compute.cu @@ -12,7 +12,7 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/gru_compute.h" - + #include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/detail/gru_gpu_kernel.h" #include "paddle/phi/kernels/funcs/detail/gru_kernel.h" - + diff --git a/paddle/phi/kernels/funcs/math/context_project.h b/paddle/phi/kernels/funcs/math/context_project.h index 15e1a4a3c3..e4780538d7 100644 --- a/paddle/phi/kernels/funcs/math/context_project.h +++ b/paddle/phi/kernels/funcs/math/context_project.h @@ -18,7 +18,7 @@ #include - + #include "paddle/phi/core/tensor_utils.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/im2col.h" - + namespace phi { diff --git a/paddle/phi/kernels/funcs/matrix_inverse.cu b/paddle/phi/kernels/funcs/matrix_inverse.cu index e101224970..a52eb6096f 100644 @@ -489,14 +489,14 @@ index e101224970..a52eb6096f 100644 +++ b/paddle/phi/kernels/funcs/matrix_inverse.cu @@ -15,11 +15,13 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/matrix_inverse.h" - + #include "paddle/phi/common/memory_utils.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "kernels/funcs/blas/blas.h" - + namespace phi { namespace funcs { - + + + template @@ -514,28 +514,28 @@ index 558d363b39..05da04b517 100644 +#include "kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/scatter.cu.h" - + diff --git a/paddle/phi/kernels/funcs/multihead_matmul_functor.cu b/paddle/phi/kernels/funcs/multihead_matmul_functor.cu index 8b0baf5f5f..260482f124 100644 --- a/paddle/phi/kernels/funcs/multihead_matmul_functor.cu +++ b/paddle/phi/kernels/funcs/multihead_matmul_functor.cu @@ -27,7 +27,7 @@ namespace cub = hipcub; - + #include "paddle/phi/kernels/funcs/multihead_matmul_functor.h" - + -#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/math_cuda_utils.h" - + namespace phi { diff --git a/paddle/phi/kernels/funcs/top_k_function_cuda.h b/paddle/phi/kernels/funcs/top_k_function_cuda.h -index e30d440ff3..3c74792690 100644 +index e30d440ff3..108edda7ca 100644 --- a/paddle/phi/kernels/funcs/top_k_function_cuda.h +++ b/paddle/phi/kernels/funcs/top_k_function_cuda.h @@ -30,11 +30,11 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" #include "paddle/phi/kernels/primitive/functor_primitives.h" - + -#define FINAL_MASK 0xffffffff +#define FINAL_MASK 0xffffffffffffffffull #ifdef PADDLE_WITH_HIP @@ -545,7 +545,7 @@ index e30d440ff3..3c74792690 100644 +#define WARP_SIZE 64 #endif #define MAX_NUM_THREADS 1024 - + @@ -196,21 +196,56 @@ __device__ __forceinline__ void AddTo(Pair topk[], for (int k = beam_size - 2; k >= 0; k--) { if (largest) { @@ -606,7 +606,7 @@ index e30d440ff3..3c74792690 100644 + topk[0 + offset].v = p.v; + topk[0 + offset].id = p.id; } - + template @@ -239,24 +274,24 @@ __device__ __forceinline__ void GetTopK(Pair topk[], template @@ -662,7 +662,7 @@ index e30d440ff3..3c74792690 100644 + // topk + MaxLength - *beam, src, tid, dim, *max, length, largest); } } - + @@ -355,6 +394,8 @@ __device__ __forceinline__ void BlockReduce(Pair shared_max[], shared_max[wid] = input_now; } @@ -697,7 +697,7 @@ index e30d440ff3..3c74792690 100644 - if (--(*k) == 0) break; + // if (--(*k) == 0) break; + unsigned long long mask = 0ull; - + - unsigned mask = 0u; + // unsigned mask = 0u; CREATE_SHFL_MASK(mask, true); @@ -721,7 +721,7 @@ index e30d440ff3..3c74792690 100644 + return ret; } - + static __device__ __forceinline__ unsigned int SetBitfield( unsigned int val, unsigned int to_insert, int pos, int len) { unsigned int ret; @@ -743,7 +743,7 @@ index e30d440ff3..3c74792690 100644 + ret = (static_cast(val) << (64 - pos - len)) >> (64 - len); return ret; } - + @@ -507,9 +556,9 @@ struct Bitfield { int pos, int len) { @@ -771,7 +771,7 @@ index e30d440ff3..3c74792690 100644 + return ::__lane_id(); + // return lane_id; } - + __device__ __forceinline__ unsigned GetLaneMaskLe() { unsigned mask; - asm("mov.u32 %0, %%lanemask_le;" : "=r"(mask)); @@ -780,17 +780,17 @@ index e30d440ff3..3c74792690 100644 + return ((uint64_t(1) << ::__lane_id()) << 1) - 1; + // return mask; } - + template @@ -881,7 +936,8 @@ __global__ void GatherKthValue(const T* input, - + // 1. Find the k-th value T kth_value = static_cast(0); - RadixSearch::RadixType, IndexType, false>( + // RadixSearch::RadixType, IndexType, false>( + RadixSearch::RadixType, IndexType, false>( cur_input, k, num_cols, shared_mem, &kth_value); - + __shared__ int64_t block_min_idx; @@ -1314,3 +1370,4 @@ bool SortTopk(const phi::GPUContext& dev_ctx, } @@ -803,12 +803,12 @@ index 32db61532f..0220316bc3 100644 +++ b/paddle/phi/kernels/fusion/gpu/fused_dropout_helper.h @@ -15,7 +15,7 @@ #pragma once - + #if defined(PADDLE_WITH_CUDA) -#include "paddle/phi/backends/dynload/cublasLt.h" +// #include "paddle/phi/backends/dynload/cublasLt.h" #endif - + #include "glog/logging.h" diff --git a/paddle/phi/kernels/fusion/gpu/fused_layernorm_residual_dropout_bias.h b/paddle/phi/kernels/fusion/gpu/fused_layernorm_residual_dropout_bias.h index 9d4bb18d55..ea42cc10a9 100644 @@ -830,12 +830,12 @@ index b8cfdbf3ce..fa14b94a77 100644 --- a/paddle/phi/kernels/fusion/gpu/masked_multihead_attention_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/masked_multihead_attention_kernel.cu @@ -14,7 +14,7 @@ - + #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/aligned_vector.h" -#include "paddle/phi/kernels/fusion/gpu/mmha_util.cu.h" +#include "kernels/metax_kernel/mmha_util.cu.h" - + namespace phi { namespace fusion { diff --git a/paddle/phi/kernels/fusion/gpu/qkv_unpack_mha_kernel.cu b/paddle/phi/kernels/fusion/gpu/qkv_unpack_mha_kernel.cu @@ -843,14 +843,27 @@ index e838778952..83e805e75a 100644 --- a/paddle/phi/kernels/fusion/gpu/qkv_unpack_mha_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/qkv_unpack_mha_kernel.cu @@ -14,7 +14,7 @@ - + #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/aligned_vector.h" -#include "paddle/phi/kernels/fusion/gpu/mmha_util.cu.h" +#include "kernels/metax_kernel/mmha_util.cu.h" - + namespace phi { namespace fusion { +diff --git a/paddle/phi/kernels/gpu/correlation_kernel.cu b/paddle/phi/kernels/gpu/correlation_kernel.cu +index 4c93778bde..c7bdf8a2cc 100644 +--- a/paddle/phi/kernels/gpu/correlation_kernel.cu ++++ b/paddle/phi/kernels/gpu/correlation_kernel.cu +@@ -103,7 +103,7 @@ void CorrelationCUDAKernel(const Context &dev_ctx, + int stride2, + int corr_type_multiply, + DenseTensor *out) { +- bool is_gpu_place = dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU; ++ bool is_gpu_place = dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU || dev_ctx.GetPlace().GetType() == phi::AllocationType::CUSTOM; + PADDLE_ENFORCE_EQ( + is_gpu_place, + true, diff --git a/paddle/phi/kernels/gpu/depthwise_conv.h b/paddle/phi/kernels/gpu/depthwise_conv.h index f0cca0f701..02ea957240 100644 --- a/paddle/phi/kernels/gpu/depthwise_conv.h @@ -863,9 +876,22 @@ index f0cca0f701..02ea957240 100644 -#include "paddle/phi/kernels/impl/conv_cudnn_impl.h" +#include "kernels/gpudnn/conv_gpudnn.h" +#include "kernels/impl/conv_cudnn_impl.h" - + namespace phi { // To determine use cudnn or not. +diff --git a/paddle/phi/kernels/gpu/dgc_kernel.cu b/paddle/phi/kernels/gpu/dgc_kernel.cu +index c2ddfa1347..c6adf5a6de 100644 +--- a/paddle/phi/kernels/gpu/dgc_kernel.cu ++++ b/paddle/phi/kernels/gpu/dgc_kernel.cu +@@ -188,7 +188,7 @@ void DGCKernel(const Context& dev_ctx, + int buf_size = paddle::communication::dgc::get_buffer_size(k); + phi::Allocator::AllocationPtr tmp_ious_data; + #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +- if (dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU) { ++ if (dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU || dev_ctx.GetPlace().GetType() == phi::AllocationType::CUSTOM) { + tmp_ious_data = phi::memory_utils::Alloc( + dev_ctx.GetPlace(), + buf_size, diff --git a/paddle/phi/kernels/gpu/gelu_funcs.h b/paddle/phi/kernels/gpu/gelu_funcs.h index 29fa252e96..4ae72b0935 100644 --- a/paddle/phi/kernels/gpu/gelu_funcs.h @@ -890,7 +916,7 @@ index 29fa252e96..4ae72b0935 100644 +// #endif return tanhf(x); } - + diff --git a/paddle/phi/kernels/gpu/log_softmax_grad_kernel.cu b/paddle/phi/kernels/gpu/log_softmax_grad_kernel.cu index 11efd87965..679db14c24 100644 --- a/paddle/phi/kernels/gpu/log_softmax_grad_kernel.cu @@ -901,9 +927,9 @@ index 11efd87965..679db14c24 100644 #include "paddle/phi/kernels/funcs/math_function.h" -#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h" +#include "kernels/gpudnn/softmax_gpudnn.h" - + namespace phi { - + diff --git a/paddle/phi/kernels/gpu/log_softmax_kernel.cu b/paddle/phi/kernels/gpu/log_softmax_kernel.cu index 63c35dd4ee..15da9aea45 100644 --- a/paddle/phi/kernels/gpu/log_softmax_kernel.cu @@ -914,9 +940,9 @@ index 63c35dd4ee..15da9aea45 100644 #include "paddle/phi/kernels/funcs/math_function.h" -#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h" +#include "kernels/gpudnn/softmax_gpudnn.h" - + namespace phi { - + diff --git a/paddle/phi/kernels/gpu/lstsq_kernel.cu b/paddle/phi/kernels/gpu/lstsq_kernel.cu index 1bdbe1564c..f753b54bc6 100644 --- a/paddle/phi/kernels/gpu/lstsq_kernel.cu @@ -930,6 +956,19 @@ index 1bdbe1564c..f753b54bc6 100644 #include "paddle/phi/kernels/impl/qr_kernel_impl.h" #include "paddle/phi/kernels/impl/tril_triu_kernel_impl.h" #include "paddle/phi/kernels/lstsq_kernel.h" +diff --git a/paddle/phi/kernels/gpu/shuffle_batch_kernel.cu b/paddle/phi/kernels/gpu/shuffle_batch_kernel.cu +index 05a977828f..5136608c41 100644 +--- a/paddle/phi/kernels/gpu/shuffle_batch_kernel.cu ++++ b/paddle/phi/kernels/gpu/shuffle_batch_kernel.cu +@@ -58,7 +58,7 @@ void ShuffleBatchKernel(const Context& dev_ctx, + int64_t seed_int = 0; + if (seed.initialized()) { + const auto& seed_place = seed.place().GetType(); +- bool is_gpu_place = seed_place == phi::AllocationType::GPU; ++ bool is_gpu_place = seed_place == phi::AllocationType::GPU || seed_place == phi::AllocationType::CUSTOM; + if (is_gpu_place) { + // NOTE: We have overwritten GetKernelTypeForVar, so seed_place would + // not be CUDAPlace in practice. This case would only happen in Python diff --git a/paddle/phi/kernels/impl/addmm_grad_kernel_impl.h b/paddle/phi/kernels/impl/addmm_grad_kernel_impl.h index 9bc5326c90..79b57a8203 100644 --- a/paddle/phi/kernels/impl/addmm_grad_kernel_impl.h @@ -948,7 +987,7 @@ index cf80666b4e..ca76e055fb 100644 --- a/paddle/phi/kernels/impl/baddbmm_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/baddbmm_grad_kernel_impl.h @@ -19,7 +19,7 @@ limitations under the License. */ - + #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/kernels/baddbmm_grad_kernel.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" @@ -961,14 +1000,14 @@ index 2789cb59a2..b91b076f7f 100644 --- a/paddle/phi/kernels/impl/baddbmm_kernel_impl.h +++ b/paddle/phi/kernels/impl/baddbmm_kernel_impl.h @@ -20,7 +20,7 @@ limitations under the License. */ - + #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/kernels/baddbmm_kernel.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" - + diff --git a/paddle/phi/kernels/impl/conv_transpose_grad_kernel_impl.h b/paddle/phi/kernels/impl/conv_transpose_grad_kernel_impl.h index 9a21c23666..86413d1577 100644 --- a/paddle/phi/kernels/impl/conv_transpose_grad_kernel_impl.h @@ -993,7 +1032,7 @@ index 4459a931da..837c8682b8 100644 -#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/deformable_conv_functor.h" - + namespace phi { diff --git a/paddle/phi/kernels/impl/deformable_conv_kernel_impl.h b/paddle/phi/kernels/impl/deformable_conv_kernel_impl.h index ad9e9197dd..5478d9817d 100644 @@ -1013,31 +1052,31 @@ index e6b3960f6d..564125f1f6 100644 --- a/paddle/phi/kernels/impl/gammaincc_kernel_impl.h +++ b/paddle/phi/kernels/impl/gammaincc_kernel_impl.h @@ -56,8 +56,8 @@ HOSTDEVICE T igam(const T a, const T x) { - + template HOSTDEVICE T igamc(const T a, const T x) { - static T big = 4.503599627370496e15; - static T biginv = 2.22044604925031308085e-16; + const static T big = 4.503599627370496e15; + const static T biginv = 2.22044604925031308085e-16; - + if ((x <= T{0}) || (a <= T{0})) return (T{1.0}); - + diff --git a/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h b/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h index 410fb3c560..009ce03440 100644 --- a/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h @@ -54,7 +54,7 @@ HOSTDEVICE T digamma_positive_domain(T x) { - + template HOSTDEVICE T digamma(T x) { - static T pi = T{3.14159265358979323846}; + const static T pi = T{3.14159265358979323846}; - + if (x == T{0.0}) { T inf = std::numeric_limits::infinity(); diff --git a/paddle/phi/kernels/impl/llm_int8_matmul_kernel_impl.h b/paddle/phi/kernels/impl/llm_int8_matmul_kernel_impl.h -index 5ebbc8d2db..48acf8d0cd 100644 +index 5ebbc8d2db..c7b6c338e2 100644 --- a/paddle/phi/kernels/impl/llm_int8_matmul_kernel_impl.h +++ b/paddle/phi/kernels/impl/llm_int8_matmul_kernel_impl.h @@ -15,8 +15,9 @@ limitations under the License. */ @@ -1049,11 +1088,11 @@ index 5ebbc8d2db..48acf8d0cd 100644 +#include "kernels/funcs/blas/cublaslt.h" +#include "kernels/funcs/quant_dequant.h" +#include "kernels/metax_kernel/metax_context.h" - + #pragma once - + @@ -668,7 +669,7 @@ void LLMGemm(const phi::GPUContext& dev_ctx, - + { auto helper = - std::make_unique(m, k, n, dev_ctx.cublaslt_handle()); @@ -1067,12 +1106,12 @@ index 1f319c4ae3..9186eb6906 100644 +++ b/paddle/phi/kernels/impl/matrix_power_grad_kernel_impl.h @@ -15,7 +15,7 @@ limitations under the License. */ #pragma once - + #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/matrix_inverse.h" - + namespace phi { diff --git a/paddle/phi/kernels/impl/matrix_power_kernel_impl.h b/paddle/phi/kernels/impl/matrix_power_kernel_impl.h index 6f03f76eeb..5fe2c3e7dc 100644 @@ -1080,13 +1119,13 @@ index 6f03f76eeb..5fe2c3e7dc 100644 +++ b/paddle/phi/kernels/impl/matrix_power_kernel_impl.h @@ -15,7 +15,7 @@ limitations under the License. */ #pragma once - + #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/for_range.h" #include "paddle/phi/kernels/funcs/matrix_inverse.h" - + diff --git a/paddle/phi/kernels/impl/merged_momentum_impl.h b/paddle/phi/kernels/impl/merged_momentum_impl.h index 7b85903776..3f4b298807 100644 --- a/paddle/phi/kernels/impl/merged_momentum_impl.h @@ -1118,31 +1157,11 @@ index 4099d8b506..baef2cd643 100644 --- a/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h +++ b/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h @@ -14,7 +14,7 @@ - + #pragma once - + -#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/math_function.h" - -diff --git a/third_party/flagcx b/third_party/flagcx -index 7c469f4af9..7e6c4cc3ca 160000 ---- a/third_party/flagcx -+++ b/third_party/flagcx -@@ -1 +1 @@ --Subproject commit 7c469f4af991bf0f64b8f76d66f8e307a5eaea3f -+Subproject commit 7e6c4cc3cad3fce9b3dedfe46a9d195d616e8ffa -diff --git a/third_party/flashattn b/third_party/flashattn -index 581e48aa69..749aca3807 160000 ---- a/third_party/flashattn -+++ b/third_party/flashattn -@@ -1 +1 @@ --Subproject commit 581e48aa693a17ec3676ec2715d46130310d318d -+Subproject commit 749aca380794b472096d4e7ea01dd252ab0887c9 -diff --git a/third_party/yaml-cpp b/third_party/yaml-cpp ---- a/third_party/yaml-cpp -+++ b/third_party/yaml-cpp -@@ -1 +1 @@ --Subproject commit 1d8ca1f35eb3a9c9142462b28282a848e5d29a91 -+Subproject commit 1d8ca1f35eb3a9c9142462b28282a848e5d29a91-dirty + From e61cf0d1fd68493ba100b6b12963d7a405437cf1 Mon Sep 17 00:00:00 2001 From: jiaxinWang-metax <189149612@qq.com> Date: Wed, 17 Sep 2025 17:22:59 +0800 Subject: [PATCH 06/12] modify fused_bias_dropout_residual_layer_norm --- backends/metax_gpu/patch/paddle.patch | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/backends/metax_gpu/patch/paddle.patch b/backends/metax_gpu/patch/paddle.patch index 667d9f75d1c..beefb730bf7 100755 --- a/backends/metax_gpu/patch/paddle.patch +++ b/backends/metax_gpu/patch/paddle.patch @@ -470,6 +470,24 @@ index 88663ec880..98b93072a3 100644 #include "paddle/phi/kernels/funcs/detail/gru_gpu_kernel.h" #include "paddle/phi/kernels/funcs/detail/gru_kernel.h" +diff --git a/paddle/phi/kernels/funcs/layer_norm_impl.cu.h b/paddle/phi/kernels/funcs/layer_norm_impl.cu.h +index 4eae698648..5c047723ea 100644 +--- a/paddle/phi/kernels/funcs/layer_norm_impl.cu.h ++++ b/paddle/phi/kernels/funcs/layer_norm_impl.cu.h +@@ -43,11 +43,11 @@ template + using LayerNormParamType = typename CudnnDataType::BatchNormParamType; + + inline static int GetDesiredBlockDim(int64_t block_dim) { +- const int kMaxBlockDim = 512; ++ const int kMaxBlockDim = 256; + #ifdef __HIPCC__ + const int lwarpSize = 64; + #else +- const int lwarpSize = 32; ++ const int lwarpSize = 64; + #endif + return block_dim >= kMaxBlockDim ? kMaxBlockDim : lwarpSize; + } diff --git a/paddle/phi/kernels/funcs/math/context_project.h b/paddle/phi/kernels/funcs/math/context_project.h index 15e1a4a3c3..e4780538d7 100644 --- a/paddle/phi/kernels/funcs/math/context_project.h From 2757fb7e20ec84bad839d4dd88bf865c695f2126 Mon Sep 17 00:00:00 2001 From: jiaxinWang-metax <189149612@qq.com> Date: Thu, 18 Sep 2025 16:42:26 +0800 Subject: [PATCH 07/12] modify compile --- backends/metax_gpu/CMakeLists.txt | 40 +++---------------- backends/metax_gpu/build.sh | 2 +- backends/metax_gpu/compile.sh | 2 +- .../fused_adam_kernel_register.cu | 0 ...esidual_layer_norm_grad_kernel_register.cu | 0 ...out_residual_layer_norm_kernel_register.cu | 0 ...dding_eltwise_layernorm_kernel_register.cu | 0 .../fused_layernorm_kernel_register.cu | 0 .../fused_seqpool_cvm_grad_kernel_register.cu | 0 .../fused_seqpool_cvm_kernel_register.cu | 0 ...fused_softmax_mask_grad_kernel_register.cu | 0 .../fused_softmax_mask_kernel_register.cu | 0 ...max_mask_upper_triangle_kernel_register.cu | 0 ...d_stack_transpose_quant_kernel_register.cu | 0 ...sed_swiglu_weighted_bwd_kernel_register.cu | 30 ++++++++++++++ .../fused_token_prune_kernel_register.cu | 0 ...d_transpose_split_quant_kernel_register.cu | 0 ...nspose_wlch_split_quant_kernel_register.cu | 0 18 files changed, 37 insertions(+), 37 deletions(-) rename backends/metax_gpu/kernels/{cuda_kernels => fusion}/fused_adam_kernel_register.cu (100%) rename backends/metax_gpu/kernels/{cuda_kernels => fusion}/fused_bias_dropout_residual_layer_norm_grad_kernel_register.cu (100%) rename backends/metax_gpu/kernels/{cuda_kernels => fusion}/fused_bias_dropout_residual_layer_norm_kernel_register.cu (100%) rename backends/metax_gpu/kernels/{cuda_kernels => fusion}/fused_embedding_eltwise_layernorm_kernel_register.cu (100%) rename backends/metax_gpu/kernels/{cuda_kernels => fusion}/fused_layernorm_kernel_register.cu (100%) rename backends/metax_gpu/kernels/{cuda_kernels => fusion}/fused_seqpool_cvm_grad_kernel_register.cu (100%) rename backends/metax_gpu/kernels/{cuda_kernels => fusion}/fused_seqpool_cvm_kernel_register.cu (100%) rename backends/metax_gpu/kernels/{cuda_kernels => fusion}/fused_softmax_mask_grad_kernel_register.cu (100%) rename backends/metax_gpu/kernels/{cuda_kernels => fusion}/fused_softmax_mask_kernel_register.cu (100%) rename backends/metax_gpu/kernels/{cuda_kernels => fusion}/fused_softmax_mask_upper_triangle_kernel_register.cu (100%) rename backends/metax_gpu/kernels/{cuda_kernels => fusion}/fused_stack_transpose_quant_kernel_register.cu (100%) create mode 100644 backends/metax_gpu/kernels/fusion/fused_swiglu_weighted_bwd_kernel_register.cu rename backends/metax_gpu/kernels/{cuda_kernels => fusion}/fused_token_prune_kernel_register.cu (100%) rename backends/metax_gpu/kernels/{cuda_kernels => fusion}/fused_transpose_split_quant_kernel_register.cu (100%) rename backends/metax_gpu/kernels/{cuda_kernels => fusion}/fused_transpose_wlch_split_quant_kernel_register.cu (100%) diff --git a/backends/metax_gpu/CMakeLists.txt b/backends/metax_gpu/CMakeLists.txt index f282a9fbf7c..7b8c52f1f31 100755 --- a/backends/metax_gpu/CMakeLists.txt +++ b/backends/metax_gpu/CMakeLists.txt @@ -70,7 +70,6 @@ include(eigen) include(xxhash) include(zlib) include(protobuf) -include(generate_pb) set(PROTO_FILE "${PADDLE_SOURCE_DIR}/paddle/phi/core/external_error.proto") get_filename_component(PROTO_WE "${PROTO_FILE}" NAME_WE) @@ -614,12 +613,9 @@ file( ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/funcs/math_function.cc ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/log_softmax_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/log_softmax_grad_kernel.cu - # ${PADDLE_SOURCE_DIR}/paddle/phi/backends/context_pool.cc ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/funcs/repeat_tensor2index_tensor.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/binomial_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/bernoulli_kernel.cu - # ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/bmm_grad_kernel_impl.h - # ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/bmm_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/backends/dynload/cufft.cc ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/box_coder_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/broadcast_tensors_kernel.cu @@ -642,29 +638,11 @@ file( ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/gather_tree_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/graph_reindex_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu - # ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/group_norm_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/group_norm_grad_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/gumbel_softmax_grad_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/gumbel_softmax_kernel.cu - # ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/gpu/fused_act_dequant_kernel.cu - # ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/gpu/block_multi_head_attention_kernel.cu - # ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/gpu/fused_weighted_swiglu_act_quant_kernel.cu - # ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/gpu/fused_elemwise_activation_kernel.cu - # ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/gpu/fused_softmax_mask_upper_triangle_grad_kernel.cu - # ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/fp8_gemm/fp8_gemm_with_cublasLt/fp8_fp8_half_gemm.cu - # ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention_grad_kernel.cu - # ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/cutlass/fused_conv2d_add_act_kernel.cu - # ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/cutlass/variable_length_memory_efficient_attention_kernel.cu - # ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention_kernel.cu - # ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/cutlass/gemm_epilogue_kernel.cu - # ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/gpu/blha_get_max_len.cu - # ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/gpu/fused_elemwise_activation_grad_kernel.cu - # ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/stride/as_real_kernel.cc - # ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/stride/as_complex_kernel.cc - # ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/stride/complex_grad_kernel.cc - # ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/stride/complex_kernel.cc - # ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/selected_rows/shape_kernel.cc - # ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/conv_kernel_igemm.cu + ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/top_p_sampling_kernel.cu + ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/rms_norm_kernel.cu # ############################################################################ ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu # kernels/kps @@ -697,7 +675,6 @@ file( ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/legacy/gpu/cal_aux_loss_grad_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/legacy/gpu/expand_modality_expert_id_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/legacy/gpu/int_bincount_kernel.cu - ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/top_p_sampling_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu) file( @@ -707,6 +684,8 @@ file( passes/*.cc kernels/*.cc kernels/*.cu + kernels/fusion/*.cc + kernels/fusion/*.cu kernels/gpudnn/*.cc kernels/gpudnn/*.cu kernels/cuda_kernels/*.cc @@ -721,13 +700,7 @@ set_source_files_properties(${CUSTOM_DEVICE_SRCS} PROPERTIES LANGUAGE CUDA) set(CMAKE_CUCC_COMPILER "cucc") set(CMAKE_CUCC_FLAGS "-I /opt/maca/tools/cu-bridge/include/") -set_source_files_properties( - ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/rms_norm_kernel.cu - PROPERTIES LANGUAGE CUDA) -add_library( - ${TARGET_NAME} SHARED - ${CUSTOM_DEVICE_SRCS} - ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/rms_norm_kernel.cu) +add_library(${TARGET_NAME} SHARED ${CUSTOM_DEVICE_SRCS}) target_include_directories( ${TARGET_NAME} @@ -753,9 +726,6 @@ target_link_libraries( ${WARPCTC_LIBRARIES} ${WARPRNNT_LIBRARIES} ${PADDLE_CORE_LIB}) -target_link_libraries(${TARGET_NAME} /opt/maca/lib/libmccl.so) -target_link_libraries(${TARGET_NAME} /opt/maca/lib/libmcFlashAttn.so) -target_link_libraries(${TARGET_NAME} /opt/maca/lib/libmcpti.so) include_directories(BEFORE ${PADDLE_SOURCE_DIR}) target_compile_definitions( diff --git a/backends/metax_gpu/build.sh b/backends/metax_gpu/build.sh index e3c4304e5f8..6bd36d6cfba 100755 --- a/backends/metax_gpu/build.sh +++ b/backends/metax_gpu/build.sh @@ -52,7 +52,7 @@ fi echo "make_maca" cd build -cmake_maca .. -DPython3_EXECUTABLE=$(which python3) -DWITH_GPU=ON +cmake_maca .. -DCMAKE_BUILD_TYPE=Release -DPython3_EXECUTABLE=$(which python3) -DWITH_GPU=ON make_maca -j60 echo "install whl" diff --git a/backends/metax_gpu/compile.sh b/backends/metax_gpu/compile.sh index e9860ccb7d0..463c2ca3402 100644 --- a/backends/metax_gpu/compile.sh +++ b/backends/metax_gpu/compile.sh @@ -30,7 +30,7 @@ fi echo "make_maca" cd build -cmake_maca .. -DPython3_EXECUTABLE=$(which python3) -DWITH_GPU=ON +cmake_maca .. -DCMAKE_BUILD_TYPE=Release -DPython3_EXECUTABLE=$(which python3) -DWITH_GPU=ON make_maca -j10 diff --git a/backends/metax_gpu/kernels/cuda_kernels/fused_adam_kernel_register.cu b/backends/metax_gpu/kernels/fusion/fused_adam_kernel_register.cu similarity index 100% rename from backends/metax_gpu/kernels/cuda_kernels/fused_adam_kernel_register.cu rename to backends/metax_gpu/kernels/fusion/fused_adam_kernel_register.cu diff --git a/backends/metax_gpu/kernels/cuda_kernels/fused_bias_dropout_residual_layer_norm_grad_kernel_register.cu b/backends/metax_gpu/kernels/fusion/fused_bias_dropout_residual_layer_norm_grad_kernel_register.cu similarity index 100% rename from backends/metax_gpu/kernels/cuda_kernels/fused_bias_dropout_residual_layer_norm_grad_kernel_register.cu rename to backends/metax_gpu/kernels/fusion/fused_bias_dropout_residual_layer_norm_grad_kernel_register.cu diff --git a/backends/metax_gpu/kernels/cuda_kernels/fused_bias_dropout_residual_layer_norm_kernel_register.cu b/backends/metax_gpu/kernels/fusion/fused_bias_dropout_residual_layer_norm_kernel_register.cu similarity index 100% rename from backends/metax_gpu/kernels/cuda_kernels/fused_bias_dropout_residual_layer_norm_kernel_register.cu rename to backends/metax_gpu/kernels/fusion/fused_bias_dropout_residual_layer_norm_kernel_register.cu diff --git a/backends/metax_gpu/kernels/cuda_kernels/fused_embedding_eltwise_layernorm_kernel_register.cu b/backends/metax_gpu/kernels/fusion/fused_embedding_eltwise_layernorm_kernel_register.cu similarity index 100% rename from backends/metax_gpu/kernels/cuda_kernels/fused_embedding_eltwise_layernorm_kernel_register.cu rename to backends/metax_gpu/kernels/fusion/fused_embedding_eltwise_layernorm_kernel_register.cu diff --git a/backends/metax_gpu/kernels/cuda_kernels/fused_layernorm_kernel_register.cu b/backends/metax_gpu/kernels/fusion/fused_layernorm_kernel_register.cu similarity index 100% rename from backends/metax_gpu/kernels/cuda_kernels/fused_layernorm_kernel_register.cu rename to backends/metax_gpu/kernels/fusion/fused_layernorm_kernel_register.cu diff --git a/backends/metax_gpu/kernels/cuda_kernels/fused_seqpool_cvm_grad_kernel_register.cu b/backends/metax_gpu/kernels/fusion/fused_seqpool_cvm_grad_kernel_register.cu similarity index 100% rename from backends/metax_gpu/kernels/cuda_kernels/fused_seqpool_cvm_grad_kernel_register.cu rename to backends/metax_gpu/kernels/fusion/fused_seqpool_cvm_grad_kernel_register.cu diff --git a/backends/metax_gpu/kernels/cuda_kernels/fused_seqpool_cvm_kernel_register.cu b/backends/metax_gpu/kernels/fusion/fused_seqpool_cvm_kernel_register.cu similarity index 100% rename from backends/metax_gpu/kernels/cuda_kernels/fused_seqpool_cvm_kernel_register.cu rename to backends/metax_gpu/kernels/fusion/fused_seqpool_cvm_kernel_register.cu diff --git a/backends/metax_gpu/kernels/cuda_kernels/fused_softmax_mask_grad_kernel_register.cu b/backends/metax_gpu/kernels/fusion/fused_softmax_mask_grad_kernel_register.cu similarity index 100% rename from backends/metax_gpu/kernels/cuda_kernels/fused_softmax_mask_grad_kernel_register.cu rename to backends/metax_gpu/kernels/fusion/fused_softmax_mask_grad_kernel_register.cu diff --git a/backends/metax_gpu/kernels/cuda_kernels/fused_softmax_mask_kernel_register.cu b/backends/metax_gpu/kernels/fusion/fused_softmax_mask_kernel_register.cu similarity index 100% rename from backends/metax_gpu/kernels/cuda_kernels/fused_softmax_mask_kernel_register.cu rename to backends/metax_gpu/kernels/fusion/fused_softmax_mask_kernel_register.cu diff --git a/backends/metax_gpu/kernels/cuda_kernels/fused_softmax_mask_upper_triangle_kernel_register.cu b/backends/metax_gpu/kernels/fusion/fused_softmax_mask_upper_triangle_kernel_register.cu similarity index 100% rename from backends/metax_gpu/kernels/cuda_kernels/fused_softmax_mask_upper_triangle_kernel_register.cu rename to backends/metax_gpu/kernels/fusion/fused_softmax_mask_upper_triangle_kernel_register.cu diff --git a/backends/metax_gpu/kernels/cuda_kernels/fused_stack_transpose_quant_kernel_register.cu b/backends/metax_gpu/kernels/fusion/fused_stack_transpose_quant_kernel_register.cu similarity index 100% rename from backends/metax_gpu/kernels/cuda_kernels/fused_stack_transpose_quant_kernel_register.cu rename to backends/metax_gpu/kernels/fusion/fused_stack_transpose_quant_kernel_register.cu diff --git a/backends/metax_gpu/kernels/fusion/fused_swiglu_weighted_bwd_kernel_register.cu b/backends/metax_gpu/kernels/fusion/fused_swiglu_weighted_bwd_kernel_register.cu new file mode 100644 index 00000000000..08876233bfb --- /dev/null +++ b/backends/metax_gpu/kernels/fusion/fused_swiglu_weighted_bwd_kernel_register.cu @@ -0,0 +1,30 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/fusion/gpu/fused_swiglu_weighted_bwd_kernel.cu" //NOLINT + +PD_CUSTOM_KERNEL_REGISTER(fused_swiglu_weighted_bwd, + metax_gpu, + ALL_LAYOUT, + phi::FusedSwigluWeightedBwdKernel, + float, + double, + int, + int64_t, + phi::bfloat16) { + kernel->OutputAt(0).SetDataType(phi::DataType::BFLOAT16); + kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(2).SetDataType(phi::DataType::BFLOAT16); +} diff --git a/backends/metax_gpu/kernels/cuda_kernels/fused_token_prune_kernel_register.cu b/backends/metax_gpu/kernels/fusion/fused_token_prune_kernel_register.cu similarity index 100% rename from backends/metax_gpu/kernels/cuda_kernels/fused_token_prune_kernel_register.cu rename to backends/metax_gpu/kernels/fusion/fused_token_prune_kernel_register.cu diff --git a/backends/metax_gpu/kernels/cuda_kernels/fused_transpose_split_quant_kernel_register.cu b/backends/metax_gpu/kernels/fusion/fused_transpose_split_quant_kernel_register.cu similarity index 100% rename from backends/metax_gpu/kernels/cuda_kernels/fused_transpose_split_quant_kernel_register.cu rename to backends/metax_gpu/kernels/fusion/fused_transpose_split_quant_kernel_register.cu diff --git a/backends/metax_gpu/kernels/cuda_kernels/fused_transpose_wlch_split_quant_kernel_register.cu b/backends/metax_gpu/kernels/fusion/fused_transpose_wlch_split_quant_kernel_register.cu similarity index 100% rename from backends/metax_gpu/kernels/cuda_kernels/fused_transpose_wlch_split_quant_kernel_register.cu rename to backends/metax_gpu/kernels/fusion/fused_transpose_wlch_split_quant_kernel_register.cu From b2b41c269cbdc294b477bff791b499956808c795 Mon Sep 17 00:00:00 2001 From: jiaxinWang-metax <189149612@qq.com> Date: Fri, 19 Sep 2025 18:25:32 +0800 Subject: [PATCH 08/12] modify blas --- backends/metax_gpu/compile.sh | 2 +- .../kernels/funcs/blas/blas_impl.cu.h | 1270 ++++++++--------- .../kernels/metax_kernel/metax_context.cc | 35 - .../kernels/metax_kernel/metax_context.h | 2 - 4 files changed, 562 insertions(+), 747 deletions(-) mode change 100755 => 100644 backends/metax_gpu/kernels/funcs/blas/blas_impl.cu.h diff --git a/backends/metax_gpu/compile.sh b/backends/metax_gpu/compile.sh index 463c2ca3402..eba45a9ced2 100644 --- a/backends/metax_gpu/compile.sh +++ b/backends/metax_gpu/compile.sh @@ -30,7 +30,7 @@ fi echo "make_maca" cd build -cmake_maca .. -DCMAKE_BUILD_TYPE=Release -DPython3_EXECUTABLE=$(which python3) -DWITH_GPU=ON +cmake_maca .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DPython3_EXECUTABLE=$(which python3) -DWITH_GPU=ON make_maca -j10 diff --git a/backends/metax_gpu/kernels/funcs/blas/blas_impl.cu.h b/backends/metax_gpu/kernels/funcs/blas/blas_impl.cu.h old mode 100755 new mode 100644 index 419387cc9c4..ae4baa52613 --- a/backends/metax_gpu/kernels/funcs/blas/blas_impl.cu.h +++ b/backends/metax_gpu/kernels/funcs/blas/blas_impl.cu.h @@ -34,70 +34,6 @@ PHI_DECLARE_bool(gemm_use_half_precision_compute_type); namespace phi { namespace funcs { - -inline static cublasHandle_t blas_handle_ = nullptr; -inline static cublasHandle_t blas_tensor_core_handle_ = nullptr; -inline static cublasHandle_t blas_tf32_tensor_core_handle_ = nullptr; - -inline std::once_flag flag_sparse_; -inline std::once_flag flag_blas_; -inline std::once_flag flag_blaslt_; -inline std::once_flag flag_dnn_; -inline std::once_flag flag_solver_; -inline std::once_flag flag_cublas_; -inline std::once_flag flag_tensorcore_cublas_; -inline std::once_flag flag_eigen_device_; - -inline std::mutex blas_mtx_; -inline std::mutex blas_tensor_core_mtx_; -inline std::mutex blas_tf32_mtx_; -inline std::mutex sparse_mtx_; -inline std::mutex stream_call_back_mtx_; - -inline void InitBlasHandle(cublasHandle_t *blas_handle, gpuStream_t stream) { - PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasCreate(blas_handle)); - PADDLE_RETRY_CUDA_SUCCESS( - phi::dynload::cublasSetStream(*blas_handle, stream)); -} - -inline void CublasCall(const std::function &callback, - phi::stream::stream_t stream) { - std::call_once(flag_cublas_, [&]() { - if (!blas_handle_) InitBlasHandle(&blas_handle_, stream); - if (!blas_tensor_core_handle_) { - InitBlasHandle(&blas_tensor_core_handle_, stream); - PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode( - blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH)); - } - }); - std::lock_guard guard(blas_mtx_); - callback(blas_handle_); -} - -inline bool MetaxTensorCoreAvailable() { - return blas_tensor_core_handle_ != nullptr; -} - -inline void TensorCoreCublasCallIfAvailable( - const std::function &callback, - phi::stream::stream_t stream) { - std::call_once(flag_tensorcore_cublas_, [&]() { - if (!blas_handle_) InitBlasHandle(&blas_handle_, stream); - if (!blas_tensor_core_handle_) { - InitBlasHandle(&blas_tensor_core_handle_, stream); - PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode( - blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH)); - } - }); - if (blas_tensor_core_handle_ != nullptr) { - std::lock_guard guard(blas_tensor_core_mtx_); - callback(blas_tensor_core_handle_); - } else { - std::lock_guard guard(blas_mtx_); - callback(blas_handle_); - } -} - template struct CUBlas; @@ -174,28 +110,26 @@ struct CUBlas { // here. #if CUDA_VERSION >= 8000 VLOG(5) << "use_tensor_op_math: " - << (MetaxTensorCoreAvailable() ? "True" : "False"); - TensorCoreCublasCallIfAvailable( - [&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSgemmEx(handle, - transa, - transb, - m, - n, - k, - alpha, - A, - Atype, - lda, - B, - Btype, - ldb, - beta, - C, - Ctype, - ldc)); - }, - dev_ctx->stream()); + << (dev_ctx->tensor_core_available() ? "True" : "False"); + dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSgemmEx(handle, + transa, + transb, + m, + n, + k, + alpha, + A, + Atype, + lda, + B, + Btype, + ldb, + beta, + C, + Ctype, + ldc)); + }); #else PADDLE_THROW(phi::errors::Unimplemented( "cublasSgemmEx is not supported on cuda <= 7.5")); @@ -376,7 +310,7 @@ struct CUBlas { #if CUDA_VERSION >= 8000 cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; #if CUDA_VERSION >= 9000 - bool use_tensor_op_math = MetaxTensorCoreAvailable(); + bool use_tensor_op_math = dev_ctx->tensor_core_available(); if (use_tensor_op_math) { algo = CUBLAS_GEMM_DFALT_TENSOR_OP; } @@ -386,31 +320,29 @@ struct CUBlas { thrust::device_vector A_ptr(A, A + batchCount); thrust::device_vector B_ptr(B, B + batchCount); thrust::device_vector C_ptr(C, C + batchCount); - TensorCoreCublasCallIfAvailable( - [&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::cublasGemmBatchedEx(handle, - transa, - transb, - m, - n, - k, - alpha, - A_ptr.data().get(), - Atype, - lda, - B_ptr.data().get(), - Btype, - ldb, - beta, - C_ptr.data().get(), - Ctype, - ldc, - batchCount, - computeType, - algo)); - }, - dev_ctx->stream()); + dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cublasGemmBatchedEx(handle, + transa, + transb, + m, + n, + k, + alpha, + A_ptr.data().get(), + Atype, + lda, + B_ptr.data().get(), + Btype, + ldb, + beta, + C_ptr.data().get(), + Ctype, + ldc, + batchCount, + computeType, + algo)); + }); #else PADDLE_THROW(phi::errors::Unimplemented( "cublasGemmBatchedEx is not supported on cuda <= 7.5")); @@ -486,7 +418,7 @@ struct CUBlas { #if CUDA_VERSION >= 8000 cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; #if CUDA_VERSION >= 9000 - bool use_tensor_op_math = MetaxTensorCoreAvailable(); + bool use_tensor_op_math = dev_ctx->tensor_core_available(); if (use_tensor_op_math) { algo = CUBLAS_GEMM_DFALT_TENSOR_OP; } @@ -494,29 +426,27 @@ struct CUBlas { << (use_tensor_op_math ? "True" : "False"); #endif // CUDA_VERSION >= 9000 - TensorCoreCublasCallIfAvailable( - [&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle, - transa, - transb, - m, - n, - k, - alpha, - A, - Atype, - lda, - B, - Btype, - ldb, - beta, - C, - Ctype, - ldc, - computeType, - algo)); - }, - dev_ctx->stream()); + dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle, + transa, + transb, + m, + n, + k, + alpha, + A, + Atype, + lda, + B, + Btype, + ldb, + beta, + C, + Ctype, + ldc, + computeType, + algo)); + }); #else PADDLE_THROW(phi::errors::Unimplemented( "cublasGemmEx is not supported on cuda <= 7.5")); @@ -696,7 +626,7 @@ struct CUBlas> { #if CUDA_VERSION >= 8000 cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; #if CUDA_VERSION >= 9000 - bool use_tensor_op_math = MetaxTensorCoreAvailable(); + bool use_tensor_op_math = dev_ctx->tensor_core_available(); if (use_tensor_op_math) { algo = CUBLAS_GEMM_DFALT_TENSOR_OP; } @@ -704,29 +634,27 @@ struct CUBlas> { << (use_tensor_op_math ? "True" : "False"); #endif // CUDA_VERSION >= 9000 - TensorCoreCublasCallIfAvailable( - [&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle, - transa, - transb, - m, - n, - k, - alpha, - A, - Atype, - lda, - B, - Btype, - ldb, - beta, - C, - Ctype, - ldc, - computeType, - algo)); - }, - dev_ctx->stream()); + dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle, + transa, + transb, + m, + n, + k, + alpha, + A, + Atype, + lda, + B, + Btype, + ldb, + beta, + C, + Ctype, + ldc, + computeType, + algo)); + }); #else PADDLE_THROW(phi::errors::Unimplemented( "cublasGemmEx is not supported on cuda <= 7.5")); @@ -1024,7 +952,7 @@ struct CUBlas> { #if CUDA_VERSION >= 8000 cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; #if CUDA_VERSION >= 9000 - bool use_tensor_op_math = MetaxTensorCoreAvailable(); + bool use_tensor_op_math = dev_ctx->tensor_core_available(); if (use_tensor_op_math) { algo = CUBLAS_GEMM_DFALT_TENSOR_OP; } @@ -1032,29 +960,27 @@ struct CUBlas> { << (use_tensor_op_math ? "True" : "False"); #endif // CUDA_VERSION >= 9000 - TensorCoreCublasCallIfAvailable( - [&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle, - transa, - transb, - m, - n, - k, - alpha, - A, - Atype, - lda, - B, - Btype, - ldb, - beta, - C, - Ctype, - ldc, - computeType, - algo)); - }, - dev_ctx->stream()); + dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle, + transa, + transb, + m, + n, + k, + alpha, + A, + Atype, + lda, + B, + Btype, + ldb, + beta, + C, + Ctype, + ldc, + computeType, + algo)); + }); #else PADDLE_THROW(phi::errors::Unimplemented( "cublasGemmEx is not supported on cuda <= 7.5")); @@ -1186,24 +1112,22 @@ void Blas::GEMM(CBLAS_TRANSPOSE transA, PADDLE_THROW(common::errors::Unimplemented( "GEMM_EX_64 is not supported on cuda < 12.3")); } else { - CublasCall( - [&](cublasHandle_t handle) { - CUBlas::GEMM(handle, - cuTransB, - cuTransA, - N, - M, - K, - &alpha, - B, - ldb, - A, - lda, - &beta, - C, - N); - }, - dev_ctx_.stream()); + dev_ctx_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMM(handle, + cuTransB, + cuTransA, + N, + M, + K, + &alpha, + B, + ldb, + A, + lda, + &beta, + C, + N); + }); } #if CUDA_VERSION >= 8000 @@ -1271,24 +1195,22 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, #else // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm - CublasCall( - [&](cublasHandle_t handle) { - CUBlas::GEMM(handle, - cuTransB, - cuTransA, - N, - M, - K, - &h_alpha, - h_B, - ldb, - h_A, - lda, - &h_beta, - h_C, - N); - }, - dev_ctx_.stream()); + dev_ctx_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMM(handle, + cuTransB, + cuTransA, + N, + M, + K, + &h_alpha, + h_B, + ldb, + h_A, + lda, + &h_beta, + h_C, + N); + }); #endif // CUDA_VERSION >= 8000 } @@ -1352,24 +1274,22 @@ void Blas::GEMM(CBLAS_TRANSPOSE transA, PADDLE_THROW(common::errors::Unimplemented( "GEMM_EX_64 is not supported on cuda < 12.3")); } else { - CublasCall( - [&](cublasHandle_t handle) { - CUBlas::GEMM(handle, - cuTransB, - cuTransA, - static_cast(N), - static_cast(M), - static_cast(K), - &t_alpha, - B, - static_cast(ldb), - A, - static_cast(lda), - &t_beta, - C, - static_cast(N)); - }, - dev_ctx_.stream()); + dev_ctx_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMM(handle, + cuTransB, + cuTransA, + static_cast(N), + static_cast(M), + static_cast(K), + &t_alpha, + B, + static_cast(ldb), + A, + static_cast(lda), + &t_beta, + C, + static_cast(N)); + }); } #if CUDA_VERSION >= 8000 @@ -1447,24 +1367,22 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, CUBLAS_COMPUTE_32F); #else // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm - CublasCall( - [&](cublasHandle_t handle) { - CUBlas::GEMM(handle, - cuTransB, - cuTransA, - static_cast(N), - static_cast(M), - static_cast(K), - &h_alpha, - h_B, - static_cast(ldb), - h_A, - static_cast(lda), - &h_beta, - h_C, - static_cast(N)); - }, - dev_ctx_.stream()); + dev_ctx_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMM(handle, + cuTransB, + cuTransA, + static_cast(N), + static_cast(M), + static_cast(K), + &h_alpha, + h_B, + static_cast(ldb), + h_A, + static_cast(lda), + &h_beta, + h_C, + static_cast(N)); + }); #endif // CUDA_VERSION >= 8000 } } @@ -1503,7 +1421,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, float h_beta = static_cast(beta); cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; - bool use_tensor_op_math = MetaxTensorCoreAvailable(); + bool use_tensor_op_math = dev_ctx_.tensor_core_available(); if (use_tensor_op_math) { algo = CUBLAS_GEMM_DFALT_TENSOR_OP; } @@ -1519,30 +1437,27 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, #endif // CUDA_VERSION >= 12030 } else { CheckGEMMNSize(N); - TensorCoreCublasCallIfAvailable( - [&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::cublasGemmEx(handle, - cuTransB, - cuTransA, - N, - M, - K, - &h_alpha, - B, - CUDA_R_16BF, - ldb, - A, - CUDA_R_16BF, - lda, - &h_beta, - C, - CUDA_R_16BF, - N, - CUBLAS_COMPUTE_32F, - algo)); - }, - dev_ctx_.stream()); + dev_ctx_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle, + cuTransB, + cuTransA, + N, + M, + K, + &h_alpha, + B, + CUDA_R_16BF, + ldb, + A, + CUDA_R_16BF, + lda, + &h_beta, + C, + CUDA_R_16BF, + N, + CUBLAS_COMPUTE_32F, + algo)); + }); } #else // raise error @@ -1621,24 +1536,22 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, #else // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm - CublasCall( - [&](cublasHandle_t handle) { - CUBlas>::GEMM(handle, - cuTransB, - cuTransA, - static_cast(N), - static_cast(M), - static_cast(K), - &c_alpha, - h_B, - static_cast(ldb), - h_A, - static_cast(lda), - &c_beta, - h_C, - static_cast(N)); - }, - dev_ctx_.stream()); + dev_ctx_.CublasCall([&](cublasHandle_t handle) { + CUBlas>::GEMM(handle, + cuTransB, + cuTransA, + static_cast(N), + static_cast(M), + static_cast(K), + &c_alpha, + h_B, + static_cast(ldb), + h_A, + static_cast(lda), + &c_beta, + h_C, + static_cast(N)); + }); #endif // CUDA_VERSION >= 8000 } } @@ -1713,24 +1626,22 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, #else // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm - CublasCall( - [&](cublasHandle_t handle) { - CUBlas>::GEMM(handle, - cuTransB, - cuTransA, - static_cast(N), - static_cast(M), - static_cast(K), - &c_alpha, - h_B, - static_cast(ldb), - h_A, - static_cast(lda), - &c_beta, - h_C, - static_cast(N)); - }, - dev_ctx_.stream()); + dev_ctx_.CublasCall([&](cublasHandle_t handle) { + CUBlas>::GEMM(handle, + cuTransB, + cuTransA, + static_cast(N), + static_cast(M), + static_cast(K), + &c_alpha, + h_B, + static_cast(ldb), + h_A, + static_cast(lda), + &c_beta, + h_C, + static_cast(N)); + }); #endif // CUDA_VERSION >= 8000 } } @@ -1769,7 +1680,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, float h_beta = beta; cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; - bool use_tensor_op_math = MetaxTensorCoreAvailable(); + bool use_tensor_op_math = dev_ctx_.tensor_core_available(); if (use_tensor_op_math) { algo = CUBLAS_GEMM_DFALT_TENSOR_OP; } @@ -1784,30 +1695,28 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, #endif // CUDA_VERSION >= 12030 } else { CheckGEMMNSize(N); - TensorCoreCublasCallIfAvailable( - [&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::cublasGemmEx(handle, - cuTransB, - cuTransA, - static_cast(N), - static_cast(M), - static_cast(K), - &h_alpha, - B, - CUDA_R_16BF, - static_cast(ldb), - A, - CUDA_R_16BF, - static_cast(lda), - &h_beta, - C, - CUDA_R_16BF, - static_cast(N), - CUDA_R_32F, - algo)); - }, - dev_ctx_.stream()); + dev_ctx_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cublasGemmEx(handle, + cuTransB, + cuTransA, + static_cast(N), + static_cast(M), + static_cast(K), + &h_alpha, + B, + CUDA_R_16BF, + static_cast(ldb), + A, + CUDA_R_16BF, + static_cast(lda), + &h_beta, + C, + CUDA_R_16BF, + static_cast(N), + CUDA_R_32F, + algo)); + }); } #else // raise error @@ -1860,24 +1769,22 @@ void Blas::GEMM(bool transA, } else { #endif // CUDA_VERSION >= 8000 - CublasCall( - [&](cublasHandle_t handle) { - CUBlas::GEMM(handle, - cuTransB, - cuTransA, - N, - M, - K, - &alpha, - B, - ldb, - A, - lda, - &beta, - C, - ldc); - }, - dev_ctx_.stream()); + dev_ctx_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMM(handle, + cuTransB, + cuTransA, + N, + M, + K, + &alpha, + B, + ldb, + A, + lda, + &beta, + C, + ldc); + }); #if CUDA_VERSION >= 8000 } @@ -1904,24 +1811,22 @@ inline void Blas::GEMM(bool transA, cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N; cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N; - CublasCall( - [&](cublasHandle_t handle) { - CUBlas::GEMM(handle, - cuTransB, - cuTransA, - N, - M, - K, - &alpha, - B, - ldb, - A, - lda, - &beta, - C, - ldc); - }, - dev_ctx_.stream()); + dev_ctx_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMM(handle, + cuTransB, + cuTransA, + N, + M, + K, + &alpha, + B, + ldb, + A, + lda, + &beta, + C, + ldc); + }); } template <> @@ -1957,36 +1862,33 @@ inline void Blas::GEMM(bool transA, float h_beta = static_cast(beta); cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; - bool use_tensor_op_math = MetaxTensorCoreAvailable(); + bool use_tensor_op_math = dev_ctx_.tensor_core_available(); if (use_tensor_op_math) { algo = CUBLAS_GEMM_DFALT_TENSOR_OP; } VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False"); - TensorCoreCublasCallIfAvailable( - [&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::cublasGemmEx(handle, - cuTransB, - cuTransA, - N, - M, - K, - &h_alpha, - B, - CUDA_R_16BF, - ldb, - A, - CUDA_R_16BF, - lda, - &h_beta, - C, - CUDA_R_16BF, - ldc, - CUBLAS_COMPUTE_32F, - algo)); - }, - dev_ctx_.stream()); + dev_ctx_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle, + cuTransB, + cuTransA, + N, + M, + K, + &h_alpha, + B, + CUDA_R_16BF, + ldb, + A, + CUDA_R_16BF, + lda, + &h_beta, + C, + CUDA_R_16BF, + ldc, + CUBLAS_COMPUTE_32F, + algo)); + }); #else // raise error PADDLE_THROW(phi::errors::Unimplemented( @@ -1998,27 +1900,23 @@ inline void Blas::GEMM(bool transA, template <> template void Blas::AXPY(int n, T alpha, const T *x, T *y) const { - CublasCall( - [&](cublasHandle_t handle) { - CUBlas::AXPY(handle, n, &alpha, x, 1, y, 1); - }, - dev_ctx_.stream()); + dev_ctx_.CublasCall([&](cublasHandle_t handle) { + CUBlas::AXPY(handle, n, &alpha, x, 1, y, 1); + }); } template <> template void Blas::SCAL(int n, const T alpha, T *x) const { - CublasCall( - [&](cublasHandle_t handle) { CUBlas::SCAL(handle, n, &alpha, x, 1); }, - dev_ctx_.stream()); + dev_ctx_.CublasCall( + [&](cublasHandle_t handle) { CUBlas::SCAL(handle, n, &alpha, x, 1); }); } template <> template void Blas::VCOPY(int n, const T *x, T *y) const { - CublasCall( - [&](cublasHandle_t handle) { CUBlas::VCOPY(handle, n, x, 1, y, 1); }, - dev_ctx_.stream()); + dev_ctx_.CublasCall( + [&](cublasHandle_t handle) { CUBlas::VCOPY(handle, n, x, 1, y, 1); }); } template <> @@ -2033,12 +1931,9 @@ void Blas::GEMV(bool trans_a, T *C) const { cublasOperation_t cuTransA = !trans_a ? CUBLAS_OP_T : CUBLAS_OP_N; - CublasCall( - [&](cublasHandle_t handle) { - CUBlas::GEMV( - handle, cuTransA, N, M, &alpha, A, N, B, 1, &beta, C, 1); - }, - dev_ctx_.stream()); + dev_ctx_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMV(handle, cuTransA, N, M, &alpha, A, N, B, 1, &beta, C, 1); + }); } template <> @@ -2112,7 +2007,7 @@ void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, if ((FLAGS_enable_cublas_tensor_op_math && (std::is_same::value)) || std::is_same::value) { cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; - bool use_tensor_op_math = MetaxTensorCoreAvailable(); + bool use_tensor_op_math = dev_ctx_.tensor_core_available(); if (use_tensor_op_math) { algo = CUBLAS_GEMM_DFALT_TENSOR_OP; } @@ -2153,60 +2048,56 @@ void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, "cublasGemmStridedBatchedEx_64 is not supported on cuda < 12.3")); #endif // CUDA_VERSION >= 12030 } else { - TensorCoreCublasCallIfAvailable( - [&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::cublasGemmStridedBatchedEx(handle, - cuTransB, - cuTransA, - N, - M, - K, - a, - B, - fp, - ldb, - strideB, - A, - fp, - lda, - strideA, - b, - C, - fp, - ldc, - strideC, - batchCount, - compute_type, - algo)); - }, - dev_ctx_.stream()); + dev_ctx_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cublasGemmStridedBatchedEx(handle, + cuTransB, + cuTransA, + N, + M, + K, + a, + B, + fp, + ldb, + strideB, + A, + fp, + lda, + strideA, + b, + C, + fp, + ldc, + strideC, + batchCount, + compute_type, + algo)); + }); } } else { #endif // CUDA_VERSION >= 9010 - CublasCall( - [&](cublasHandle_t handle) { - CUBlas::GEMM_STRIDED_BATCH(handle, - cuTransB, - cuTransA, - static_cast(N), - static_cast(M), - static_cast(K), - &alpha, - B, - static_cast(ldb), - strideB, - A, - static_cast(lda), - strideA, - &beta, - C, - ldc, - strideC, - static_cast(batchCount)); - }, - dev_ctx_.stream()); + dev_ctx_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMM_STRIDED_BATCH(handle, + cuTransB, + cuTransA, + static_cast(N), + static_cast(M), + static_cast(K), + &alpha, + B, + static_cast(ldb), + strideB, + A, + static_cast(lda), + strideA, + &beta, + C, + ldc, + strideC, + static_cast(batchCount)); + }); #if CUDA_VERSION >= 9010 } @@ -2242,7 +2133,7 @@ void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, if ((FLAGS_enable_cublas_tensor_op_math && (std::is_same::value)) || std::is_same::value) { cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; - bool use_tensor_op_math = MetaxTensorCoreAvailable(); + bool use_tensor_op_math = dev_ctx_.tensor_core_available(); if (use_tensor_op_math) { algo = CUBLAS_GEMM_DFALT_TENSOR_OP; } @@ -2284,61 +2175,57 @@ void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, "cublasGemmStridedBatchedEx_64 is not supported on cuda < 12.3")); #endif // CUDA_VERSION >= 12030 } else { - TensorCoreCublasCallIfAvailable( - [&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmStridedBatchedEx( - handle, - cuTransB, - cuTransA, - static_cast(N), - static_cast(M), - static_cast(K), - a, - B, - fp, - static_cast(ldb), - strideB, - A, - fp, - static_cast(lda), - strideA, - b, - C, - fp, - static_cast(ldc), - strideC, - static_cast(batchCount), - compute_type, - algo)); - }, - dev_ctx_.stream()); + dev_ctx_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmStridedBatchedEx( + handle, + cuTransB, + cuTransA, + static_cast(N), + static_cast(M), + static_cast(K), + a, + B, + fp, + static_cast(ldb), + strideB, + A, + fp, + static_cast(lda), + strideA, + b, + C, + fp, + static_cast(ldc), + strideC, + static_cast(batchCount), + compute_type, + algo)); + }); } } else { #endif // CUDA_VERSION >= 9010 T h_alpha = static_cast(alpha); T h_beta = static_cast(beta); - CublasCall( - [&](cublasHandle_t handle) { - CUBlas::GEMM_STRIDED_BATCH(handle, - cuTransB, - cuTransA, - static_cast(N), - static_cast(M), - static_cast(K), - &h_alpha, - B, - static_cast(ldb), - strideB, - A, - static_cast(lda), - strideA, - &h_beta, - C, - static_cast(ldc), - strideC, - static_cast(batchCount)); - }, - dev_ctx_.stream()); + dev_ctx_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMM_STRIDED_BATCH(handle, + cuTransB, + cuTransA, + static_cast(N), + static_cast(M), + static_cast(K), + &h_alpha, + B, + static_cast(ldb), + strideB, + A, + static_cast(lda), + strideA, + &h_beta, + C, + static_cast(ldc), + strideC, + static_cast(batchCount)); + }); #if CUDA_VERSION >= 9010 } @@ -2377,7 +2264,7 @@ inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, float h_beta = static_cast(beta); cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; - bool use_tensor_op_math = MetaxTensorCoreAvailable(); + bool use_tensor_op_math = dev_ctx_.tensor_core_available(); if (use_tensor_op_math) { algo = CUBLAS_GEMM_DFALT_TENSOR_OP; } @@ -2392,34 +2279,32 @@ inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, "cublasGemmStridedBatchedEx_64 is not supported on cuda < 12.3")); #endif // CUDA_VERSION >= 12030 } else { - TensorCoreCublasCallIfAvailable( - [&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmStridedBatchedEx( - handle, - cuTransB, - cuTransA, - static_cast(N), - static_cast(M), - static_cast(K), - &h_alpha, - B, - CUDA_R_16BF, - static_cast(ldb), - strideB, - A, - CUDA_R_16BF, - static_cast(lda), - strideA, - &h_beta, - C, - CUDA_R_16BF, - static_cast(ldc), - strideC, - static_cast(batchCount), - CUBLAS_COMPUTE_32F, - algo)); - }, - dev_ctx_.stream()); + dev_ctx_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cublasGemmStridedBatchedEx(handle, + cuTransB, + cuTransA, + static_cast(N), + static_cast(M), + static_cast(K), + &h_alpha, + B, + CUDA_R_16BF, + static_cast(ldb), + strideB, + A, + CUDA_R_16BF, + static_cast(lda), + strideA, + &h_beta, + C, + CUDA_R_16BF, + static_cast(ldc), + strideC, + static_cast(batchCount), + CUBLAS_COMPUTE_32F, + algo)); + }); } #else // raise error @@ -2460,7 +2345,7 @@ inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, float h_beta = beta; cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; - bool use_tensor_op_math = MetaxTensorCoreAvailable(); + bool use_tensor_op_math = dev_ctx_.tensor_core_available(); if (use_tensor_op_math) { algo = CUBLAS_GEMM_DFALT_TENSOR_OP; } @@ -2475,34 +2360,32 @@ inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, "cublasGemmStridedBatchedEx_64 is not supported on cuda < 12.3")); #endif // CUDA_VERSION >= 12030 } else { - TensorCoreCublasCallIfAvailable( - [&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmStridedBatchedEx( - handle, - cuTransB, - cuTransA, - static_cast(N), - static_cast(M), - static_cast(K), - &h_alpha, - B, - CUDA_R_16BF, - static_cast(ldb), - strideB, - A, - CUDA_R_16BF, - static_cast(lda), - strideA, - &h_beta, - C, - CUDA_R_16BF, - static_cast(ldc), - strideC, - static_cast(batchCount), - CUBLAS_COMPUTE_32F, - algo)); - }, - dev_ctx_.stream()); + dev_ctx_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cublasGemmStridedBatchedEx(handle, + cuTransB, + cuTransA, + static_cast(N), + static_cast(M), + static_cast(K), + &h_alpha, + B, + CUDA_R_16BF, + static_cast(ldb), + strideB, + A, + CUDA_R_16BF, + static_cast(lda), + strideA, + &h_beta, + C, + CUDA_R_16BF, + static_cast(ldc), + strideC, + static_cast(batchCount), + CUBLAS_COMPUTE_32F, + algo)); + }); } #else // raise error @@ -2547,7 +2430,7 @@ inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, // (std::is_same::value)) || // std::is_same::value) { // cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; -// bool use_tensor_op_math = MetaxTensorCoreAvailable(); +// bool use_tensor_op_math = dev_ctx_.tensor_core_available(); // if (use_tensor_op_math) { // algo = CUBLAS_GEMM_DFALT_TENSOR_OP; // } @@ -2579,7 +2462,7 @@ inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, // #endif // } -// TensorCoreCublasCallIfAvailable( +// dev_ctx_.TensorCoreCublasCallIfAvailable( // [&](cublasHandle_t handle) { // PADDLE_ENFORCE_GPU_SUCCESS( // phi::dynload::cublasGemmStridedBatchedEx(handle, @@ -2605,12 +2488,11 @@ inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, // batchCount, // compute_type, // algo)); -// }, -// dev_ctx_.stream()); +// }); // } else { // #endif // CUDA_VERSION >= 9010 -// CublasCall( +// dev_ctx_.CublasCall( // [&](cublasHandle_t handle) { // CUBlas::GEMM_STRIDED_BATCH(handle, // cuTransB, @@ -2667,7 +2549,7 @@ inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, // cublasOperation_t cuTransB = // (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; // const int64_t strideC = M * N; -// CublasCall( +// dev_ctx_.CublasCall( // [&](cublasHandle_t handle) { // PADDLE_ENFORCE_GPU_SUCCESS( // phi::dynload::cublasDgemmStridedBatched(handle, @@ -2723,14 +2605,14 @@ inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, // float h_beta = static_cast(beta); // cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; -// bool use_tensor_op_math = MetaxTensorCoreAvailable(); +// bool use_tensor_op_math = dev_ctx->tensor_core_available(); // if (use_tensor_op_math) { // algo = CUBLAS_GEMM_DFALT_TENSOR_OP; // } // VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : // "False"); -// TensorCoreCublasCallIfAvailable( +// dev_ctx_.TensorCoreCublasCallIfAvailable( // [&](cublasHandle_t handle) { // PADDLE_ENFORCE_GPU_SUCCESS( // phi::dynload::cublasGemmStridedBatchedEx(handle, @@ -2756,8 +2638,7 @@ inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, // batchCount, // CUBLAS_COMPUTE_32F, // algo)); -// }, -// dev_ctx_.stream()); +// }); // #else // // raise error // PADDLE_THROW(phi::errors::Unimplemented( @@ -2812,25 +2693,23 @@ inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, thrust::device_vector B_ptr(B, B + batchCount); thrust::device_vector C_ptr(C, C + batchCount); - CublasCall( - [&](cublasHandle_t handle) { - CUBlas::GEMM_BATCH(handle, - cuTransB, - cuTransA, - N, - M, - K, - &alpha, - B_ptr.data().get(), - ldb, - A_ptr.data().get(), - lda, - &beta, - C_ptr.data().get(), - ldc, - batchCount); - }, - dev_ctx_.stream()); + dev_ctx_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMM_BATCH(handle, + cuTransB, + cuTransA, + N, + M, + K, + &alpha, + B_ptr.data().get(), + ldb, + A_ptr.data().get(), + lda, + &beta, + C_ptr.data().get(), + ldc, + batchCount); + }); } template <> @@ -2859,25 +2738,23 @@ inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, thrust::device_vector B_ptr(B, B + batchCount); thrust::device_vector C_ptr(C, C + batchCount); - CublasCall( - [&](cublasHandle_t handle) { - CUBlas::GEMM_BATCH(handle, - cuTransB, - cuTransA, - N, - M, - K, - &alpha, - B_ptr.data().get(), - ldb, - A_ptr.data().get(), - lda, - &beta, - C_ptr.data().get(), - ldc, - batchCount); - }, - dev_ctx_.stream()); + dev_ctx_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMM_BATCH(handle, + cuTransB, + cuTransA, + N, + M, + K, + &alpha, + B_ptr.data().get(), + ldb, + A_ptr.data().get(), + lda, + &beta, + C_ptr.data().get(), + ldc, + batchCount); + }); } template <> @@ -2970,7 +2847,7 @@ inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, float f_beta = static_cast(beta); cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; - bool use_tensor_op_math = MetaxTensorCoreAvailable(); + bool use_tensor_op_math = dev_ctx_.tensor_core_available(); if (use_tensor_op_math) { algo = CUBLAS_GEMM_DFALT_TENSOR_OP; } @@ -2979,31 +2856,29 @@ inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, thrust::device_vector A_ptr(A, A + batchCount); thrust::device_vector B_ptr(B, B + batchCount); thrust::device_vector C_ptr(C, C + batchCount); - TensorCoreCublasCallIfAvailable( - [&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::cublasGemmBatchedEx(handle, - cuTransB, - cuTransA, - N, - M, - K, - &f_alpha, - B_ptr.data().get(), - CUDA_R_16BF, - ldb, - A_ptr.data().get(), - CUDA_R_16BF, - lda, - &f_beta, - C_ptr.data().get(), - CUDA_R_16BF, - ldc, - batchCount, - CUBLAS_COMPUTE_32F, - algo)); - }, - dev_ctx_.stream()); + dev_ctx_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cublasGemmBatchedEx(handle, + cuTransB, + cuTransA, + N, + M, + K, + &f_alpha, + B_ptr.data().get(), + CUDA_R_16BF, + ldb, + A_ptr.data().get(), + CUDA_R_16BF, + lda, + &f_beta, + C_ptr.data().get(), + CUDA_R_16BF, + ldc, + batchCount, + CUBLAS_COMPUTE_32F, + algo)); + }); #else // raise error PADDLE_THROW(phi::errors::Unimplemented( @@ -3038,33 +2913,19 @@ void Blas::TRSM(CBLAS_SIDE side, cublasDiagType_t cuDiag = (diag == CblasUnit) ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT; - CublasCall( - [&](cublasHandle_t handle) { - CUBlas::TRSM(handle, - cuSide, - cuUplo, - cuTransA, - cuDiag, - N, - M, - &alpha, - A, - lda, - B, - ldb); - }, - dev_ctx_.stream()); + dev_ctx_.CublasCall([&](cublasHandle_t handle) { + CUBlas::TRSM( + handle, cuSide, cuUplo, cuTransA, cuDiag, N, M, &alpha, A, lda, B, ldb); + }); } template <> template void Blas::BatchedGETRF( int n, T **a, int *ipiv, int *info, int batch_size) const { - CublasCall( - [&](cublasHandle_t handle) { - CUBlas::GETRF_BATCH(handle, n, a, n, ipiv, info, batch_size); - }, - dev_ctx_.stream()); + dev_ctx_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GETRF_BATCH(handle, n, a, n, ipiv, info, batch_size); + }); } template <> @@ -3084,23 +2945,18 @@ void Blas::BatchedGETRI(int n, "overlap memory space of input matrix (address: %p).", a_inv, a)); - CublasCall( - [&](cublasHandle_t handle) { - CUBlas::GETRI_BATCH( - handle, n, a, n, ipiv, a_inv, n, info, batch_size); - }, - dev_ctx_.stream()); + dev_ctx_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GETRI_BATCH(handle, n, a, n, ipiv, a_inv, n, info, batch_size); + }); } template <> template void Blas::BatchedMatInv( int n, const T **a, T **a_inv, int *info, int batch_size) const { - CublasCall( - [&](cublasHandle_t handle) { - CUBlas::MATINV_BATCH(handle, n, a, n, a_inv, n, info, batch_size); - }, - dev_ctx_.stream()); + dev_ctx_.CublasCall([&](cublasHandle_t handle) { + CUBlas::MATINV_BATCH(handle, n, a, n, a_inv, n, info, batch_size); + }); } template <> @@ -3118,12 +2974,10 @@ void Blas::BatchedGETRS(CBLAS_TRANSPOSE trans, // use CUBLAS_OP_C (conjugate transpose) for complex cublasOperation_t cuTrans = (trans == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - CublasCall( - [&](cublasHandle_t handle) { - CUBlas::GETRS_BATCH( - handle, cuTrans, n, nrhs, a, lda, ipiv, b, ldb, info, batch_size); - }, - dev_ctx_.stream()); + dev_ctx_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GETRS_BATCH( + handle, cuTrans, n, nrhs, a, lda, ipiv, b, ldb, info, batch_size); + }); } template <> @@ -3152,23 +3006,21 @@ void Blas::BatchedTRSM(CBLAS_SIDE side, cublasDiagType_t cuDiag = (diag == CblasUnit) ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT; - CublasCall( - [&](cublasHandle_t handle) { - CUBlas::TRSM_BATCH(handle, - cuSide, - cuUplo, - cuTransA, - cuDiag, - N, - M, - &alpha, - A, - lda, - B, - ldb, - batch_size); - }, - dev_ctx_.stream()); + dev_ctx_.CublasCall([&](cublasHandle_t handle) { + CUBlas::TRSM_BATCH(handle, + cuSide, + cuUplo, + cuTransA, + cuDiag, + N, + M, + &alpha, + A, + lda, + B, + ldb, + batch_size); + }); } } // namespace funcs diff --git a/backends/metax_gpu/kernels/metax_kernel/metax_context.cc b/backends/metax_gpu/kernels/metax_kernel/metax_context.cc index 62aaa5fb2de..a388387de45 100644 --- a/backends/metax_gpu/kernels/metax_kernel/metax_context.cc +++ b/backends/metax_gpu/kernels/metax_kernel/metax_context.cc @@ -15,25 +15,6 @@ #include "kernels/metax_kernel/metax_context.h" namespace phi { -const bool allow_tf32_cublas = []() -> bool { - const char* v = std::getenv("ALLOW_TF32_CUBLAS"); - if (v) { - return std::atoi(v); - } - return false; -}(); - -const bool allow_tf32_cudnn = []() -> bool { - const char* v = std::getenv("ALLOW_TF32_CUDNN"); - if (v) { - return std::atoi(v); - } - return false; -}(); - -bool AllowTF32Cublas() { return allow_tf32_cublas; } -bool AllowTF32Cudnn() { return allow_tf32_cudnn; } - void DnnWorkspaceHandle::RunFuncSync( const std::function& cudnn_func, size_t required_workspace_bytes, @@ -87,20 +68,4 @@ static void InitBlasLtHandle(blasLtHandle_t* blaslt_handle) { phi::dynload::hipblasLtCreate(blaslt_handle); #endif } - -blasLtHandle_t GetBlasLtHandle() { - std::call_once(flag_blaslt_, [&]() { - if (!blaslt_handle_) { - if (!blaslt_handle_creator_) - InitBlasLtHandle(&blaslt_handle_); - else - blaslt_handle_ = blaslt_handle_creator_(); - } - }); - PADDLE_ENFORCE_NOT_NULL( - blaslt_handle_, - common::errors::InvalidArgument( - "The GPU blasLt handle is nullptr. It must not be null.")); - return blaslt_handle_; -} } // namespace phi diff --git a/backends/metax_gpu/kernels/metax_kernel/metax_context.h b/backends/metax_gpu/kernels/metax_kernel/metax_context.h index a6610c1dab2..2339e18a4a6 100644 --- a/backends/metax_gpu/kernels/metax_kernel/metax_context.h +++ b/backends/metax_gpu/kernels/metax_kernel/metax_context.h @@ -128,8 +128,6 @@ inline void InitCusolverDnHandle(cusolverDnHandle_t* handle, } } -bool AllowTF32Cublas(); -bool AllowTF32Cudnn(); inline cusolverDnHandle_t GetCusolverDnHandle(gpuStream_t stream, Place place) { std::call_once(flag_cusolver_dn_, [&]() { if (!cusolver_dn_handle_) { From 6556cce0f0df0f394cfdaa34ab9e0161683027b3 Mon Sep 17 00:00:00 2001 From: jiaxinWang-metax <189149612@qq.com> Date: Mon, 22 Sep 2025 15:39:14 +0800 Subject: [PATCH 09/12] modify blas --- backends/metax_gpu/CMakeLists.txt | 2 ++ .../metax_gpu/kernels/metax_kernel/metax_context.cc | 12 ------------ .../metax_gpu/kernels/metax_kernel/metax_context.h | 4 +--- 3 files changed, 3 insertions(+), 15 deletions(-) diff --git a/backends/metax_gpu/CMakeLists.txt b/backends/metax_gpu/CMakeLists.txt index 7b8c52f1f31..b98f2bcc919 100755 --- a/backends/metax_gpu/CMakeLists.txt +++ b/backends/metax_gpu/CMakeLists.txt @@ -70,6 +70,7 @@ include(eigen) include(xxhash) include(zlib) include(protobuf) +include(generate_pb) set(PROTO_FILE "${PADDLE_SOURCE_DIR}/paddle/phi/core/external_error.proto") get_filename_component(PROTO_WE "${PROTO_FILE}" NAME_WE) @@ -732,6 +733,7 @@ target_compile_definitions( ${TARGET_NAME} PUBLIC PADDLE_WITH_CUDA=1 PADDLE_WITH_CUSTOM_DEVICE=1 + mcblasContext=cublasContext GPUContext=CustomContext KPSContext=CustomContext STREAM_TYPE=cudaStream_t diff --git a/backends/metax_gpu/kernels/metax_kernel/metax_context.cc b/backends/metax_gpu/kernels/metax_kernel/metax_context.cc index a388387de45..6d86c81041f 100644 --- a/backends/metax_gpu/kernels/metax_kernel/metax_context.cc +++ b/backends/metax_gpu/kernels/metax_kernel/metax_context.cc @@ -56,16 +56,4 @@ void DnnWorkspaceHandle::ReallocWorkspace(size_t required_workspace_bytes) { allocation_.reset(); allocation_ = allocator_->Allocate(required_workspace_bytes); } - -static std::function blaslt_handle_creator_{nullptr}; -static blasLtHandle_t blaslt_handle_{nullptr}; -static std::once_flag flag_blaslt_; - -static void InitBlasLtHandle(blasLtHandle_t* blaslt_handle) { -#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 - mcblasLtCreate(blaslt_handle); -#elif defined(PADDLE_WITH_HIP) - phi::dynload::hipblasLtCreate(blaslt_handle); -#endif -} } // namespace phi diff --git a/backends/metax_gpu/kernels/metax_kernel/metax_context.h b/backends/metax_gpu/kernels/metax_kernel/metax_context.h index 2339e18a4a6..376981f27a4 100644 --- a/backends/metax_gpu/kernels/metax_kernel/metax_context.h +++ b/backends/metax_gpu/kernels/metax_kernel/metax_context.h @@ -27,9 +27,7 @@ #include "paddle/phi/core/attribute.h" #include "paddle/phi/core/device_context.h" -using blasLtHandle_t = struct mcblasLtContext*; - -blasLtHandle_t GetBlasLtHandle(); +cublasLtHandle_t GetBlasLtHandle(); namespace phi { class DnnWorkspaceHandle { From 1cbe0d83214c41aef8cb96cb612662b8b5c95bf2 Mon Sep 17 00:00:00 2001 From: jiaxinWang-metax <189149612@qq.com> Date: Mon, 22 Sep 2025 15:51:51 +0800 Subject: [PATCH 10/12] modify blas --- .../kernels/metax_kernel/metax_context.h | 2 -- backends/metax_gpu/patch/paddle.patch | 15 ++------------- 2 files changed, 2 insertions(+), 15 deletions(-) diff --git a/backends/metax_gpu/kernels/metax_kernel/metax_context.h b/backends/metax_gpu/kernels/metax_kernel/metax_context.h index 376981f27a4..9af19bd7464 100644 --- a/backends/metax_gpu/kernels/metax_kernel/metax_context.h +++ b/backends/metax_gpu/kernels/metax_kernel/metax_context.h @@ -27,8 +27,6 @@ #include "paddle/phi/core/attribute.h" #include "paddle/phi/core/device_context.h" -cublasLtHandle_t GetBlasLtHandle(); - namespace phi { class DnnWorkspaceHandle { public: diff --git a/backends/metax_gpu/patch/paddle.patch b/backends/metax_gpu/patch/paddle.patch index b7bdb953077..0c776c72cc7 100755 --- a/backends/metax_gpu/patch/paddle.patch +++ b/backends/metax_gpu/patch/paddle.patch @@ -488,7 +488,6 @@ index 4eae698648..5c047723ea 100644 #endif return block_dim >= kMaxBlockDim ? kMaxBlockDim : lwarpSize; } - diff --git a/paddle/phi/kernels/funcs/math/context_project.h b/paddle/phi/kernels/funcs/math/context_project.h index 15e1a4a3c3..e4780538d7 100644 --- a/paddle/phi/kernels/funcs/math/context_project.h @@ -1095,10 +1094,10 @@ index 410fb3c560..009ce03440 100644 if (x == T{0.0}) { T inf = std::numeric_limits::infinity(); diff --git a/paddle/phi/kernels/impl/llm_int8_matmul_kernel_impl.h b/paddle/phi/kernels/impl/llm_int8_matmul_kernel_impl.h -index 5ebbc8d2db..c7b6c338e2 100644 +index 5ebbc8d2db..162fb3bffb 100644 --- a/paddle/phi/kernels/impl/llm_int8_matmul_kernel_impl.h +++ b/paddle/phi/kernels/impl/llm_int8_matmul_kernel_impl.h -@@ -15,8 +15,9 @@ limitations under the License. */ +@@ -15,8 +15,8 @@ limitations under the License. */ #include #include #include "paddle/phi/common/datatype_traits.h" @@ -1106,19 +1105,9 @@ index 5ebbc8d2db..c7b6c338e2 100644 -#include "paddle/phi/kernels/funcs/quant_dequant.h" +#include "kernels/funcs/blas/cublaslt.h" +#include "kernels/funcs/quant_dequant.h" -+#include "kernels/metax_kernel/metax_context.h" #pragma once -@@ -668,7 +669,7 @@ void LLMGemm(const phi::GPUContext& dev_ctx, - - { - auto helper = -- std::make_unique(m, k, n, dev_ctx.cublaslt_handle()); -+ std::make_unique(m, k, n, GetBlasLtHandle()); - helper->GEMM(quant_input.data(), - weight->data(), - int_out.data(), diff --git a/paddle/phi/kernels/impl/matrix_power_grad_kernel_impl.h b/paddle/phi/kernels/impl/matrix_power_grad_kernel_impl.h index 1f319c4ae3..9186eb6906 100644 --- a/paddle/phi/kernels/impl/matrix_power_grad_kernel_impl.h From 554b3cb7cbc7806f037871a1323c270efeafc479 Mon Sep 17 00:00:00 2001 From: jiaxinWang-metax <189149612@qq.com> Date: Mon, 22 Sep 2025 16:38:13 +0800 Subject: [PATCH 11/12] modify blas --- .../metax_gpu/kernels/metax_kernel/metax_context.h | 2 ++ backends/metax_gpu/patch/paddle.patch | 14 ++++++++++++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/backends/metax_gpu/kernels/metax_kernel/metax_context.h b/backends/metax_gpu/kernels/metax_kernel/metax_context.h index 9af19bd7464..376981f27a4 100644 --- a/backends/metax_gpu/kernels/metax_kernel/metax_context.h +++ b/backends/metax_gpu/kernels/metax_kernel/metax_context.h @@ -27,6 +27,8 @@ #include "paddle/phi/core/attribute.h" #include "paddle/phi/core/device_context.h" +cublasLtHandle_t GetBlasLtHandle(); + namespace phi { class DnnWorkspaceHandle { public: diff --git a/backends/metax_gpu/patch/paddle.patch b/backends/metax_gpu/patch/paddle.patch index 0c776c72cc7..beefb730bf7 100755 --- a/backends/metax_gpu/patch/paddle.patch +++ b/backends/metax_gpu/patch/paddle.patch @@ -1094,10 +1094,10 @@ index 410fb3c560..009ce03440 100644 if (x == T{0.0}) { T inf = std::numeric_limits::infinity(); diff --git a/paddle/phi/kernels/impl/llm_int8_matmul_kernel_impl.h b/paddle/phi/kernels/impl/llm_int8_matmul_kernel_impl.h -index 5ebbc8d2db..162fb3bffb 100644 +index 5ebbc8d2db..c7b6c338e2 100644 --- a/paddle/phi/kernels/impl/llm_int8_matmul_kernel_impl.h +++ b/paddle/phi/kernels/impl/llm_int8_matmul_kernel_impl.h -@@ -15,8 +15,8 @@ limitations under the License. */ +@@ -15,8 +15,9 @@ limitations under the License. */ #include #include #include "paddle/phi/common/datatype_traits.h" @@ -1105,9 +1105,19 @@ index 5ebbc8d2db..162fb3bffb 100644 -#include "paddle/phi/kernels/funcs/quant_dequant.h" +#include "kernels/funcs/blas/cublaslt.h" +#include "kernels/funcs/quant_dequant.h" ++#include "kernels/metax_kernel/metax_context.h" #pragma once +@@ -668,7 +669,7 @@ void LLMGemm(const phi::GPUContext& dev_ctx, + + { + auto helper = +- std::make_unique(m, k, n, dev_ctx.cublaslt_handle()); ++ std::make_unique(m, k, n, GetBlasLtHandle()); + helper->GEMM(quant_input.data(), + weight->data(), + int_out.data(), diff --git a/paddle/phi/kernels/impl/matrix_power_grad_kernel_impl.h b/paddle/phi/kernels/impl/matrix_power_grad_kernel_impl.h index 1f319c4ae3..9186eb6906 100644 --- a/paddle/phi/kernels/impl/matrix_power_grad_kernel_impl.h From dfac884216c8d70a6cd112229b319fcd7eb84620 Mon Sep 17 00:00:00 2001 From: jiaxinWang-metax <189149612@qq.com> Date: Mon, 22 Sep 2025 17:24:02 +0800 Subject: [PATCH 12/12] modify context --- .../kernels/metax_kernel/metax_context.cc | 18 ++++++++++++++++++ .../kernels/metax_kernel/metax_context.h | 2 ++ 2 files changed, 20 insertions(+) diff --git a/backends/metax_gpu/kernels/metax_kernel/metax_context.cc b/backends/metax_gpu/kernels/metax_kernel/metax_context.cc index 6d86c81041f..efddba5f00b 100644 --- a/backends/metax_gpu/kernels/metax_kernel/metax_context.cc +++ b/backends/metax_gpu/kernels/metax_kernel/metax_context.cc @@ -15,6 +15,24 @@ #include "kernels/metax_kernel/metax_context.h" namespace phi { +const bool allow_tf32_cublas = []() -> bool { + const char* v = std::getenv("ALLOW_TF32_CUBLAS"); + if (v) { + return std::atoi(v); + } + return true; +}(); + +const bool allow_tf32_cudnn = []() -> bool { + const char* v = std::getenv("ALLOW_TF32_CUDNN"); + if (v) { + return std::atoi(v); + } + return false; +}(); + +bool AllowTF32Cublas() { return allow_tf32_cublas; } +bool AllowTF32Cudnn() { return allow_tf32_cudnn; } void DnnWorkspaceHandle::RunFuncSync( const std::function& cudnn_func, size_t required_workspace_bytes, diff --git a/backends/metax_gpu/kernels/metax_kernel/metax_context.h b/backends/metax_gpu/kernels/metax_kernel/metax_context.h index 376981f27a4..2d761439089 100644 --- a/backends/metax_gpu/kernels/metax_kernel/metax_context.h +++ b/backends/metax_gpu/kernels/metax_kernel/metax_context.h @@ -30,6 +30,8 @@ cublasLtHandle_t GetBlasLtHandle(); namespace phi { +bool AllowTF32Cublas(); +bool AllowTF32Cudnn(); class DnnWorkspaceHandle { public: inline DnnWorkspaceHandle(Allocator* allocator, gpuStream_t stream)