diff --git a/backends/metax_gpu/kernels/gpudnn/conv_grad_kernel_register.cu b/backends/metax_gpu/kernels/gpudnn/conv_grad_kernel_register.cu index e4acb2f95b6..2da42c7ff8c 100644 --- a/backends/metax_gpu/kernels/gpudnn/conv_grad_kernel_register.cu +++ b/backends/metax_gpu/kernels/gpudnn/conv_grad_kernel_register.cu @@ -437,26 +437,22 @@ void ConvCudnnGradKernel(const Context& dev_ctx, dev_ctx.template Alloc(filter_grad); } - // bool has_use_addto = dev_ctx.HasDnnAttr("use_addto"); - bool has_use_addto = "true"; + bool has_use_addto = dev_ctx.HasDnnAttr("use_addto"); VLOG(4) << "GPUContext contains `use_addto`: " << has_use_addto; - // bool use_addto = has_use_addto - // ? PADDLE_GET_CONST(bool, "true") - // : false; - bool use_addto = "true"; + bool use_addto = has_use_addto + ? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("use_addto")) + : false; std::vector dilations = dilations_t; std::vector strides = strides_t; std::vector paddings = paddings_t; - // bool has_exhaustive_search = dev_ctx.HasDnnAttr("exhaustive_search"); - bool has_exhaustive_search = "true"; + bool has_exhaustive_search = dev_ctx.HasDnnAttr("exhaustive_search"); VLOG(4) << "GPUContext contains `exhaustive_search`: " << has_exhaustive_search; - // bool exhaustive_search_attr = - // has_exhaustive_search - // ? PADDLE_GET_CONST(bool, "true") - // : false; - bool exhaustive_search_attr = "true"; + bool exhaustive_search_attr = + has_exhaustive_search + ? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("exhaustive_search")) + : false; bool exhaustive_search = FLAGS_cudnn_exhaustive_search || exhaustive_search_attr; bool deterministic = FLAGS_cudnn_deterministic; @@ -835,14 +831,13 @@ void ConvCudnnGradGradKernel( T* transformed_dx = nullptr; std::vector dilations = dilations_t; - // bool has_exhaustive_search = dev_ctx.HasDnnAttr("exhaustive_search"); - // VLOG(4) << "GPUContext contains `exhaustive_search`: " - // << has_exhaustive_search; - // bool exhaustive_search_attr = - // has_exhaustive_search - // ? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("exhaustive_search")) - // : false; - bool exhaustive_search_attr = "true"; + bool has_exhaustive_search = dev_ctx.HasDnnAttr("exhaustive_search"); + VLOG(4) << "GPUContext contains `exhaustive_search`: " + << has_exhaustive_search; + bool exhaustive_search_attr = + has_exhaustive_search + ? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("exhaustive_search")) + : false; bool exhaustive_search = FLAGS_cudnn_exhaustive_search || exhaustive_search_attr; bool deterministic = FLAGS_cudnn_deterministic; diff --git a/backends/metax_gpu/kernels/gpudnn/conv_kernel_register.cu b/backends/metax_gpu/kernels/gpudnn/conv_kernel_register.cu index 0a83b504c76..d6b243c956c 100644 --- a/backends/metax_gpu/kernels/gpudnn/conv_kernel_register.cu +++ b/backends/metax_gpu/kernels/gpudnn/conv_kernel_register.cu @@ -228,15 +228,16 @@ void ConvCudnnKernel(const Context& dev_ctx, std::vector paddings = paddings_t; std::vector dilations = dilations_t; - // bool has_exhaustive_search = dev_ctx.HasDnnAttr("exhaustive_search"); - // VLOG(4) << "GPUContext contains `exhaustive_search`: " - // << has_exhaustive_search; - // bool exhaustive_search_attr = - // has_exhaustive_search - // ? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("exhaustive_search")) - // : false; - - bool exhaustive_search = FLAGS_cudnn_exhaustive_search; + bool has_exhaustive_search = dev_ctx.HasDnnAttr("exhaustive_search"); + VLOG(4) << "GPUContext contains `exhaustive_search`: " + << has_exhaustive_search; + bool exhaustive_search_attr = + has_exhaustive_search + ? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("exhaustive_search")) + : false; + + bool exhaustive_search = + FLAGS_cudnn_exhaustive_search || exhaustive_search_attr; bool deterministic = FLAGS_cudnn_deterministic; PADDLE_ENFORCE_EQ(exhaustive_search && deterministic, diff --git a/backends/metax_gpu/kernels/gpudnn/conv_transpose_kernel.cu b/backends/metax_gpu/kernels/gpudnn/conv_transpose_kernel.cu index 532b7af0db4..4049d2f3130 100644 --- a/backends/metax_gpu/kernels/gpudnn/conv_transpose_kernel.cu +++ b/backends/metax_gpu/kernels/gpudnn/conv_transpose_kernel.cu @@ -260,14 +260,13 @@ void ConvTransposeRawGPUDNNKernel(const Context& dev_ctx, return; } - // bool has_exhaustive_search = dev_ctx.HasDnnAttr("exhaustive_search"); - // bool exhaustive_search_attr = - // has_exhaustive_search - // ? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("exhaustive_search")) - // : false; - // bool exhaustive_search = - // FLAGS_cudnn_exhaustive_search || exhaustive_search_attr; - bool exhaustive_search = FLAGS_cudnn_exhaustive_search; + bool has_exhaustive_search = dev_ctx.HasDnnAttr("exhaustive_search"); + bool exhaustive_search_attr = + has_exhaustive_search + ? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("exhaustive_search")) + : false; + bool exhaustive_search = + FLAGS_cudnn_exhaustive_search || exhaustive_search_attr; bool deterministic = FLAGS_cudnn_deterministic; PADDLE_ENFORCE_EQ(exhaustive_search && deterministic, diff --git a/backends/metax_gpu/kernels/metax_kernel/depthwise_conv_grad_kernel.cu b/backends/metax_gpu/kernels/metax_kernel/depthwise_conv_grad_kernel.cu index f2475298963..4e5f881385a 100644 --- a/backends/metax_gpu/kernels/metax_kernel/depthwise_conv_grad_kernel.cu +++ b/backends/metax_gpu/kernels/metax_kernel/depthwise_conv_grad_kernel.cu @@ -54,14 +54,12 @@ void DepthwiseConvGradKernel(const Context& dev_ctx, return; } - // bool has_fuse_relu = dev_ctx.HasDnnAttr("fuse_relu_before_depthwise_conv"); - // bool fuse_relu = - // has_fuse_relu - // ? PADDLE_GET_CONST( - // bool, dev_ctx.GetDnnAttr("fuse_relu_before_depthwise_conv")) - // : false; - bool has_fuse_relu = false; - bool fuse_relu = false; + bool has_fuse_relu = dev_ctx.HasDnnAttr("fuse_relu_before_depthwise_conv"); + bool fuse_relu = + has_fuse_relu + ? PADDLE_GET_CONST( + bool, dev_ctx.GetDnnAttr("fuse_relu_before_depthwise_conv")) + : false; std::vector strides = strides_t; std::vector paddings = paddings_t; diff --git a/backends/metax_gpu/kernels/metax_kernel/depthwise_conv_kernel.cu b/backends/metax_gpu/kernels/metax_kernel/depthwise_conv_kernel.cu index 517f26b1c02..d3d6c4a4edd 100644 --- a/backends/metax_gpu/kernels/metax_kernel/depthwise_conv_kernel.cu +++ b/backends/metax_gpu/kernels/metax_kernel/depthwise_conv_kernel.cu @@ -48,14 +48,12 @@ void DepthwiseConvKernel(const Context& dev_ctx, const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC"); - // bool has_fuse_relu = dev_ctx.HasDnnAttr("fuse_relu_before_depthwise_conv"); - // bool fuse_relu = - // has_fuse_relu - // ? PADDLE_GET_CONST( - // bool, dev_ctx.GetDnnAttr("fuse_relu_before_depthwise_conv")) - // : false; - bool has_fuse_relu = false; - bool fuse_relu = false; + bool has_fuse_relu = dev_ctx.HasDnnAttr("fuse_relu_before_depthwise_conv"); + bool fuse_relu = + has_fuse_relu + ? PADDLE_GET_CONST( + bool, dev_ctx.GetDnnAttr("fuse_relu_before_depthwise_conv")) + : false; if (channel_last) { PADDLE_ENFORCE_EQ(