From c97ca15a9ce0c9a58b2334afece7b612a7cbd729 Mon Sep 17 00:00:00 2001 From: Zixuan Wei Date: Tue, 29 Oct 2019 14:58:44 +0800 Subject: [PATCH] Adopt autograd.record() context to RNNOp (#16657) --- src/operator/nn/mkldnn/mkldnn_rnn.cc | 26 +++-- src/operator/rnn-inl.h | 144 +++++++++++++-------------- 2 files changed, 88 insertions(+), 82 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_rnn.cc b/src/operator/nn/mkldnn/mkldnn_rnn.cc index 216aac22390c..f713c497a077 100644 --- a/src/operator/nn/mkldnn/mkldnn_rnn.cc +++ b/src/operator/nn/mkldnn/mkldnn_rnn.cc @@ -618,6 +618,9 @@ void MKLDNNRnnOp::Init(const OpContext &ctx, using memory = mkldnn::memory; using format_tag = mkldnn::memory::format_tag; + // In the `autograd.record()` context, RNNOp is required to run into + // `forward_training` mode. + const bool is_training = (ctx.is_train || ctx.need_grad); const size_t num_fusion = full_param_.layer_params.size(); if (fwd_inf_vec_.size() < num_fusion) { size_t buffer_size = 0; // Element number, instead of bytes, in the buffer @@ -635,7 +638,7 @@ void MKLDNNRnnOp::Init(const OpContext &ctx, mgr_.Init(buffer_size, ctx.run_ctx.ctx, inputs[rnn_enum::kParams].dtype()); } - if (ctx.is_train && fwd_trn_vec_.size() < num_fusion) { + if (is_training && fwd_trn_vec_.size() < num_fusion) { for (auto& layer_param : full_param_.layer_params) { fwd_trn_vec_.emplace_back(layer_param, true, inputs[rnn_enum::kData], inputs[rnn_enum::kParams]); @@ -659,13 +662,13 @@ void MKLDNNRnnOp::Init(const OpContext &ctx, size_t layer_weights_bytes = single_w_bytes * directions; size_t layer_bias_bytes = single_b_bytes * directions; // Naive MXNet has double bias - if (!fwd_layer.IsInitialized() || ctx.is_train) + if (!fwd_layer.IsInitialized() || is_training) fwd_layer.SetWeightsMem(&(this->mgr_), weights_ptr, bias_ptr, dtype); weights_ptr += layer_weights_bytes; bias_ptr += layer_bias_bytes; } - if (ctx.is_train) { + if (is_training) { CHECK_EQ(fwd_trn_vec_.size(), fwd_inf_vec_.size()) << "Layers' configurations of forward inference and forward training are disparate."; for (size_t lyr = 0; lyr < fwd_inf_vec_.size(); ++lyr) @@ -898,6 +901,9 @@ void MKLDNNRnnOp::Forward(const OpContext &ctx, const std::vector &inputs, const std::vector &req, const std::vector &outputs) { + // In the `autograd.record()` context, RNNOp is required to run into + // forward_training mode. + const bool is_training = (ctx.is_train || ctx.need_grad); // check output requests if (kAddTo == req[rnn_enum::kOut]) LOG(FATAL) << "Currently, `add` operation is not supported by RNNs."; @@ -922,7 +928,7 @@ void MKLDNNRnnOp::Forward(const OpContext &ctx, weights_version_ = inputs[rnn_enum::kParams].version(); } - if (!initialized_ || ctx.is_train || fwd_trn_vec_.size() == 0) { + if (!initialized_ || is_training || fwd_trn_vec_.size() == 0) { Init(ctx, inputs, req, outputs); } @@ -952,7 +958,7 @@ void MKLDNNRnnOp::Forward(const OpContext &ctx, if (fwd_inf_vec_.size() == 1) { fwd_inf_vec_.front().SetNewDataMem(src, src_state, src_state_cell, dst, dst_state, dst_state_cell, data_dtype); - if (ctx.is_train) { + if (is_training) { fwd_trn_vec_.front().FetchData(fwd_inf_vec_.front()); } } else { @@ -964,7 +970,7 @@ void MKLDNNRnnOp::Forward(const OpContext &ctx, // results in this->xxx, used as the source input of the next layer. fwd_inf_vec_.front().SetNewDataMem(src, src_state, src_state_cell, this->dst_.front()->get_data_handle(), dst_state, dst_state_cell, data_dtype); - if (ctx.is_train) { + if (is_training) { fwd_trn_vec_.front().FetchData(fwd_inf_vec_.front()); } // 1st_lyr -> dst_handle -> next_lyr -> dst_handle -> next_lyr -> ... @@ -976,7 +982,7 @@ void MKLDNNRnnOp::Forward(const OpContext &ctx, fwd_inf_vec_.at(lyr).SetNewDataMem(this->dst_.at(lyr - 1)->get_data_handle(), src_state, src_state_cell, this->dst_.at(lyr)->get_data_handle(), dst_state, dst_state_cell, data_dtype); - if (ctx.is_train) { + if (is_training) { fwd_trn_vec_.at(lyr).FetchData(fwd_inf_vec_.at(lyr)); } } @@ -987,11 +993,11 @@ void MKLDNNRnnOp::Forward(const OpContext &ctx, if (dst_state_cell) dst_state_cell += cell_bytes; fwd_inf_vec_.back().SetNewDataMem(this->dst_.back()->get_data_handle(), src_state, src_state_cell, dst, dst_state, dst_state_cell, data_dtype); - if (ctx.is_train) { + if (is_training) { fwd_trn_vec_.back().FetchData(fwd_inf_vec_.back()); } } - if (ctx.is_train) { + if (is_training) { for (auto& trn_lyr : fwd_trn_vec_) RegisterMKLDNNRnn(trn_lyr); } else { for (auto& inf_lyr : fwd_inf_vec_) RegisterMKLDNNRnn(inf_lyr); @@ -1092,7 +1098,7 @@ void MKLDNNRnnOp::Backward(const OpContext& ctx, dy, dhy, dcy, data_dtype); for (std::vector::const_reverse_iterator bwd = bwd_vec_.rbegin(); - bwd < bwd_vec_.rend(); ++bwd) { + bwd != bwd_vec_.rend(); ++bwd) { RegisterMKLDNNRnn(*bwd); } } diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index 9019726ebcc4..d4971c1e12bb 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -422,87 +422,87 @@ class RNNOp { if (ctx_.dev_type == kGPU) { #if MXNET_USE_CUDNN == 1 - init_cudnn_ = false; - dtype_ = mshadow::DataType::kCudnnFlag; - // TensorCore algos only allowed on fp16-I/O convolutions if permitted by the global policy. - // No tests in place for fp16 RNNs, so leave TensorCore disabled for now. - cudnn_tensor_core_ = false; - // When fp16 RNN tests are introduced, we can enable TensorCore as follows: - // cudnn_tensor_core = - // mshadow::DataType::kFlag == mshadow::kFloat16 && GetEnvAllowTensorCore(); - // Defaults - input_mode_ = CUDNN_LINEAR_INPUT; // Don't support this yet - // RNN Mode - switch (param_.mode) { - case rnn_enum::kRnnRelu: - mode_ = CUDNN_RNN_RELU; - break; - case rnn_enum::kRnnTanh: - mode_ = CUDNN_RNN_TANH; - break; - case rnn_enum::kLstm: - mode_ = CUDNN_LSTM; - break; - case rnn_enum::kGru: - mode_ = CUDNN_GRU; - break; - default: - LOG(FATAL) << "Not implmented"; - } + init_cudnn_ = false; + dtype_ = mshadow::DataType::kCudnnFlag; + // TensorCore algos only allowed on fp16-I/O convolutions if permitted by the global policy. + // No tests in place for fp16 RNNs, so leave TensorCore disabled for now. + cudnn_tensor_core_ = false; + // When fp16 RNN tests are introduced, we can enable TensorCore as follows: + // cudnn_tensor_core = + // mshadow::DataType::kFlag == mshadow::kFloat16 && GetEnvAllowTensorCore(); + // Defaults + input_mode_ = CUDNN_LINEAR_INPUT; // Don't support this yet + // RNN Mode + switch (param_.mode) { + case rnn_enum::kRnnRelu: + mode_ = CUDNN_RNN_RELU; + break; + case rnn_enum::kRnnTanh: + mode_ = CUDNN_RNN_TANH; + break; + case rnn_enum::kLstm: + mode_ = CUDNN_LSTM; + break; + case rnn_enum::kGru: + mode_ = CUDNN_GRU; + break; + default: + LOG(FATAL) << "Not implmented"; + } #if MXNET_USE_CUDNN_GE_7200 - if (param_.projection_size.has_value()) { - CHECK_EQ(param_.mode, rnn_enum::kLstm) - << "Projection is only supported for LSTM."; - CHECK_GE(param_.state_size, param_.projection_size.value()) - << "State size must be larger than projection size."; - } + if (param_.projection_size.has_value()) { + CHECK_EQ(param_.mode, rnn_enum::kLstm) + << "Projection is only supported for LSTM."; + CHECK_GE(param_.state_size, param_.projection_size.value()) + << "State size must be larger than projection size."; + } #else - CHECK(!param_.projection_size.has_value()) - << "Projection is only supported for LSTM with CuDNN version later than 7.1.1."; + CHECK(!param_.projection_size.has_value()) + << "Projection is only supported for LSTM with CuDNN version later than 7.1.1."; #endif // MXNET_USE_CUDNN_GE_7200 #if MXNET_USE_CUDNN_GE_7200 - if (param_.lstm_state_clip_min.has_value() - || param_.lstm_state_clip_max.has_value()) { - CHECK_EQ(param_.mode, rnn_enum::kLstm) - << "State clipping is only supported for LSTM."; - CHECK(param_.lstm_state_clip_min.has_value() && param_.lstm_state_clip_max.has_value()) - << "lstm_state_clip_min and lstm_state_clip_max must be specified together."; - CHECK_GE(param_.lstm_state_clip_max.value(), param_.lstm_state_clip_min.value()) - << "lstm_state_clip_max must be greater or equal to lstm_state_clip_min"; - } + if (param_.lstm_state_clip_min.has_value() + || param_.lstm_state_clip_max.has_value()) { + CHECK_EQ(param_.mode, rnn_enum::kLstm) + << "State clipping is only supported for LSTM."; + CHECK(param_.lstm_state_clip_min.has_value() && param_.lstm_state_clip_max.has_value()) + << "lstm_state_clip_min and lstm_state_clip_max must be specified together."; + CHECK_GE(param_.lstm_state_clip_max.value(), param_.lstm_state_clip_min.value()) + << "lstm_state_clip_max must be greater or equal to lstm_state_clip_min"; + } #else - CHECK(!param_.lstm_state_clip_min.has_value() - && !param_.lstm_state_clip_max.has_value()) - << "State clipping is only supported for LSTM with CuDNN version later than 7.2.1."; + CHECK(!param_.lstm_state_clip_min.has_value() + && !param_.lstm_state_clip_max.has_value()) + << "State clipping is only supported for LSTM with CuDNN version later than 7.2.1."; #endif // MXNET_USE_CUDNN_GE_7200 - // RNN Direction - direction_ = param_.bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL; - // Create descriptors - CUDNN_CALL(cudnnCreateTensorDescriptor(&hx_desc_)); - CUDNN_CALL(cudnnCreateTensorDescriptor(&cx_desc_)); - CUDNN_CALL(cudnnCreateTensorDescriptor(&hy_desc_)); - CUDNN_CALL(cudnnCreateTensorDescriptor(&cy_desc_)); - CUDNN_CALL(cudnnCreateTensorDescriptor(&dhx_desc_)); - CUDNN_CALL(cudnnCreateTensorDescriptor(&dcx_desc_)); - CUDNN_CALL(cudnnCreateTensorDescriptor(&dhy_desc_)); - CUDNN_CALL(cudnnCreateTensorDescriptor(&dcy_desc_)); - - CUDNN_CALL(cudnnCreateFilterDescriptor(&w_desc_)); - CUDNN_CALL(cudnnCreateFilterDescriptor(&dw_desc_)); - - CUDNN_CALL(cudnnCreateRNNDescriptor(&rnn_desc_)); - CUDNN_CALL(cudnnCreateDropoutDescriptor(&dropout_desc_)); + // RNN Direction + direction_ = param_.bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL; + // Create descriptors + CUDNN_CALL(cudnnCreateTensorDescriptor(&hx_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&cx_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&hy_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&cy_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&dhx_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&dcx_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&dhy_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&dcy_desc_)); + + CUDNN_CALL(cudnnCreateFilterDescriptor(&w_desc_)); + CUDNN_CALL(cudnnCreateFilterDescriptor(&dw_desc_)); + + CUDNN_CALL(cudnnCreateRNNDescriptor(&rnn_desc_)); + CUDNN_CALL(cudnnCreateDropoutDescriptor(&dropout_desc_)); #if MXNET_USE_CUDNN_GE_7200 - CUDNN_CALL(cudnnCreateRNNDataDescriptor(&x_data_desc_)); - CUDNN_CALL(cudnnCreateRNNDataDescriptor(&y_data_desc_)); - CUDNN_CALL(cudnnCreateRNNDataDescriptor(&dx_data_desc_)); - CUDNN_CALL(cudnnCreateRNNDataDescriptor(&dy_data_desc_)); + CUDNN_CALL(cudnnCreateRNNDataDescriptor(&x_data_desc_)); + CUDNN_CALL(cudnnCreateRNNDataDescriptor(&y_data_desc_)); + CUDNN_CALL(cudnnCreateRNNDataDescriptor(&dx_data_desc_)); + CUDNN_CALL(cudnnCreateRNNDataDescriptor(&dy_data_desc_)); #endif // MXNET_USE_CUDNN_GE_7200 #else - if (ctx_.dev_type == kGPU) { - LOG(FATAL) << "RNN on GPU is only available for cuDNN at the moment."; - } + if (ctx_.dev_type == kGPU) { + LOG(FATAL) << "RNN on GPU is only available for cuDNN at the moment."; + } #endif // MXNET_USE_CUDNN == 1 } @@ -854,7 +854,7 @@ class RNNOp { } DType* work_cpu_space = static_cast(temp_cpu_space_.data().dptr_); - if (ctx.is_train) { + if (ctx.is_train || ctx.need_grad) { // allocate reserve space const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction,