Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[mkldnn-v1.0] Adopt autograd.record() context to RNNOp #16657

Merged
merged 1 commit into from
Oct 29, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions src/operator/nn/mkldnn/mkldnn_rnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,9 @@ void MKLDNNRnnOp::Init(const OpContext &ctx,
using memory = mkldnn::memory;
using format_tag = mkldnn::memory::format_tag;

// In the `autograd.record()` context, RNNOp is required to run into
// `forward_training` mode.
const bool is_training = (ctx.is_train || ctx.need_grad);
const size_t num_fusion = full_param_.layer_params.size();
if (fwd_inf_vec_.size() < num_fusion) {
size_t buffer_size = 0; // Element number, instead of bytes, in the buffer
Expand All @@ -635,7 +638,7 @@ void MKLDNNRnnOp::Init(const OpContext &ctx,
mgr_.Init(buffer_size, ctx.run_ctx.ctx, inputs[rnn_enum::kParams].dtype());
}

if (ctx.is_train && fwd_trn_vec_.size() < num_fusion) {
if (is_training && fwd_trn_vec_.size() < num_fusion) {
for (auto& layer_param : full_param_.layer_params) {
fwd_trn_vec_.emplace_back(layer_param,
true, inputs[rnn_enum::kData], inputs[rnn_enum::kParams]);
Expand All @@ -659,13 +662,13 @@ void MKLDNNRnnOp::Init(const OpContext &ctx,
size_t layer_weights_bytes = single_w_bytes * directions;
size_t layer_bias_bytes = single_b_bytes * directions; // Naive MXNet has double bias

if (!fwd_layer.IsInitialized() || ctx.is_train)
if (!fwd_layer.IsInitialized() || is_training)
fwd_layer.SetWeightsMem(&(this->mgr_), weights_ptr, bias_ptr, dtype);
weights_ptr += layer_weights_bytes;
bias_ptr += layer_bias_bytes;
}

if (ctx.is_train) {
if (is_training) {
CHECK_EQ(fwd_trn_vec_.size(), fwd_inf_vec_.size()) <<
"Layers' configurations of forward inference and forward training are disparate.";
for (size_t lyr = 0; lyr < fwd_inf_vec_.size(); ++lyr)
Expand Down Expand Up @@ -898,6 +901,9 @@ void MKLDNNRnnOp::Forward(const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
// In the `autograd.record()` context, RNNOp is required to run into
// forward_training mode.
const bool is_training = (ctx.is_train || ctx.need_grad);
// check output requests
if (kAddTo == req[rnn_enum::kOut])
LOG(FATAL) << "Currently, `add` operation is not supported by RNNs.";
Expand All @@ -922,7 +928,7 @@ void MKLDNNRnnOp::Forward(const OpContext &ctx,
weights_version_ = inputs[rnn_enum::kParams].version();
}

if (!initialized_ || ctx.is_train || fwd_trn_vec_.size() == 0) {
if (!initialized_ || is_training || fwd_trn_vec_.size() == 0) {
Init(ctx, inputs, req, outputs);
}

Expand Down Expand Up @@ -952,7 +958,7 @@ void MKLDNNRnnOp::Forward(const OpContext &ctx,
if (fwd_inf_vec_.size() == 1) {
fwd_inf_vec_.front().SetNewDataMem(src, src_state, src_state_cell,
dst, dst_state, dst_state_cell, data_dtype);
if (ctx.is_train) {
if (is_training) {
fwd_trn_vec_.front().FetchData(fwd_inf_vec_.front());
}
} else {
Expand All @@ -964,7 +970,7 @@ void MKLDNNRnnOp::Forward(const OpContext &ctx,
// results in this->xxx, used as the source input of the next layer.
fwd_inf_vec_.front().SetNewDataMem(src, src_state, src_state_cell,
this->dst_.front()->get_data_handle(), dst_state, dst_state_cell, data_dtype);
if (ctx.is_train) {
if (is_training) {
fwd_trn_vec_.front().FetchData(fwd_inf_vec_.front());
}
// 1st_lyr -> dst_handle -> next_lyr -> dst_handle -> next_lyr -> ...
Expand All @@ -976,7 +982,7 @@ void MKLDNNRnnOp::Forward(const OpContext &ctx,
fwd_inf_vec_.at(lyr).SetNewDataMem(this->dst_.at(lyr - 1)->get_data_handle(),
src_state, src_state_cell,
this->dst_.at(lyr)->get_data_handle(), dst_state, dst_state_cell, data_dtype);
if (ctx.is_train) {
if (is_training) {
fwd_trn_vec_.at(lyr).FetchData(fwd_inf_vec_.at(lyr));
}
}
Expand All @@ -987,11 +993,11 @@ void MKLDNNRnnOp::Forward(const OpContext &ctx,
if (dst_state_cell) dst_state_cell += cell_bytes;
fwd_inf_vec_.back().SetNewDataMem(this->dst_.back()->get_data_handle(),
src_state, src_state_cell, dst, dst_state, dst_state_cell, data_dtype);
if (ctx.is_train) {
if (is_training) {
fwd_trn_vec_.back().FetchData(fwd_inf_vec_.back());
}
}
if (ctx.is_train) {
if (is_training) {
for (auto& trn_lyr : fwd_trn_vec_) RegisterMKLDNNRnn(trn_lyr);
} else {
for (auto& inf_lyr : fwd_inf_vec_) RegisterMKLDNNRnn(inf_lyr);
Expand Down Expand Up @@ -1092,7 +1098,7 @@ void MKLDNNRnnOp::Backward(const OpContext& ctx,
dy, dhy, dcy, data_dtype);

for (std::vector<MKLDNNRnnBackward>::const_reverse_iterator bwd = bwd_vec_.rbegin();
bwd < bwd_vec_.rend(); ++bwd) {
bwd != bwd_vec_.rend(); ++bwd) {
RegisterMKLDNNRnn(*bwd);
}
}
Expand Down
144 changes: 72 additions & 72 deletions src/operator/rnn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -422,87 +422,87 @@ class RNNOp {

if (ctx_.dev_type == kGPU) {
#if MXNET_USE_CUDNN == 1
init_cudnn_ = false;
dtype_ = mshadow::DataType<DType>::kCudnnFlag;
// TensorCore algos only allowed on fp16-I/O convolutions if permitted by the global policy.
// No tests in place for fp16 RNNs, so leave TensorCore disabled for now.
cudnn_tensor_core_ = false;
// When fp16 RNN tests are introduced, we can enable TensorCore as follows:
// cudnn_tensor_core =
// mshadow::DataType<DType>::kFlag == mshadow::kFloat16 && GetEnvAllowTensorCore();
// Defaults
input_mode_ = CUDNN_LINEAR_INPUT; // Don't support this yet
// RNN Mode
switch (param_.mode) {
case rnn_enum::kRnnRelu:
mode_ = CUDNN_RNN_RELU;
break;
case rnn_enum::kRnnTanh:
mode_ = CUDNN_RNN_TANH;
break;
case rnn_enum::kLstm:
mode_ = CUDNN_LSTM;
break;
case rnn_enum::kGru:
mode_ = CUDNN_GRU;
break;
default:
LOG(FATAL) << "Not implmented";
}
init_cudnn_ = false;
dtype_ = mshadow::DataType<DType>::kCudnnFlag;
// TensorCore algos only allowed on fp16-I/O convolutions if permitted by the global policy.
// No tests in place for fp16 RNNs, so leave TensorCore disabled for now.
cudnn_tensor_core_ = false;
// When fp16 RNN tests are introduced, we can enable TensorCore as follows:
// cudnn_tensor_core =
// mshadow::DataType<DType>::kFlag == mshadow::kFloat16 && GetEnvAllowTensorCore();
// Defaults
input_mode_ = CUDNN_LINEAR_INPUT; // Don't support this yet
// RNN Mode
switch (param_.mode) {
case rnn_enum::kRnnRelu:
mode_ = CUDNN_RNN_RELU;
break;
case rnn_enum::kRnnTanh:
mode_ = CUDNN_RNN_TANH;
break;
case rnn_enum::kLstm:
mode_ = CUDNN_LSTM;
break;
case rnn_enum::kGru:
mode_ = CUDNN_GRU;
break;
default:
LOG(FATAL) << "Not implmented";
}
#if MXNET_USE_CUDNN_GE_7200
if (param_.projection_size.has_value()) {
CHECK_EQ(param_.mode, rnn_enum::kLstm)
<< "Projection is only supported for LSTM.";
CHECK_GE(param_.state_size, param_.projection_size.value())
<< "State size must be larger than projection size.";
}
if (param_.projection_size.has_value()) {
CHECK_EQ(param_.mode, rnn_enum::kLstm)
<< "Projection is only supported for LSTM.";
CHECK_GE(param_.state_size, param_.projection_size.value())
<< "State size must be larger than projection size.";
}
#else
CHECK(!param_.projection_size.has_value())
<< "Projection is only supported for LSTM with CuDNN version later than 7.1.1.";
CHECK(!param_.projection_size.has_value())
<< "Projection is only supported for LSTM with CuDNN version later than 7.1.1.";
#endif // MXNET_USE_CUDNN_GE_7200
#if MXNET_USE_CUDNN_GE_7200
if (param_.lstm_state_clip_min.has_value()
|| param_.lstm_state_clip_max.has_value()) {
CHECK_EQ(param_.mode, rnn_enum::kLstm)
<< "State clipping is only supported for LSTM.";
CHECK(param_.lstm_state_clip_min.has_value() && param_.lstm_state_clip_max.has_value())
<< "lstm_state_clip_min and lstm_state_clip_max must be specified together.";
CHECK_GE(param_.lstm_state_clip_max.value(), param_.lstm_state_clip_min.value())
<< "lstm_state_clip_max must be greater or equal to lstm_state_clip_min";
}
if (param_.lstm_state_clip_min.has_value()
|| param_.lstm_state_clip_max.has_value()) {
CHECK_EQ(param_.mode, rnn_enum::kLstm)
<< "State clipping is only supported for LSTM.";
CHECK(param_.lstm_state_clip_min.has_value() && param_.lstm_state_clip_max.has_value())
<< "lstm_state_clip_min and lstm_state_clip_max must be specified together.";
CHECK_GE(param_.lstm_state_clip_max.value(), param_.lstm_state_clip_min.value())
<< "lstm_state_clip_max must be greater or equal to lstm_state_clip_min";
}
#else
CHECK(!param_.lstm_state_clip_min.has_value()
&& !param_.lstm_state_clip_max.has_value())
<< "State clipping is only supported for LSTM with CuDNN version later than 7.2.1.";
CHECK(!param_.lstm_state_clip_min.has_value()
&& !param_.lstm_state_clip_max.has_value())
<< "State clipping is only supported for LSTM with CuDNN version later than 7.2.1.";
#endif // MXNET_USE_CUDNN_GE_7200
// RNN Direction
direction_ = param_.bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL;
// Create descriptors
CUDNN_CALL(cudnnCreateTensorDescriptor(&hx_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&cx_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&hy_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&cy_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&dhx_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&dcx_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&dhy_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&dcy_desc_));

CUDNN_CALL(cudnnCreateFilterDescriptor(&w_desc_));
CUDNN_CALL(cudnnCreateFilterDescriptor(&dw_desc_));

CUDNN_CALL(cudnnCreateRNNDescriptor(&rnn_desc_));
CUDNN_CALL(cudnnCreateDropoutDescriptor(&dropout_desc_));
// RNN Direction
direction_ = param_.bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL;
// Create descriptors
CUDNN_CALL(cudnnCreateTensorDescriptor(&hx_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&cx_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&hy_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&cy_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&dhx_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&dcx_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&dhy_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&dcy_desc_));

CUDNN_CALL(cudnnCreateFilterDescriptor(&w_desc_));
CUDNN_CALL(cudnnCreateFilterDescriptor(&dw_desc_));

CUDNN_CALL(cudnnCreateRNNDescriptor(&rnn_desc_));
CUDNN_CALL(cudnnCreateDropoutDescriptor(&dropout_desc_));

#if MXNET_USE_CUDNN_GE_7200
CUDNN_CALL(cudnnCreateRNNDataDescriptor(&x_data_desc_));
CUDNN_CALL(cudnnCreateRNNDataDescriptor(&y_data_desc_));
CUDNN_CALL(cudnnCreateRNNDataDescriptor(&dx_data_desc_));
CUDNN_CALL(cudnnCreateRNNDataDescriptor(&dy_data_desc_));
CUDNN_CALL(cudnnCreateRNNDataDescriptor(&x_data_desc_));
CUDNN_CALL(cudnnCreateRNNDataDescriptor(&y_data_desc_));
CUDNN_CALL(cudnnCreateRNNDataDescriptor(&dx_data_desc_));
CUDNN_CALL(cudnnCreateRNNDataDescriptor(&dy_data_desc_));
#endif // MXNET_USE_CUDNN_GE_7200
#else
if (ctx_.dev_type == kGPU) {
LOG(FATAL) << "RNN on GPU is only available for cuDNN at the moment.";
}
if (ctx_.dev_type == kGPU) {
LOG(FATAL) << "RNN on GPU is only available for cuDNN at the moment.";
}
#endif // MXNET_USE_CUDNN == 1
}

Expand Down Expand Up @@ -854,7 +854,7 @@ class RNNOp {
}
DType* work_cpu_space = static_cast<DType*>(temp_cpu_space_.data().dptr_);

if (ctx.is_train) {
if (ctx.is_train || ctx.need_grad) {
// allocate reserve space

const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction,
Expand Down