Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix cuDNN RNN dtype_with_fallback_ bug
Browse files Browse the repository at this point in the history
  • Loading branch information
stu1130 committed Oct 29, 2019
1 parent 60d74bc commit 32c713f
Showing 1 changed file with 4 additions and 13 deletions.
17 changes: 4 additions & 13 deletions src/operator/rnn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1376,21 +1376,12 @@ class RNNOp {
seed_));

// RNN descriptors
cudnnDataType_t dtype_with_fallback_;
// adopt pseudo-fp16 for all architectures
cudnnDataType_t dtype_with_fallback_ =
(CUDNN_VERSION >= 7500 && dtype_ == CUDNN_DATA_HALF) ? CUDNN_DATA_FLOAT
: dtype_;
cudnnRNNAlgo_t rnn_algo = CUDNN_RNN_ALGO_STANDARD;
dgrad_sync_needed_ = (rnn_algo == CUDNN_RNN_ALGO_STANDARD) && param_.bidirectional;
// On arch's 50 and 52(Maxwell), the gpu doesn't support native fp16 compute.
// Before cuDNN 7.5.0, when running fp16, cuDNN fallback to fp32 under the hood on Maxwell.
// That's not the case begining from 7.5.0. Thereby adding fallback explicitly here.
#if __CUDA_ARCH__ < 530 && CUDNN_VERSION >= 7500
if (dtype_ == CUDNN_DATA_HALF) {
dtype_with_fallback_ = CUDNN_DATA_FLOAT;
} else {
dtype_with_fallback_ = dtype_;
}
#else
dtype_with_fallback_ = dtype_;
#endif
CUDNN_CALL(cudnnSetRNNDescriptor_v6(s->dnn_handle_,
rnn_desc_,
param_.state_size,
Expand Down

0 comments on commit 32c713f

Please sign in to comment.