From c91c9e1afd68cec1a812548f92714eb775c6aa30 Mon Sep 17 00:00:00 2001 From: zixuanweeei Date: Fri, 27 Dec 2019 15:02:56 +0800 Subject: [PATCH] MKL-DNN RNN backward path enhancement * Flush memory before RNN backward primitive * Add gluon rnn unit test for gradients check * Cache reorder * Re-write rnn supporting check * Update OpSignature.AddSign to avoid potential hash collision for rnn-packed memory --- src/operator/nn/mkldnn/mkldnn_base-inl.h | 9 +- src/operator/nn/mkldnn/mkldnn_rnn-inl.h | 11 +- src/operator/nn/mkldnn/mkldnn_rnn.cc | 352 +++++++++++------------ src/operator/operator_common.h | 18 ++ src/operator/rnn.cc | 6 +- tests/python/unittest/test_gluon_rnn.py | 124 ++++++++ 6 files changed, 324 insertions(+), 196 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index da9b8166a849..ad935556065b 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -132,9 +132,12 @@ static inline bool SupportMKLDNN(int dtype, const mxnet::TShape &shape) { return dtype == mshadow::kFloat32 && (ndim == 1 || ndim == 2 || ndim == 4); } -static inline bool SupportMKLDNNRNN(const NDArray &input) { - int ndim = input.shape().ndim(); - return (input.dtype() == mshadow::kFloat32) && (ndim == 3); +static inline bool SupportMKLDNNRnn(const NDArray &input) { + if (input.dtype() == mshadow::kFloat32 && input.shape().ndim() == 3 + && dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1)) { + return true; + } + return false; } static inline bool SupportMKLDNNQuantize(int dtype) { diff --git a/src/operator/nn/mkldnn/mkldnn_rnn-inl.h b/src/operator/nn/mkldnn/mkldnn_rnn-inl.h index 314106b98eb9..2b54b9f5bbc2 100644 --- a/src/operator/nn/mkldnn/mkldnn_rnn-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_rnn-inl.h @@ -365,13 +365,14 @@ class MKLDNNRnnBackward { fwd_ptr_(&fwd) { } void FetchDataWeightsMem(const MKLDNNRnnForwardTraining& fwd); - void SetWeightsGradsMem(); + void SetWeightsGradsMem(const int dtype = mshadow::kFloat32); void SetDataGradsMem(void* diff_src, void* diff_state, void* diff_statecell, void* diff_out, void* diff_state_out, void* diff_statecell_out, const int dtype = mshadow::kFloat32); - void CommitWeightsDiff(void* diff_weights, void* diff_bias, - const OpReqType req, - const int dtype = mshadow::kFloat32); + void SetNativeWeightsGrads() const; + void CommitWeightsGrads(void* diff_weights, void* diff_bias, + const OpReqType req, + const int dtype = mshadow::kFloat32); const mkldnn::primitive& GetBwd() const { return *bwd_.primitive_; } const mkldnn_args_map_t& GetArgsMap() const { return net_args_; } @@ -386,6 +387,8 @@ class MKLDNNRnnBackward { mkldnn_shared_mem_t diff_weights_layer_; mkldnn_shared_mem_t diff_weights_iter_; + mkldnn_shared_mem_t diff_weights_layer_r_; + mkldnn_shared_mem_t diff_weights_iter_r_; mkldnn_shared_mem_t diff_bias_; mkldnn_args_map_t net_args_; diff --git a/src/operator/nn/mkldnn/mkldnn_rnn.cc b/src/operator/nn/mkldnn/mkldnn_rnn.cc index 6da8f3b8a58a..294d36b0dfc3 100644 --- a/src/operator/nn/mkldnn/mkldnn_rnn.cc +++ b/src/operator/nn/mkldnn/mkldnn_rnn.cc @@ -364,18 +364,38 @@ void MKLDNNRnnForward::SetNewDataMem(void* x, void* hx, void* cx, } } +inline void MKLDNNMemoryReorder(const mkldnn::memory& src, + const mkldnn::memory& dst) { +#if DMLC_CXX11_THREAD_LOCAL + static thread_local std::unordered_map reorderPrimitives; +#else + static MX_THREAD_LOCAL std::unordered_map reorderPrimitives; +#endif + OpSignature key{}; + key.AddSign(src); + key.AddSign(dst); + + auto it = reorderPrimitives.find(key); + if (it == reorderPrimitives.end()) { + auto reorder = mkldnn::reorder(src, dst); + it = AddToCache(&reorderPrimitives, key, reorder); + } + + mkldnn_args_map_t net_args; + net_args.emplace(MKLDNN_ARG_SRC, src); + net_args.emplace(MKLDNN_ARG_DST, dst); + MKLDNNStream::Get()->RegisterPrimArgs(it->second, net_args); +} + /* * Reorder the concatenated weights memory to a efficient memory block * with primitive-prefered format. */ void MKLDNNRnnForward::ReorderWeights() { - auto& cpu_engine = CpuEngine::Get()->get_engine(); - mkldnn::stream s(cpu_engine); - mkldnn::reorder(*weights_layer_r_, *weights_layer_) - .execute(s, *weights_layer_r_, *weights_layer_); - mkldnn::reorder(*weights_iter_r_, *weights_iter_) - .execute(s, *weights_iter_r_, *weights_iter_); - s.wait(); + MKLDNNMemoryReorder(*weights_layer_r_, *weights_layer_); + MKLDNNMemoryReorder(*weights_iter_r_, *weights_iter_); } void AdjustGruGateOrder(char* weight, @@ -394,7 +414,7 @@ void AdjustGruGateOrder(char* weight, * Fuse uni-directional bias among single layer. */ template -void FuseBias(DType* fuse_bias, DType* naive_bias, +void FuseBias(DType* fuse_bias, DType* native_bias, const int mode, const size_t state_size) { const size_t ngates = GetRnnGatesNum(mode); const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); @@ -403,8 +423,8 @@ void FuseBias(DType* fuse_bias, DType* naive_bias, // OpenMP 'for' statement. const int state_size_ = static_cast(state_size); const int single_b_sz = static_cast(nbias * state_size); - DType* bx = naive_bias; - DType* bh = naive_bias + state_size * ngates; + DType* bx = native_bias; + DType* bh = native_bias + state_size * ngates; if (mode == rnn_enum::kGru) { // While mxnet gru gate order is reset, update and new gates, // mkldnn gru gate order is update, reset and new gates. So @@ -528,12 +548,6 @@ void MKLDNNRnnForward::SetWeightsMem(MKLDNNRnnMemMgr* mgr, void *w_ptr, void *b_ } } } - // Reorder after adjustment only when is_train == false. When is_train == true, i.e. - // in forward training path, we use plain memory (ldxxx) as the space for weights and - // their gradients. Then, forward training primitives could fetch them from the scope - // of forward inference. And from there, we don't need to reorder the plain memory to - // the optimal rnn-packed memory for forward inference. - ReorderWeights(); // Process bias MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { @@ -553,7 +567,15 @@ void MKLDNNRnnForward::SetWeightsMem(MKLDNNRnnMemMgr* mgr, void *w_ptr, void *b_ EmplaceNetArgs(&this->net_args_, MKLDNN_ARG_WEIGHTS_ITER, this->weights_iter_); EmplaceNetArgs(&this->net_args_, MKLDNN_ARG_BIAS, this->bias_); - initialized_ = true; + if (!is_train) { + // Reorder after adjustment only when is_train == false. When is_train == true, i.e. + // in forward training path, we use plain memory (ldxxx) as the space for weights and + // their gradients. Then, forward training primitives could fetch them from the scope + // of forward inference. And from there, we don't need to reorder the plain memory to + // the optimal rnn-packed memory for forward inference. + ReorderWeights(); + initialized_ = true; + } } void MKLDNNRnnForwardTraining::SetTrnMem(const MKLDNNRnnForward& fwd) { @@ -572,17 +594,14 @@ void MKLDNNRnnForwardTraining::SetTrnMem(const MKLDNNRnnForward& fwd) { if (fwd.weights_layer_r_->get_desc() == fwd_trn_.GetLayerDesc()) { weights_layer_->set_data_handle(fwd.weights_layer_r_->get_data_handle()); } else { - mkldnn::reorder(*fwd.weights_layer_r_, *weights_layer_) - .execute(s, *fwd.weights_layer_r_, *weights_layer_); + MKLDNNMemoryReorder(*fwd.weights_layer_r_, *weights_layer_); } if (fwd.weights_iter_r_->get_desc() == fwd_trn_.GetIterDesc()) { weights_iter_->set_data_handle(fwd.weights_iter_r_->get_data_handle()); } else { - mkldnn::reorder(*fwd.weights_iter_r_, *weights_iter_) - .execute(s, *fwd.weights_iter_r_, *weights_iter_); + MKLDNNMemoryReorder(*fwd.weights_iter_r_, *weights_iter_); } - s.wait(); // bias are always in format_tag::ldgo this->bias_ = fwd.bias_; @@ -687,18 +706,17 @@ void MKLDNNRnnOp::Init(const OpContext &ctx, {fwd->GetParam().dst_dims, get_mkldnn_type(data_dtype), format_tag::tnc})); } - initialized_ = true; + if (!is_training) initialized_ = true; } void MKLDNNRnnBackward::FetchDataWeightsMem(const MKLDNNRnnForwardTraining& fwd) { using memory = mkldnn::memory; auto& cpu_engine = CpuEngine::Get()->get_engine(); - auto s = mkldnn::stream(cpu_engine); - if (this->weights_layer_ == nullptr) + if (this->weights_layer_ == nullptr || this-> weights_iter_ == nullptr) { this->weights_layer_ = mkldnn_shared_mem_t(new memory(bwd_.weights_layer_desc_, cpu_engine)); - if (this->weights_iter_ == nullptr) this->weights_iter_ = mkldnn_shared_mem_t(new memory(bwd_.weights_iter_desc_, cpu_engine)); + } for (auto& kv : fwd.net_args_) { const mkldnn::memory* valid_mem; @@ -707,17 +725,15 @@ void MKLDNNRnnBackward::FetchDataWeightsMem(const MKLDNNRnnForwardTraining& fwd) if (bwd_.weights_layer_desc_ == fwd.fwd_trn_.GetLayerDesc()) { this->weights_layer_->set_data_handle(kv.second.get_data_handle()); } else { - mkldnn::reorder(*fwd.weights_layer_, *this->weights_layer_) - .execute(s, *fwd.weights_layer_, *this->weights_layer_); + MKLDNNMemoryReorder(*fwd.weights_layer_, *this->weights_layer_); } valid_mem = this->weights_layer_.get(); } break; case MKLDNN_ARG_WEIGHTS_ITER: { - if (bwd_.weights_iter_desc_ == fwd.fwd_trn_.GetLayerDesc()) { + if (bwd_.weights_iter_desc_ == fwd.fwd_trn_.GetIterDesc()) { this->weights_iter_->set_data_handle(kv.second.get_data_handle()); } else { - mkldnn::reorder(*fwd.weights_iter_, *this->weights_iter_) - .execute(s, *fwd.weights_iter_, *this->weights_iter_); + MKLDNNMemoryReorder(*fwd.weights_iter_, *this->weights_iter_); } valid_mem = this->weights_iter_.get(); } break; @@ -727,20 +743,49 @@ void MKLDNNRnnBackward::FetchDataWeightsMem(const MKLDNNRnnForwardTraining& fwd) } EmplaceNetArgs(&this->net_args_, kv.first, valid_mem); } - s.wait(); } -void MKLDNNRnnBackward::SetWeightsGradsMem() { - auto& cpu_engine = CpuEngine::Get()->get_engine(); - if (this->diff_weights_layer_ == nullptr) - this->diff_weights_layer_ = std::make_shared( - bwd_.diff_weights_layer_desc_, cpu_engine); - if (this->diff_weights_iter_ == nullptr) - this->diff_weights_iter_ = std::make_shared( - bwd_.diff_weights_iter_desc_, cpu_engine); - if (this->diff_bias_ == nullptr) +void MKLDNNRnnBackward::SetWeightsGradsMem(int dtype) { + using tag = mkldnn::memory::format_tag; + + if (this->diff_weights_layer_ == nullptr + || this->diff_weights_iter_ == nullptr + || this->diff_bias_ == nullptr) { + const auto& cpu_engine = CpuEngine::Get()->get_engine(); + const MKLDNNRnnLayerParam& param = fwd_ptr_->GetParam(); + const auto& mkldnn_type = get_mkldnn_type(dtype); + + auto native_layer_desc = mkldnn::memory::desc(param.weight_layer_dims, mkldnn_type, tag::ldgoi); + auto native_iter_desc = mkldnn::memory::desc(param.weight_iter_dims, mkldnn_type, tag::ldgoi); + + this->diff_weights_layer_r_ = std::make_shared( + native_layer_desc, cpu_engine); + this->diff_weights_iter_r_ = std::make_shared( + native_iter_desc, cpu_engine); + + if (native_layer_desc == bwd_.diff_weights_layer_desc_) { + this->diff_weights_layer_ = std::make_shared( + bwd_.diff_weights_layer_desc_, cpu_engine, diff_weights_layer_r_->get_data_handle()); + } else { + this->diff_weights_layer_ = std::make_shared( + bwd_.diff_weights_layer_desc_, cpu_engine); + } + if (native_iter_desc == bwd_.diff_weights_iter_desc_) { + this->diff_weights_iter_ = std::make_shared( + bwd_.diff_weights_iter_desc_, cpu_engine, diff_weights_iter_r_->get_data_handle()); + } else { + this->diff_weights_iter_ = std::make_shared( + bwd_.diff_weights_iter_desc_, cpu_engine); + } this->diff_bias_ = std::make_shared( bwd_.diff_bias_desc_, cpu_engine); + } + std::memset(this->diff_weights_layer_->get_data_handle(), 0, + bwd_.diff_weights_layer_desc_.get_size()); + std::memset(this->diff_weights_iter_->get_data_handle(), 0, + bwd_.diff_weights_iter_desc_.get_size()); + std::memset(this->diff_bias_->get_data_handle(), 0, + bwd_.diff_bias_desc_.get_size()); EmplaceNetArgs(&this->net_args_, MKLDNN_ARG_DIFF_WEIGHTS_LAYER, this->diff_weights_layer_.get()); EmplaceNetArgs(&this->net_args_, MKLDNN_ARG_DIFF_WEIGHTS_ITER, @@ -776,23 +821,40 @@ void MKLDNNRnnBackward::SetDataGradsMem( } } -void MKLDNNRnnBackward::CommitWeightsDiff(void* diff_weights, void* diff_bias, - const OpReqType req, const int dtype) { - using tag = mkldnn::memory::format_tag; - auto& cpu_engine = CpuEngine::Get()->get_engine(); - auto s = mkldnn::stream(cpu_engine); +void MKLDNNRnnBackward::SetNativeWeightsGrads() const { + if (this->diff_weights_layer_->get_desc() != this->diff_weights_layer_r_->get_desc()) { + MKLDNNMemoryReorder(*this->diff_weights_layer_, *this->diff_weights_layer_r_); + } + if (this->diff_weights_iter_->get_desc() != this->diff_weights_iter_r_->get_desc()) { + MKLDNNMemoryReorder(*this->diff_weights_iter_, *this->diff_weights_iter_r_); + } +} + +#define OPREQTYPE_SWITCH(ReqType, DType, FWrapper, ...) \ +std::function FWrapper = nullptr; \ +if (kWriteTo == ReqType || kWriteInplace == ReqType) \ + FWrapper = common::ParallelCopy; \ +else \ + FWrapper = common::ParallelAdd; \ +{__VA_ARGS__} +void MKLDNNRnnBackward::CommitWeightsGrads(void* diff_weights, void* diff_bias, + const OpReqType req, const int dtype) { const MKLDNNRnnLayerParam& param = fwd_ptr_->GetParam(); + + void* diff_weights_layer_ptr = this->diff_weights_layer_->get_data_handle(); + void* diff_weights_iter_ptr = this->diff_weights_iter_->get_data_handle(); + if (this->diff_weights_layer_->get_desc() != this->diff_weights_layer_r_->get_desc()) + diff_weights_layer_ptr = this->diff_weights_layer_r_->get_data_handle(); + if (this->diff_weights_iter_->get_desc() != this->diff_weights_iter_r_->get_desc()) + diff_weights_iter_ptr = this->diff_weights_iter_r_->get_data_handle(); + const int num_layer = param.num_layer; const int direction = param.bidirectional ? 2 : 1; const int ngates = GetRnnGatesNum(param.mode); - const size_t dtype_bytes = mshadow::mshadow_sizeof(dtype); const size_t wxh_size = param.single_w_size; const size_t wx_size = param.input_size * param.state_size * ngates; const size_t wh_size = param.state_size * param.state_size * ngates; - const size_t wxh_bytes = param.single_w_size * dtype_bytes; - const size_t wx_bytes = param.input_size * param.state_size * ngates * dtype_bytes; - const size_t wh_bytes = param.state_size * param.state_size * ngates * dtype_bytes; /* naive weights layout is: 1st-layer: | wx_lr | wh_lr | wx_rl | wh_rl | @@ -800,162 +862,81 @@ void MKLDNNRnnBackward::CommitWeightsDiff(void* diff_weights, void* diff_bias, size: | wxh_bytes | |wx_bytes|wh_bytes| */ - if (kWriteTo == req) { - char* naive_weights = static_cast(diff_weights); - char* diff_wx_ptr = static_cast(diff_weights_layer_->get_data_handle()); - char* diff_wh_ptr = static_cast(diff_weights_iter_->get_data_handle()); - if (param.mode != rnn_enum::kGru) { - for (int shift = 0; shift < num_layer * direction; ++shift) { - std::memcpy(naive_weights + shift * wxh_bytes, - diff_wx_ptr + shift * wx_bytes, wx_bytes); - } - // align naive_weights to weights_iter memory - naive_weights += wx_bytes; - for (int shift = 0; shift < num_layer * direction; ++shift) { - std::memcpy(naive_weights + shift * wxh_bytes, - diff_wh_ptr + shift * wh_bytes, wh_bytes); - } - } else { - const size_t wx_bytes_per_gate = param.input_size * param.state_size * dtype_bytes; - const size_t wh_bytes_per_gate = param.state_size * param.state_size * dtype_bytes; - for (int shift = 0; shift < num_layer * direction; ++shift) { - std::memcpy(naive_weights + shift * wxh_bytes + wx_bytes_per_gate, - diff_wx_ptr + shift * wx_bytes, wx_bytes_per_gate); - std::memcpy(naive_weights + shift * wxh_bytes, - diff_wx_ptr + shift * wx_bytes + wx_bytes_per_gate, wx_bytes_per_gate); - std::memcpy(naive_weights + shift * wxh_bytes + 2 * wx_bytes_per_gate, - diff_wx_ptr + shift * wx_bytes + 2 * wx_bytes_per_gate, wx_bytes_per_gate); - } - // align naive_weights to weights_iter memory - naive_weights += wx_bytes; - for (int shift = 0; shift < num_layer * direction; ++shift) { - std::memcpy(naive_weights + shift * wxh_bytes + wh_bytes_per_gate, - diff_wh_ptr + shift * wh_bytes, wh_bytes_per_gate); - std::memcpy(naive_weights + shift * wxh_bytes, - diff_wh_ptr + shift * wh_bytes + wh_bytes_per_gate, wh_bytes_per_gate); - std::memcpy(naive_weights + shift * wxh_bytes + 2 * wh_bytes_per_gate, - diff_wh_ptr + shift * wh_bytes + 2 * wh_bytes_per_gate, wh_bytes_per_gate); - } - } - } else if (kAddTo == req) { - if (param.mode != rnn_enum::kGru) { - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - DType* naive_weights = static_cast(diff_weights); - DType* diff_wx_ptr = static_cast(diff_weights_layer_->get_data_handle()); - DType* diff_wh_ptr = static_cast(diff_weights_iter_->get_data_handle()); + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { + DType* native_weights = static_cast(diff_weights); + DType* diff_wx_ptr = static_cast(diff_weights_layer_ptr); + DType* diff_wh_ptr = static_cast(diff_weights_iter_ptr); + OPREQTYPE_SWITCH(req, DType, FAccGrad, { + if (param.mode != rnn_enum::kGru) { for (int shift = 0; shift < num_layer * direction; ++shift) { - common::ParallelAdd(naive_weights + shift * wxh_size, - diff_wx_ptr + shift * wx_size, wx_size); + FAccGrad(native_weights + shift * wxh_size, diff_wx_ptr + shift * wx_size, wx_size); } - // align naive_weights to weights_iter memory - naive_weights += wx_size; + // align native_weights to weights_iter memory + native_weights += wx_size; for (int shift = 0; shift < num_layer * direction; ++shift) { - common::ParallelAdd(naive_weights + shift * wxh_size, - diff_wh_ptr + shift * wh_size, wh_size); + FAccGrad(native_weights + shift * wxh_size, diff_wh_ptr + shift * wh_size, wh_size); } - }); - } else { - const size_t wx_size_per_gate = param.input_size * param.state_size; - const size_t wh_size_per_gate = param.state_size * param.state_size; - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - DType* naive_weights = static_cast(diff_weights); - DType* diff_wx_ptr = static_cast(diff_weights_layer_->get_data_handle()); - DType* diff_wh_ptr = static_cast(diff_weights_iter_->get_data_handle()); + } else { + const size_t wx_size_per_gate = param.input_size * param.state_size; + const size_t wh_size_per_gate = param.state_size * param.state_size; for (int shift = 0; shift < num_layer * direction; ++shift) { - common::ParallelAdd(naive_weights + shift * wxh_size + wx_size_per_gate, + FAccGrad(native_weights + shift * wxh_size + wx_size_per_gate, diff_wx_ptr + shift * wx_size, wx_size_per_gate); - common::ParallelAdd(naive_weights + shift * wxh_size, + FAccGrad(native_weights + shift * wxh_size, diff_wx_ptr + shift * wx_size + wx_size_per_gate, wx_size_per_gate); - common::ParallelAdd(naive_weights + shift * wxh_size + 2 * wx_size_per_gate, + FAccGrad(native_weights + shift * wxh_size + 2 * wx_size_per_gate, diff_wx_ptr + shift * wx_size + 2 * wx_size_per_gate, wx_size_per_gate); } - // align naive_weights to weights_iter memory - naive_weights += wx_size; + // align native_weights to weights_iter memory + native_weights += wx_size; for (int shift = 0; shift < num_layer * direction; ++shift) { - common::ParallelAdd(naive_weights + shift * wxh_size + wh_size_per_gate, + FAccGrad(native_weights + shift * wxh_size + wh_size_per_gate, diff_wh_ptr + shift * wh_size, wh_size_per_gate); - common::ParallelAdd(naive_weights + shift * wxh_size, + FAccGrad(native_weights + shift * wxh_size, diff_wh_ptr + shift * wh_size + wh_size_per_gate, wh_size_per_gate); - common::ParallelAdd(naive_weights + shift * wxh_size + 2 * wh_size_per_gate, + FAccGrad(native_weights + shift * wxh_size + 2 * wh_size_per_gate, diff_wh_ptr + shift * wh_size + 2 * wh_size_per_gate, wh_size_per_gate); } - }); - } - } - - if (kWriteTo == req) { - const size_t bias_bytes = param.single_b_size * dtype_bytes; - const size_t naive_bias_bytes = param.naive_single_b_size * dtype_bytes; - char* naive_bias = static_cast(diff_bias); - char* diff_bias_ptr = static_cast(this->diff_bias_->get_data_handle()); - if (param.mode != rnn_enum::kGru) { - for (int shift = 0; shift < num_layer * direction; ++shift) { - std::memcpy(naive_bias + shift * naive_bias_bytes, - diff_bias_ptr + shift * bias_bytes, bias_bytes); - std::memcpy(naive_bias + shift * naive_bias_bytes + bias_bytes, - diff_bias_ptr + shift * bias_bytes, bias_bytes); - } - } else { - const size_t bias_bytes_per_gate = param.state_size * dtype_bytes; - for (int shift = 0; shift < num_layer * direction; ++shift) { - char* naive_reset = naive_bias + shift * naive_bias_bytes; - char* naive_update = naive_reset + bias_bytes_per_gate; - char* update = diff_bias_ptr + shift * bias_bytes; - char* reset = update + bias_bytes_per_gate; - - std::memcpy(naive_update, update, bias_bytes_per_gate); - std::memcpy(naive_reset, reset, bias_bytes_per_gate); - std::memcpy(naive_update + naive_bias_bytes / 2, update, bias_bytes_per_gate); - std::memcpy(naive_reset + naive_bias_bytes / 2, reset, bias_bytes_per_gate); - - char* naive_new_bx = naive_update + bias_bytes_per_gate; - char* naive_new_bh = naive_new_bx + naive_bias_bytes / 2; - char* new_bx = reset + bias_bytes_per_gate; - char* new_bh = new_bx + bias_bytes_per_gate; - std::memcpy(naive_new_bx, new_bx, bias_bytes_per_gate); - std::memcpy(naive_new_bh, new_bh, bias_bytes_per_gate); } - } - } else if (kAddTo == req) { - const size_t bias_size = param.single_b_size; - const size_t naive_bias_size = param.naive_single_b_size; - if (param.mode != rnn_enum::kGru) { - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - DType* naive_bias = static_cast(diff_bias); - DType* diff_bias_ptr = static_cast(this->diff_bias_->get_data_handle()); + }); + }); + + const size_t bias_size = param.single_b_size; + const size_t naive_bias_size = param.naive_single_b_size; + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { + DType* native_bias = static_cast(diff_bias); + DType* diff_bias_ptr = static_cast(this->diff_bias_->get_data_handle()); + OPREQTYPE_SWITCH(req, DType, FAccGrad, { + if (param.mode != rnn_enum::kGru) { for (int shift = 0; shift < num_layer * direction; ++shift) { - common::ParallelAdd(naive_bias + shift * naive_bias_size, + FAccGrad(native_bias + shift * naive_bias_size, diff_bias_ptr + shift * bias_size, bias_size); - common::ParallelAdd(naive_bias + shift * naive_bias_size + bias_size, + FAccGrad(native_bias + shift * naive_bias_size + bias_size, diff_bias_ptr + shift * bias_size, bias_size); } - }); - } else { - const size_t bias_size_per_gate = param.state_size; - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - DType* naive_bias = static_cast(diff_bias); - DType* diff_bias_ptr = static_cast(this->diff_bias_->get_data_handle()); + } else { + const size_t bias_size_per_gate = param.state_size; for (int shift = 0; shift < num_layer * direction; ++shift) { - DType* naive_reset = naive_bias + shift * naive_bias_size; - DType* naive_update = naive_reset + bias_size_per_gate; + DType* native_reset = native_bias + shift * naive_bias_size; + DType* native_update = native_reset + bias_size_per_gate; DType* update = diff_bias_ptr + shift * bias_size; DType* reset = update + bias_size_per_gate; - common::ParallelAdd(naive_update, update, bias_size_per_gate); - common::ParallelAdd(naive_reset, reset, bias_size_per_gate); - common::ParallelAdd(naive_update + naive_bias_size / 2, update, bias_size_per_gate); - common::ParallelAdd(naive_reset + naive_bias_size / 2, reset, bias_size_per_gate); + FAccGrad(native_update, update, bias_size_per_gate); + FAccGrad(native_reset, reset, bias_size_per_gate); + FAccGrad(native_update + naive_bias_size / 2, update, bias_size_per_gate); + FAccGrad(native_reset + naive_bias_size / 2, reset, bias_size_per_gate); - DType* naive_new_bx = naive_update + bias_size_per_gate; - DType* naive_new_bh = naive_new_bx + naive_bias_size / 2; + DType* native_new_bx = native_update + bias_size_per_gate; + DType* native_new_bh = native_new_bx + naive_bias_size / 2; DType* new_bx = reset + bias_size_per_gate; DType* new_bh = new_bx + bias_size_per_gate; - common::ParallelAdd(naive_new_bx, new_bx, bias_size_per_gate); - common::ParallelAdd(naive_new_bh, new_bh, bias_size_per_gate); + FAccGrad(native_new_bx, new_bx, bias_size_per_gate); + FAccGrad(native_new_bh, new_bh, bias_size_per_gate); } - }); - } - } + } + }); + }); } template @@ -966,6 +947,7 @@ inline void RegisterMKLDNNRnn(MKLDNNRnnX const& rnn) { template <> inline void RegisterMKLDNNRnn(MKLDNNRnnBackward const& rnn) { MKLDNNStream::Get()->RegisterPrimArgs(rnn.GetBwd(), rnn.GetArgsMap()); + rnn.SetNativeWeightsGrads(); } void MKLDNNRnnOp::Forward(const OpContext &ctx, @@ -984,8 +966,8 @@ void MKLDNNRnnOp::Forward(const OpContext &ctx, } // Check if weights NDArray was changed. If so, reset initialized_ - if (weights_version_ != inputs[rnn_enum::kParams].version() && - fwd_inf_vec_.size() > 0) { + if (!is_training && fwd_inf_vec_.size() > 0 + && weights_version_ != inputs[rnn_enum::kParams].version()) { initialized_ = false; for (auto& fwd : fwd_inf_vec_) fwd.Reset(); weights_version_ = inputs[rnn_enum::kParams].version(); @@ -1097,6 +1079,8 @@ void MKLDNNRnnOp::Backward(const OpContext& ctx, using tag = mkldnn::memory::format_tag; TmpMemMgr::Get()->Init(ctx.requested[0]); const RNNParam& default_param = full_param_.default_param; + const int data_dtype = inputs[rnn_enum::kData].dtype(); + const int w_dtype = inputs[rnn_enum::kParams].dtype(); // Initialize the bwd_vec_ if (bwd_vec_.size() != fwd_inf_vec_.size()) { @@ -1110,11 +1094,9 @@ void MKLDNNRnnOp::Backward(const OpContext& ctx, LOG(FATAL) << "MKL-DNN RNN fusion error."; for (size_t lyr = 0; lyr < bwd_vec_.size(); ++lyr) { bwd_vec_.at(lyr).FetchDataWeightsMem(fwd_trn_vec_.at(lyr)); - bwd_vec_.at(lyr).SetWeightsGradsMem(); + bwd_vec_.at(lyr).SetWeightsGradsMem(w_dtype); } - const int data_dtype = inputs[rnn_enum::kData].dtype(); - const int w_dtype = inputs[rnn_enum::kParams].dtype(); const size_t w_bytes = mshadow::mshadow_sizeof(w_dtype); // Get temporary memory for diff_src, diff_state, diff_statecell const int num_layers = default_param.num_layers; @@ -1204,7 +1186,7 @@ void MKLDNNRnnOp::Backward(const OpContext& ctx, // Commit weights diff if (req[rnn_enum::kParams] != kNullOp) { for (size_t lyr = 0; lyr < bwd_vec_.size(); ++lyr) { - bwd_vec_.at(lyr).CommitWeightsDiff(dw, db, req[rnn_enum::kParams], w_dtype); + bwd_vec_.at(lyr).CommitWeightsGrads(dw, db, req[rnn_enum::kParams], w_dtype); dw += full_param_.layer_params.at(lyr).single_w_size * w_bytes; db += full_param_.layer_params.at(lyr).single_b_size * w_bytes; } diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index f6af58bce995..a715bfc2f0cc 100644 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -561,6 +561,24 @@ class OpSignature { case mkldnn_format_kind_rnn_packed: hash = hash * 2 + desc.data.format_desc.rnn_packed_desc.format; eles.push_back(desc.data.format_desc.rnn_packed_desc.format); + hash = hash * 2 + desc.data.format_desc.rnn_packed_desc.n_parts; + eles.push_back(desc.data.format_desc.rnn_packed_desc.n_parts); + for (int i = 0; i < desc.data.format_desc.rnn_packed_desc.n_parts; ++i) { + hash = hash * 2 + desc.data.format_desc.rnn_packed_desc.parts[i]; + hash = hash * 2 + desc.data.format_desc.rnn_packed_desc.part_pack_size[i]; + hash = hash * 2 + desc.data.format_desc.rnn_packed_desc.pack_part[i]; + eles.push_back(desc.data.format_desc.rnn_packed_desc.parts[i]); + eles.push_back(desc.data.format_desc.rnn_packed_desc.part_pack_size[i]); + eles.push_back(desc.data.format_desc.rnn_packed_desc.pack_part[i]); + } + hash = hash * 2 + desc.data.format_desc.rnn_packed_desc.n; + hash = hash * 2 + desc.data.format_desc.rnn_packed_desc.ldb; + hash = hash * 2 + desc.data.format_desc.rnn_packed_desc.offset_compensation; + hash = hash * 2 + desc.data.format_desc.rnn_packed_desc.size; + eles.push_back(desc.data.format_desc.rnn_packed_desc.n); + eles.push_back(desc.data.format_desc.rnn_packed_desc.ldb); + eles.push_back(desc.data.format_desc.rnn_packed_desc.offset_compensation); + eles.push_back(desc.data.format_desc.rnn_packed_desc.size); break; default: // nothing need to add diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc index 542968ef0a2c..a8e1b128f773 100644 --- a/src/operator/rnn.cc +++ b/src/operator/rnn.cc @@ -274,8 +274,7 @@ static void RNNStatefulComputeExCPU(const OpStatePtr& state_ptr, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - if ((inputs[0].dtype() == mshadow::kFloat32 || inputs[0].dtype() == mshadow::kFloat16) && - inputs[0].shape().ndim() == 3 && dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1)) { + if (SupportMKLDNNRnn(inputs[0])) { MKLDNNRnnOp& op = state_ptr.get_state(); op.Forward(ctx, inputs, req, outputs); } else { @@ -288,8 +287,7 @@ static void RNNStatefulGradComputeExCPU(const OpStatePtr& state_ptr, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - if ((inputs[0].dtype() == mshadow::kFloat32 || inputs[0].dtype() == mshadow::kFloat16) && - inputs[0].shape().ndim() == 3 && dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1)) { + if (SupportMKLDNNRnn(inputs[0])) { MKLDNNRnnOp& op = state_ptr.get_state(); op.Backward(ctx, inputs, req, outputs); } else { diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py index 309756b122e7..0f27f53f83a8 100644 --- a/tests/python/unittest/test_gluon_rnn.py +++ b/tests/python/unittest/test_gluon_rnn.py @@ -19,6 +19,8 @@ from mxnet import gluon, nd import numpy as np import copy +from itertools import product +from functools import partial from numpy.testing import assert_allclose import unittest from mxnet.test_utils import almost_equal, assert_almost_equal @@ -545,6 +547,128 @@ def test_rnn_layers_fp16(): run_rnn_layers('float16', 'float32', mx.gpu()) +def check_rnn_consistency(fused_layer, stack_layer, loss, input_size, hidden_size, bidirectional=False, rtol=1e-2, atol=1e-4): + fused_begin_state = fused_layer.begin_state(1) + stack_state = stack_layer.begin_state(batch_size=1) + x = nd.random.normal(shape=(1, 5, input_size)) + x.attach_grad() + y = nd.random.normal(shape=(1, 5, hidden_size * 2 if bidirectional else hidden_size)) + + with mx.autograd.record(): + fused_out, fused_state = fused_layer(x, fused_begin_state) + l = loss(fused_out, y).mean() + l.backward() + fused_grads = dict([(name, p.grad()) for name, p in fused_layer.collect_params().items()]) + fused_input_grad = x.grad.asnumpy() + + with mx.autograd.record(): + stack_out, stack_state = stack_layer.unroll(5, x, stack_state, merge_outputs=True) + l = loss(stack_out, y).mean() + l.backward() + stack_grads = dict([(name, p.grad()) for name, p in stack_layer.collect_params().items()]) + stack_input_grad = x.grad.asnumpy() + + assert_allclose(fused_out.asnumpy(), stack_out.asnumpy(), rtol=rtol, atol=atol) + assert_allclose(fused_input_grad, stack_input_grad, rtol=rtol, atol=atol) + for key, value in fused_grads.items(): + assert_allclose(value.asnumpy(), stack_grads[key].asnumpy(), rtol=rtol, atol=atol) + + +def create_op_by_mode(mode): + if mode == 'lstm': + fused_op = gluon.rnn.LSTM + stack_op = gluon.rnn.LSTMCell + recurrent_block_prefix = 'lstm0_' + elif mode == 'gru': + fused_op = gluon.rnn.GRU + stack_op = gluon.rnn.GRUCell + recurrent_block_prefix = 'gru0_' + elif mode == 'rnn_relu': + fused_op = partial(gluon.rnn.RNN, activation='relu') + stack_op = partial(gluon.rnn.RNNCell, activation='relu') + recurrent_block_prefix = 'rnn0_' + elif mode == 'rnn_tanh': + fused_op = partial(gluon.rnn.RNN, activation='tanh') + stack_op = partial(gluon.rnn.RNNCell, activation='tanh') + recurrent_block_prefix = 'rnn0_' + + return fused_op, stack_op, recurrent_block_prefix + + +def check_rnn_unidir_layer_gradients(mode, input_size, hidden_size, loss): + fused_op, stack_op, recurrent_block_prefix = create_op_by_mode(mode) + # ==== Single layer ==== + fused_layer = fused_op(hidden_size, num_layers=1, layout='NTC', bidirectional=False) + fused_layer.collect_params().initialize() + + params = fused_layer.collect_params() + stack_layer = mx.gluon.rnn.HybridSequentialRNNCell(prefix=recurrent_block_prefix, params=params) + with stack_layer.name_scope(): + stack_layer.add(stack_op(hidden_size, prefix='l0_')) + stack_layer.initialize() + + check_rnn_consistency(fused_layer, stack_layer, loss, input_size, hidden_size) + + # ==== Multiple layer ==== + fused_layer = fused_op(hidden_size, num_layers=3, layout='NTC', bidirectional=False) + fused_layer.collect_params().initialize() + + params = fused_layer.collect_params() + stack_layer = mx.gluon.rnn.HybridSequentialRNNCell(prefix=recurrent_block_prefix, params=params) + with stack_layer.name_scope(): + stack_layer.add(stack_op(hidden_size, prefix='l0_')) + stack_layer.add(stack_op(hidden_size, prefix='l1_')) + stack_layer.add(stack_op(hidden_size, prefix='l2_')) + stack_layer.initialize() + + check_rnn_consistency(fused_layer, stack_layer, loss, input_size, hidden_size) + + +def check_rnn_bidir_layer_gradients(mode, input_size, hidden_size, loss): + fused_op, stack_op, recurrent_block_prefix = create_op_by_mode(mode) + # ==== Single layer ==== + fused_layer = fused_op(hidden_size, num_layers=1, layout='NTC', bidirectional=True) + fused_layer.collect_params().initialize() + + params = fused_layer.collect_params() + stack_layer = mx.gluon.rnn.HybridSequentialRNNCell(prefix=recurrent_block_prefix, params=params) + with stack_layer.name_scope(): + stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size, prefix='l0_'), + stack_op(hidden_size, prefix='r0_'))) + stack_layer.initialize() + + check_rnn_consistency(fused_layer, stack_layer, loss, input_size, hidden_size, bidirectional=True) + + # ==== Multiple layer ==== + fused_layer = fused_op(hidden_size, num_layers=3, layout='NTC', bidirectional=True) + fused_layer.collect_params().initialize() + + params = fused_layer.collect_params() + stack_layer = mx.gluon.rnn.HybridSequentialRNNCell(prefix=recurrent_block_prefix, params=params) + with stack_layer.name_scope(): + stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size, prefix='l0_'), + stack_op(hidden_size, prefix='r0_'))) + stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size, prefix='l1_'), + stack_op(hidden_size, prefix='r1_'))) + stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size, prefix='l2_'), + stack_op(hidden_size, prefix='r2_'))) + stack_layer.initialize() + + check_rnn_consistency(fused_layer, stack_layer, loss, input_size, hidden_size, bidirectional=True) + + +@assert_raises_cudnn_not_satisfied(min_version='5.1.10') +def test_fused_rnn_layer(): + input_sizes = [128] + hidden_sizes = [128, 256] + modes = ['lstm', 'gru', 'rnn_relu', 'rnn_tanh'] + # single layer + for mode, input_size, hidden_size in product(modes, input_sizes, hidden_sizes): + loss = mx.gluon.loss.L2Loss() + check_rnn_unidir_layer_gradients(mode, input_size, hidden_size, loss) + check_rnn_bidir_layer_gradients(mode, input_size, hidden_size, loss) + + def test_rnn_unroll_variant_length(): # Test for imperative usage cell_list = []