From 32c713fc242314896d2e07f8c26cbae4448538ec Mon Sep 17 00:00:00 2001 From: stu1130 Date: Tue, 29 Oct 2019 13:56:50 -0700 Subject: [PATCH] fix cuDNN RNN dtype_with_fallback_ bug --- src/operator/rnn-inl.h | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index b448261f215d..5eb1a1d2cb57 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -1376,21 +1376,12 @@ class RNNOp { seed_)); // RNN descriptors - cudnnDataType_t dtype_with_fallback_; + // adopt pseudo-fp16 for all architectures + cudnnDataType_t dtype_with_fallback_ = + (CUDNN_VERSION >= 7500 && dtype_ == CUDNN_DATA_HALF) ? CUDNN_DATA_FLOAT + : dtype_; cudnnRNNAlgo_t rnn_algo = CUDNN_RNN_ALGO_STANDARD; dgrad_sync_needed_ = (rnn_algo == CUDNN_RNN_ALGO_STANDARD) && param_.bidirectional; - // On arch's 50 and 52(Maxwell), the gpu doesn't support native fp16 compute. - // Before cuDNN 7.5.0, when running fp16, cuDNN fallback to fp32 under the hood on Maxwell. - // That's not the case begining from 7.5.0. Thereby adding fallback explicitly here. -#if __CUDA_ARCH__ < 530 && CUDNN_VERSION >= 7500 - if (dtype_ == CUDNN_DATA_HALF) { - dtype_with_fallback_ = CUDNN_DATA_FLOAT; - } else { - dtype_with_fallback_ = dtype_; - } -#else - dtype_with_fallback_ = dtype_; -#endif CUDNN_CALL(cudnnSetRNNDescriptor_v6(s->dnn_handle_, rnn_desc_, param_.state_size,