diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index d4971c1e12bb..db2360313aef 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -1332,21 +1332,12 @@ class RNNOp { seed_)); // RNN descriptors - cudnnDataType_t dtype_with_fallback_; + // adopt pseudo-fp16 for all architectures + cudnnDataType_t dtype_with_fallback_ = + (cudnnGetVersion() >= 7500 && dtype_ == CUDNN_DATA_HALF) ? CUDNN_DATA_FLOAT + : dtype_; cudnnRNNAlgo_t rnn_algo = CUDNN_RNN_ALGO_STANDARD; dgrad_sync_needed_ = (rnn_algo == CUDNN_RNN_ALGO_STANDARD) && param_.bidirectional; - // On arch's 50 and 52(Maxwell), the gpu doesn't support native fp16 compute. - // Before cuDNN 7.5.0, when running fp16, cuDNN fallback to fp32 under the hood on Maxwell. - // That's not the case begining from 7.5.0. Thereby adding fallback explicitly here. -#if __CUDA_ARCH__ < 530 && CUDNN_VERSION >= 7500 - if (dtype_ == CUDNN_DATA_HALF) { - dtype_with_fallback_ = CUDNN_DATA_FLOAT; - } else { - dtype_with_fallback_ = dtype_; - } -#else - dtype_with_fallback_ = dtype_; -#endif CUDNN_CALL(cudnnSetRNNDescriptor_v6(s->dnn_handle_, rnn_desc_, param_.state_size,