Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix cuDNN RNN dtype_with_fallback_ bug #16671

Merged
merged 1 commit into from
Oct 30, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 4 additions & 13 deletions src/operator/rnn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1376,21 +1376,12 @@ class RNNOp {
seed_));

// RNN descriptors
cudnnDataType_t dtype_with_fallback_;
// adopt pseudo-fp16 for all architectures
cudnnDataType_t dtype_with_fallback_ =
(cudnnGetVersion() >= 7500 && dtype_ == CUDNN_DATA_HALF) ? CUDNN_DATA_FLOAT
: dtype_;
cudnnRNNAlgo_t rnn_algo = CUDNN_RNN_ALGO_STANDARD;
dgrad_sync_needed_ = (rnn_algo == CUDNN_RNN_ALGO_STANDARD) && param_.bidirectional;
// On arch's 50 and 52(Maxwell), the gpu doesn't support native fp16 compute.
// Before cuDNN 7.5.0, when running fp16, cuDNN fallback to fp32 under the hood on Maxwell.
// That's not the case begining from 7.5.0. Thereby adding fallback explicitly here.
#if __CUDA_ARCH__ < 530 && CUDNN_VERSION >= 7500
if (dtype_ == CUDNN_DATA_HALF) {
dtype_with_fallback_ = CUDNN_DATA_FLOAT;
} else {
dtype_with_fallback_ = dtype_;
}
#else
dtype_with_fallback_ = dtype_;
#endif
CUDNN_CALL(cudnnSetRNNDescriptor_v6(s->dnn_handle_,
rnn_desc_,
param_.state_size,
Expand Down