diff --git a/backends/metax_gpu/kernels/metax_kernel/fused_conv2d_add_act_kernel_register.cu b/backends/metax_gpu/kernels/fusion/fused_conv2d_add_act_kernel_register.cu similarity index 100% rename from backends/metax_gpu/kernels/metax_kernel/fused_conv2d_add_act_kernel_register.cu rename to backends/metax_gpu/kernels/fusion/fused_conv2d_add_act_kernel_register.cu diff --git a/backends/metax_gpu/kernels/metax_kernel/fused_rope_grad_kernel_register.cu b/backends/metax_gpu/kernels/fusion/fused_rope_grad_kernel_register.cu similarity index 100% rename from backends/metax_gpu/kernels/metax_kernel/fused_rope_grad_kernel_register.cu rename to backends/metax_gpu/kernels/fusion/fused_rope_grad_kernel_register.cu diff --git a/backends/metax_gpu/kernels/metax_kernel/fused_rope_kernel_register.cu b/backends/metax_gpu/kernels/fusion/fused_rope_kernel_register.cu similarity index 100% rename from backends/metax_gpu/kernels/metax_kernel/fused_rope_kernel_register.cu rename to backends/metax_gpu/kernels/fusion/fused_rope_kernel_register.cu diff --git a/backends/metax_gpu/kernels/metax_kernel/metax_context.cc b/backends/metax_gpu/kernels/metax_kernel/metax_context.cc index efddba5f00b..0712fb75bbe 100644 --- a/backends/metax_gpu/kernels/metax_kernel/metax_context.cc +++ b/backends/metax_gpu/kernels/metax_kernel/metax_context.cc @@ -15,24 +15,6 @@ #include "kernels/metax_kernel/metax_context.h" namespace phi { -const bool allow_tf32_cublas = []() -> bool { - const char* v = std::getenv("ALLOW_TF32_CUBLAS"); - if (v) { - return std::atoi(v); - } - return true; -}(); - -const bool allow_tf32_cudnn = []() -> bool { - const char* v = std::getenv("ALLOW_TF32_CUDNN"); - if (v) { - return std::atoi(v); - } - return false; -}(); - -bool AllowTF32Cublas() { return allow_tf32_cublas; } -bool AllowTF32Cudnn() { return allow_tf32_cudnn; } void DnnWorkspaceHandle::RunFuncSync( const std::function& cudnn_func, size_t required_workspace_bytes, @@ -42,19 +24,11 @@ void DnnWorkspaceHandle::RunFuncSync( void* workspace_ptr = nullptr; size_t size = ((required_workspace_bytes + 255) >> 8) << 8; std::lock_guard guard(*mtx_); -#ifdef PADDLE_WITH_HIP - auto status = hipMalloc(&workspace_ptr, size); -#else auto status = cudaMalloc(&workspace_ptr, size); -#endif if (status == gpuSuccess) { cudnn_func(workspace_ptr); phi::backends::gpu::GpuStreamSync(stream_); -#ifdef PADDLE_WITH_HIP - PADDLE_ENFORCE_GPU_SUCCESS(hipFree(workspace_ptr)); -#else PADDLE_ENFORCE_GPU_SUCCESS(cudaFree(workspace_ptr)); -#endif return; } } diff --git a/backends/metax_gpu/kernels/metax_kernel/metax_context.h b/backends/metax_gpu/kernels/metax_kernel/metax_context.h index 2d761439089..7386811a236 100644 --- a/backends/metax_gpu/kernels/metax_kernel/metax_context.h +++ b/backends/metax_gpu/kernels/metax_kernel/metax_context.h @@ -18,6 +18,7 @@ #include #include "kernels/funcs/blas/cublasLt.h" +#include "paddle/phi/backends/context_pool.h" #include "paddle/phi/backends/custom/custom_context.h" #include "paddle/phi/backends/gpu/forwards.h" #include "paddle/phi/backends/gpu/gpu_decls.h" @@ -30,8 +31,6 @@ cublasLtHandle_t GetBlasLtHandle(); namespace phi { -bool AllowTF32Cublas(); -bool AllowTF32Cudnn(); class DnnWorkspaceHandle { public: inline DnnWorkspaceHandle(Allocator* allocator, gpuStream_t stream)