Skip to content

Commit

Permalink
use cudnnGet instead of cudnnFind when determinism required
Browse files Browse the repository at this point in the history
  • Loading branch information
apeforest committed Nov 3, 2018
1 parent 3d6ef7b commit d1bdf0f
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
8 changes: 4 additions & 4 deletions src/operator/nn/cudnn/cudnn_convolution-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -635,8 +635,8 @@ class CuDNNConvolutionOp {
std::vector<cudnnConvolutionFwdAlgoPerf_t> fwd_results(MaxForwardAlgos(s->dnn_handle_));
int actual_fwd_algos = 0;
auto fwd_algo_discoverer =
param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionForwardAlgorithm_v7
: cudnnFindConvolutionForwardAlgorithm;
(param_.cudnn_tune.value() == conv::kOff || dmlc::GetEnv("MXNET_ENFORCE_DETERMINISM", 0)) ?
cudnnGetConvolutionForwardAlgorithm_v7 : cudnnFindConvolutionForwardAlgorithm;
CUDNN_CALL((*fwd_algo_discoverer)(s->dnn_handle_,
in_desc_,
filter_desc_,
Expand All @@ -657,8 +657,8 @@ class CuDNNConvolutionOp {
// In cudnn v7.1.4, find() returned wgrad algos that could fail for large c if we
// were summing into the output (i.e. beta != 0). Get() returned OK algos though.
auto bwd_filter_algo_discoverer =
param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionBackwardFilterAlgorithm_v7
: cudnnFindConvolutionBackwardFilterAlgorithm;
(param_.cudnn_tune.value() == conv::kOff || dmlc::GetEnv("MXNET_ENFORCE_DETERMINISM", 0)) ?
cudnnGetConvolutionBackwardFilterAlgorithm_v7 : cudnnFindConvolutionBackwardFilterAlgorithm;
CUDNN_CALL((*bwd_filter_algo_discoverer)(s->dnn_handle_,
in_desc_,
out_desc_,
Expand Down
12 changes: 6 additions & 6 deletions src/operator/nn/cudnn/cudnn_deconvolution-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -562,8 +562,8 @@ class CuDNNDeconvolutionOp {
std::vector<cudnnConvolutionFwdAlgoPerf_t> fwd_results(MaxForwardAlgos(s->dnn_handle_));
int actual_fwd_algos = 0;
auto fwd_algo_discoverer =
param_.cudnn_tune.value() == deconv::kOff ? cudnnGetConvolutionForwardAlgorithm_v7
: cudnnFindConvolutionForwardAlgorithm;
(param_.cudnn_tune.value() == conv::kOff || dmlc::GetEnv("MXNET_ENFORCE_DETERMINISM", 0)) ?
cudnnGetConvolutionForwardAlgorithm_v7 : cudnnFindConvolutionForwardAlgorithm;
CUDNN_CALL((*fwd_algo_discoverer)(s->dnn_handle_,
out_desc_,
filter_desc_,
Expand All @@ -584,8 +584,8 @@ class CuDNNDeconvolutionOp {
// In cudnn v7.1.4, find() returned wgrad algos that could fail for large c if we
// were summing into the output (i.e. beta != 0). Get() returned OK algos though.
auto bwd_filter_algo_discoverer =
param_.cudnn_tune.value() == deconv::kOff ? cudnnGetConvolutionBackwardFilterAlgorithm_v7
: cudnnFindConvolutionBackwardFilterAlgorithm;
(param_.cudnn_tune.value() == conv::kOff || dmlc::GetEnv("MXNET_ENFORCE_DETERMINISM", 0)) ?
cudnnGetConvolutionBackwardFilterAlgorithm_v7 : cudnnFindConvolutionBackwardFilterAlgorithm;
CUDNN_CALL((*bwd_filter_algo_discoverer)(s->dnn_handle_,
out_desc_,
in_desc_,
Expand All @@ -603,8 +603,8 @@ class CuDNNDeconvolutionOp {
std::vector<cudnnConvolutionBwdDataAlgoPerf_t> bwd_data_results(max_bwd_data_algos);
int actual_bwd_data_algos = 0;
auto bwd_data_algo_discoverer =
param_.cudnn_tune.value() == deconv::kOff ? cudnnGetConvolutionBackwardDataAlgorithm_v7
: cudnnFindConvolutionBackwardDataAlgorithm;
(param_.cudnn_tune.value() == conv::kOff || dmlc::GetEnv("MXNET_ENFORCE_DETERMINISM", 0)) ?
cudnnGetConvolutionBackwardDataAlgorithm_v7 : cudnnFindConvolutionBackwardDataAlgorithm;
CUDNN_CALL((*bwd_data_algo_discoverer)(s->dnn_handle_,
filter_desc_,
in_desc_,
Expand Down

0 comments on commit d1bdf0f

Please sign in to comment.