diff --git a/backends/metax_gpu/kernels/metax_kernel/metax_context.cc b/backends/metax_gpu/kernels/metax_kernel/metax_context.cc index 6d86c81041f..efddba5f00b 100644 --- a/backends/metax_gpu/kernels/metax_kernel/metax_context.cc +++ b/backends/metax_gpu/kernels/metax_kernel/metax_context.cc @@ -15,6 +15,24 @@ #include "kernels/metax_kernel/metax_context.h" namespace phi { +const bool allow_tf32_cublas = []() -> bool { + const char* v = std::getenv("ALLOW_TF32_CUBLAS"); + if (v) { + return std::atoi(v); + } + return true; +}(); + +const bool allow_tf32_cudnn = []() -> bool { + const char* v = std::getenv("ALLOW_TF32_CUDNN"); + if (v) { + return std::atoi(v); + } + return false; +}(); + +bool AllowTF32Cublas() { return allow_tf32_cublas; } +bool AllowTF32Cudnn() { return allow_tf32_cudnn; } void DnnWorkspaceHandle::RunFuncSync( const std::function& cudnn_func, size_t required_workspace_bytes, diff --git a/backends/metax_gpu/kernels/metax_kernel/metax_context.h b/backends/metax_gpu/kernels/metax_kernel/metax_context.h index 376981f27a4..2d761439089 100644 --- a/backends/metax_gpu/kernels/metax_kernel/metax_context.h +++ b/backends/metax_gpu/kernels/metax_kernel/metax_context.h @@ -30,6 +30,8 @@ cublasLtHandle_t GetBlasLtHandle(); namespace phi { +bool AllowTF32Cublas(); +bool AllowTF32Cudnn(); class DnnWorkspaceHandle { public: inline DnnWorkspaceHandle(Allocator* allocator, gpuStream_t stream)