Skip to content

Commit

Permalink
RNNOp only call cuda/cudnn if GPU ctx is requested (apache#16632)
Browse files Browse the repository at this point in the history
  • Loading branch information
leezu authored and yajiedesign committed Nov 6, 2019
1 parent 498388e commit 61294b9
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/operator/rnn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,8 @@ class RNNOp {
init_mem_ = false;
reserve_mem_size_ = 0;
#endif

if (ctx_.dev_type == kGPU) {
#if MXNET_USE_CUDNN == 1
init_cudnn_ = false;
dtype_ = mshadow::DataType<DType>::kCudnnFlag;
Expand Down Expand Up @@ -505,6 +507,7 @@ class RNNOp {
LOG(FATAL) << "RNN on GPU is only available for cuDNN at the moment.";
}
#endif // MXNET_USE_CUDNN == 1
}

if (ctx_.dev_type == kCPU) {
this->init_space_ = false;
Expand All @@ -523,6 +526,7 @@ class RNNOp {
}

~RNNOp() {
if (ctx_.dev_type == kGPU) {
#if MXNET_USE_CUDNN == 1
CUDNN_CALL(cudnnDestroyTensorDescriptor(hx_desc_));
CUDNN_CALL(cudnnDestroyTensorDescriptor(cx_desc_));
Expand Down Expand Up @@ -557,6 +561,7 @@ class RNNOp {
CUDNN_CALL(cudnnDestroyRNNDataDescriptor(dy_data_desc_));
#endif // MXNET_USE_CUDNN_GE_7200
#endif // MXNET_USE_CUDNN
}
}

void Forward(const OpContext &ctx, const std::vector<TBlob> &in_data,
Expand Down

0 comments on commit 61294b9

Please sign in to comment.