diff --git a/src/operator/nn/cudnn/cudnn_convolution-inl.h b/src/operator/nn/cudnn/cudnn_convolution-inl.h index e11f7cc81d25..66df82e4395e 100644 --- a/src/operator/nn/cudnn/cudnn_convolution-inl.h +++ b/src/operator/nn/cudnn/cudnn_convolution-inl.h @@ -750,7 +750,7 @@ class CuDNNConvolutionOp { i = 0; while (i < nalgo && (fwd_algo[i].status != CUDNN_STATUS_SUCCESS - || (param_.cudnn_tune.value() == conv::kLimited + || (param_.cudnn_tune.value() != conv::kFastest && fwd_algo[i].memory > workspace_byte))) { ++i; min_memory_needs = @@ -787,7 +787,7 @@ class CuDNNConvolutionOp { i = 0; while (i < nalgo && (bwd_filter_algo[i].status != CUDNN_STATUS_SUCCESS - || (param_.cudnn_tune.value() == conv::kLimited + || (param_.cudnn_tune.value() != conv::kFastest && bwd_filter_algo[i].memory > workspace_byte))) { ++i; min_memory_needs = (i == 0) ? @@ -825,7 +825,7 @@ class CuDNNConvolutionOp { i = 0; while (i < nalgo && (bwd_data_algo[i].status != CUDNN_STATUS_SUCCESS - || (param_.cudnn_tune.value() == conv::kLimited + || (param_.cudnn_tune.value() != conv::kFastest && bwd_data_algo[i].memory > workspace_byte))) { ++i; min_memory_needs = (i == 0) ? @@ -924,7 +924,7 @@ class CuDNNConvolutionOp { #if CUDNN_MAJOR >= 7 (!enforce_determinism || result.determinism == cudnnDeterminism_t::CUDNN_DETERMINISTIC) && #endif - (param_.cudnn_tune.value() != conv::kLimited || result.memory <= workspace_byte)) { + (param_.cudnn_tune.value() == conv::kFastest || result.memory <= workspace_byte)) { algo->Set(result.algo, algo_is_tensor_core); return; }