From 61294b929e2fe4df2a2d4610e88043a40829dd1d Mon Sep 17 00:00:00 2001 From: Leonard Lausen Date: Sun, 27 Oct 2019 10:21:14 -0700 Subject: [PATCH] RNNOp only call cuda/cudnn if GPU ctx is requested (#16632) --- src/operator/rnn-inl.h | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index ead7501a48b0..b448261f215d 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -422,6 +422,8 @@ class RNNOp { init_mem_ = false; reserve_mem_size_ = 0; #endif + + if (ctx_.dev_type == kGPU) { #if MXNET_USE_CUDNN == 1 init_cudnn_ = false; dtype_ = mshadow::DataType::kCudnnFlag; @@ -505,6 +507,7 @@ class RNNOp { LOG(FATAL) << "RNN on GPU is only available for cuDNN at the moment."; } #endif // MXNET_USE_CUDNN == 1 + } if (ctx_.dev_type == kCPU) { this->init_space_ = false; @@ -523,6 +526,7 @@ class RNNOp { } ~RNNOp() { + if (ctx_.dev_type == kGPU) { #if MXNET_USE_CUDNN == 1 CUDNN_CALL(cudnnDestroyTensorDescriptor(hx_desc_)); CUDNN_CALL(cudnnDestroyTensorDescriptor(cx_desc_)); @@ -557,6 +561,7 @@ class RNNOp { CUDNN_CALL(cudnnDestroyRNNDataDescriptor(dy_data_desc_)); #endif // MXNET_USE_CUDNN_GE_7200 #endif // MXNET_USE_CUDNN + } } void Forward(const OpContext &ctx, const std::vector &in_data,