Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Adopt autograd.record() context to RNNOp (#16657)
Browse files Browse the repository at this point in the history
  • Loading branch information
zixuanweeei authored and pengzhao-intel committed Oct 29, 2019
1 parent 5eb89fc commit c97ca15
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 82 deletions.
26 changes: 16 additions & 10 deletions src/operator/nn/mkldnn/mkldnn_rnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,9 @@ void MKLDNNRnnOp::Init(const OpContext &ctx,
using memory = mkldnn::memory;
using format_tag = mkldnn::memory::format_tag;

// In the `autograd.record()` context, RNNOp is required to run into
// `forward_training` mode.
const bool is_training = (ctx.is_train || ctx.need_grad);
const size_t num_fusion = full_param_.layer_params.size();
if (fwd_inf_vec_.size() < num_fusion) {
size_t buffer_size = 0; // Element number, instead of bytes, in the buffer
Expand All @@ -635,7 +638,7 @@ void MKLDNNRnnOp::Init(const OpContext &ctx,
mgr_.Init(buffer_size, ctx.run_ctx.ctx, inputs[rnn_enum::kParams].dtype());
}

if (ctx.is_train && fwd_trn_vec_.size() < num_fusion) {
if (is_training && fwd_trn_vec_.size() < num_fusion) {
for (auto& layer_param : full_param_.layer_params) {
fwd_trn_vec_.emplace_back(layer_param,
true, inputs[rnn_enum::kData], inputs[rnn_enum::kParams]);
Expand All @@ -659,13 +662,13 @@ void MKLDNNRnnOp::Init(const OpContext &ctx,
size_t layer_weights_bytes = single_w_bytes * directions;
size_t layer_bias_bytes = single_b_bytes * directions; // Naive MXNet has double bias

if (!fwd_layer.IsInitialized() || ctx.is_train)
if (!fwd_layer.IsInitialized() || is_training)
fwd_layer.SetWeightsMem(&(this->mgr_), weights_ptr, bias_ptr, dtype);
weights_ptr += layer_weights_bytes;
bias_ptr += layer_bias_bytes;
}

if (ctx.is_train) {
if (is_training) {
CHECK_EQ(fwd_trn_vec_.size(), fwd_inf_vec_.size()) <<
"Layers' configurations of forward inference and forward training are disparate.";
for (size_t lyr = 0; lyr < fwd_inf_vec_.size(); ++lyr)
Expand Down Expand Up @@ -898,6 +901,9 @@ void MKLDNNRnnOp::Forward(const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
// In the `autograd.record()` context, RNNOp is required to run into
// forward_training mode.
const bool is_training = (ctx.is_train || ctx.need_grad);
// check output requests
if (kAddTo == req[rnn_enum::kOut])
LOG(FATAL) << "Currently, `add` operation is not supported by RNNs.";
Expand All @@ -922,7 +928,7 @@ void MKLDNNRnnOp::Forward(const OpContext &ctx,
weights_version_ = inputs[rnn_enum::kParams].version();
}

if (!initialized_ || ctx.is_train || fwd_trn_vec_.size() == 0) {
if (!initialized_ || is_training || fwd_trn_vec_.size() == 0) {
Init(ctx, inputs, req, outputs);
}

Expand Down Expand Up @@ -952,7 +958,7 @@ void MKLDNNRnnOp::Forward(const OpContext &ctx,
if (fwd_inf_vec_.size() == 1) {
fwd_inf_vec_.front().SetNewDataMem(src, src_state, src_state_cell,
dst, dst_state, dst_state_cell, data_dtype);
if (ctx.is_train) {
if (is_training) {
fwd_trn_vec_.front().FetchData(fwd_inf_vec_.front());
}
} else {
Expand All @@ -964,7 +970,7 @@ void MKLDNNRnnOp::Forward(const OpContext &ctx,
// results in this->xxx, used as the source input of the next layer.
fwd_inf_vec_.front().SetNewDataMem(src, src_state, src_state_cell,
this->dst_.front()->get_data_handle(), dst_state, dst_state_cell, data_dtype);
if (ctx.is_train) {
if (is_training) {
fwd_trn_vec_.front().FetchData(fwd_inf_vec_.front());
}
// 1st_lyr -> dst_handle -> next_lyr -> dst_handle -> next_lyr -> ...
Expand All @@ -976,7 +982,7 @@ void MKLDNNRnnOp::Forward(const OpContext &ctx,
fwd_inf_vec_.at(lyr).SetNewDataMem(this->dst_.at(lyr - 1)->get_data_handle(),
src_state, src_state_cell,
this->dst_.at(lyr)->get_data_handle(), dst_state, dst_state_cell, data_dtype);
if (ctx.is_train) {
if (is_training) {
fwd_trn_vec_.at(lyr).FetchData(fwd_inf_vec_.at(lyr));
}
}
Expand All @@ -987,11 +993,11 @@ void MKLDNNRnnOp::Forward(const OpContext &ctx,
if (dst_state_cell) dst_state_cell += cell_bytes;
fwd_inf_vec_.back().SetNewDataMem(this->dst_.back()->get_data_handle(),
src_state, src_state_cell, dst, dst_state, dst_state_cell, data_dtype);
if (ctx.is_train) {
if (is_training) {
fwd_trn_vec_.back().FetchData(fwd_inf_vec_.back());
}
}
if (ctx.is_train) {
if (is_training) {
for (auto& trn_lyr : fwd_trn_vec_) RegisterMKLDNNRnn(trn_lyr);
} else {
for (auto& inf_lyr : fwd_inf_vec_) RegisterMKLDNNRnn(inf_lyr);
Expand Down Expand Up @@ -1092,7 +1098,7 @@ void MKLDNNRnnOp::Backward(const OpContext& ctx,
dy, dhy, dcy, data_dtype);

for (std::vector<MKLDNNRnnBackward>::const_reverse_iterator bwd = bwd_vec_.rbegin();
bwd < bwd_vec_.rend(); ++bwd) {
bwd != bwd_vec_.rend(); ++bwd) {
RegisterMKLDNNRnn(*bwd);
}
}
Expand Down
144 changes: 72 additions & 72 deletions src/operator/rnn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -422,87 +422,87 @@ class RNNOp {

if (ctx_.dev_type == kGPU) {
#if MXNET_USE_CUDNN == 1
init_cudnn_ = false;
dtype_ = mshadow::DataType<DType>::kCudnnFlag;
// TensorCore algos only allowed on fp16-I/O convolutions if permitted by the global policy.
// No tests in place for fp16 RNNs, so leave TensorCore disabled for now.
cudnn_tensor_core_ = false;
// When fp16 RNN tests are introduced, we can enable TensorCore as follows:
// cudnn_tensor_core =
// mshadow::DataType<DType>::kFlag == mshadow::kFloat16 && GetEnvAllowTensorCore();
// Defaults
input_mode_ = CUDNN_LINEAR_INPUT; // Don't support this yet
// RNN Mode
switch (param_.mode) {
case rnn_enum::kRnnRelu:
mode_ = CUDNN_RNN_RELU;
break;
case rnn_enum::kRnnTanh:
mode_ = CUDNN_RNN_TANH;
break;
case rnn_enum::kLstm:
mode_ = CUDNN_LSTM;
break;
case rnn_enum::kGru:
mode_ = CUDNN_GRU;
break;
default:
LOG(FATAL) << "Not implmented";
}
init_cudnn_ = false;
dtype_ = mshadow::DataType<DType>::kCudnnFlag;
// TensorCore algos only allowed on fp16-I/O convolutions if permitted by the global policy.
// No tests in place for fp16 RNNs, so leave TensorCore disabled for now.
cudnn_tensor_core_ = false;
// When fp16 RNN tests are introduced, we can enable TensorCore as follows:
// cudnn_tensor_core =
// mshadow::DataType<DType>::kFlag == mshadow::kFloat16 && GetEnvAllowTensorCore();
// Defaults
input_mode_ = CUDNN_LINEAR_INPUT; // Don't support this yet
// RNN Mode
switch (param_.mode) {
case rnn_enum::kRnnRelu:
mode_ = CUDNN_RNN_RELU;
break;
case rnn_enum::kRnnTanh:
mode_ = CUDNN_RNN_TANH;
break;
case rnn_enum::kLstm:
mode_ = CUDNN_LSTM;
break;
case rnn_enum::kGru:
mode_ = CUDNN_GRU;
break;
default:
LOG(FATAL) << "Not implmented";
}
#if MXNET_USE_CUDNN_GE_7200
if (param_.projection_size.has_value()) {
CHECK_EQ(param_.mode, rnn_enum::kLstm)
<< "Projection is only supported for LSTM.";
CHECK_GE(param_.state_size, param_.projection_size.value())
<< "State size must be larger than projection size.";
}
if (param_.projection_size.has_value()) {
CHECK_EQ(param_.mode, rnn_enum::kLstm)
<< "Projection is only supported for LSTM.";
CHECK_GE(param_.state_size, param_.projection_size.value())
<< "State size must be larger than projection size.";
}
#else
CHECK(!param_.projection_size.has_value())
<< "Projection is only supported for LSTM with CuDNN version later than 7.1.1.";
CHECK(!param_.projection_size.has_value())
<< "Projection is only supported for LSTM with CuDNN version later than 7.1.1.";
#endif // MXNET_USE_CUDNN_GE_7200
#if MXNET_USE_CUDNN_GE_7200
if (param_.lstm_state_clip_min.has_value()
|| param_.lstm_state_clip_max.has_value()) {
CHECK_EQ(param_.mode, rnn_enum::kLstm)
<< "State clipping is only supported for LSTM.";
CHECK(param_.lstm_state_clip_min.has_value() && param_.lstm_state_clip_max.has_value())
<< "lstm_state_clip_min and lstm_state_clip_max must be specified together.";
CHECK_GE(param_.lstm_state_clip_max.value(), param_.lstm_state_clip_min.value())
<< "lstm_state_clip_max must be greater or equal to lstm_state_clip_min";
}
if (param_.lstm_state_clip_min.has_value()
|| param_.lstm_state_clip_max.has_value()) {
CHECK_EQ(param_.mode, rnn_enum::kLstm)
<< "State clipping is only supported for LSTM.";
CHECK(param_.lstm_state_clip_min.has_value() && param_.lstm_state_clip_max.has_value())
<< "lstm_state_clip_min and lstm_state_clip_max must be specified together.";
CHECK_GE(param_.lstm_state_clip_max.value(), param_.lstm_state_clip_min.value())
<< "lstm_state_clip_max must be greater or equal to lstm_state_clip_min";
}
#else
CHECK(!param_.lstm_state_clip_min.has_value()
&& !param_.lstm_state_clip_max.has_value())
<< "State clipping is only supported for LSTM with CuDNN version later than 7.2.1.";
CHECK(!param_.lstm_state_clip_min.has_value()
&& !param_.lstm_state_clip_max.has_value())
<< "State clipping is only supported for LSTM with CuDNN version later than 7.2.1.";
#endif // MXNET_USE_CUDNN_GE_7200
// RNN Direction
direction_ = param_.bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL;
// Create descriptors
CUDNN_CALL(cudnnCreateTensorDescriptor(&hx_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&cx_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&hy_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&cy_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&dhx_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&dcx_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&dhy_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&dcy_desc_));

CUDNN_CALL(cudnnCreateFilterDescriptor(&w_desc_));
CUDNN_CALL(cudnnCreateFilterDescriptor(&dw_desc_));

CUDNN_CALL(cudnnCreateRNNDescriptor(&rnn_desc_));
CUDNN_CALL(cudnnCreateDropoutDescriptor(&dropout_desc_));
// RNN Direction
direction_ = param_.bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL;
// Create descriptors
CUDNN_CALL(cudnnCreateTensorDescriptor(&hx_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&cx_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&hy_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&cy_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&dhx_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&dcx_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&dhy_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&dcy_desc_));

CUDNN_CALL(cudnnCreateFilterDescriptor(&w_desc_));
CUDNN_CALL(cudnnCreateFilterDescriptor(&dw_desc_));

CUDNN_CALL(cudnnCreateRNNDescriptor(&rnn_desc_));
CUDNN_CALL(cudnnCreateDropoutDescriptor(&dropout_desc_));

#if MXNET_USE_CUDNN_GE_7200
CUDNN_CALL(cudnnCreateRNNDataDescriptor(&x_data_desc_));
CUDNN_CALL(cudnnCreateRNNDataDescriptor(&y_data_desc_));
CUDNN_CALL(cudnnCreateRNNDataDescriptor(&dx_data_desc_));
CUDNN_CALL(cudnnCreateRNNDataDescriptor(&dy_data_desc_));
CUDNN_CALL(cudnnCreateRNNDataDescriptor(&x_data_desc_));
CUDNN_CALL(cudnnCreateRNNDataDescriptor(&y_data_desc_));
CUDNN_CALL(cudnnCreateRNNDataDescriptor(&dx_data_desc_));
CUDNN_CALL(cudnnCreateRNNDataDescriptor(&dy_data_desc_));
#endif // MXNET_USE_CUDNN_GE_7200
#else
if (ctx_.dev_type == kGPU) {
LOG(FATAL) << "RNN on GPU is only available for cuDNN at the moment.";
}
if (ctx_.dev_type == kGPU) {
LOG(FATAL) << "RNN on GPU is only available for cuDNN at the moment.";
}
#endif // MXNET_USE_CUDNN == 1
}

Expand Down Expand Up @@ -854,7 +854,7 @@ class RNNOp {
}
DType* work_cpu_space = static_cast<DType*>(temp_cpu_space_.data().dptr_);

if (ctx.is_train) {
if (ctx.is_train || ctx.need_grad) {
// allocate reserve space

const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction,
Expand Down

0 comments on commit c97ca15

Please sign in to comment.