diff --git a/docs/static_site/src/pages/api/faq/env_var.md b/docs/static_site/src/pages/api/faq/env_var.md index bc98c39d9570..57ab27630a8f 100644 --- a/docs/static_site/src/pages/api/faq/env_var.md +++ b/docs/static_site/src/pages/api/faq/env_var.md @@ -289,11 +289,11 @@ If ctypes is used, it must be `mxnet._ctypes.ndarray.NDArrayBase`. If no such algorithm exists given other constraints, MXNet will error out. This variable affects the choice of CUDNN convolution algorithms. Please see [CUDNN developer guide](https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html) for more details. -* MXNET_CPU_PARALLEL_COPY_SIZE +* MXNET_CPU_PARALLEL_SIZE - Values: Int ```(default=200000)``` - - The minimum size to call parallel copy by OpenMP in CPU2CPU mode. - - When the array size is bigger than or equal to this threshold, NDArray::Copy(from, to) is implemented by OpenMP with the Recommended OMP Thread Count. - - When the array size is less than this threshold, NDArray::Copy(from , to)) is implemented by memcpy in single thread. + - The minimum size to call parallel operations by OpenMP for CPU context. + - When the array size is bigger than or equal to this threshold, the operation implemented by OpenMP is executed with the Recommended OMP Thread Count. + - When the array size is less than this threshold, the operation is implemented naively in single thread. * MXNET_OPTIMIZER_AGGREGATION_SIZE - Values: Int ```(default=4)``` @@ -349,6 +349,10 @@ If ctypes is used, it must be `mxnet._ctypes.ndarray.NDArrayBase`. - Values: 0(false) or 1(true) ```(default=1)``` - If this variable is set, MXNet will simplify the computation graph, eliminating duplicated operations on the same inputs. +* MXNET_USE_MKLDNN_RNN + - Values: 0(false) or 1(true) ```(default=1)``` + - This variable controls whether to use the MKL-DNN backend in fused RNN operator for CPU context. There are two fusion implementations of RNN operator in MXNet. The MKL-DNN implementation has a better performance than the naive one, but the latter is more stable in the backward operation currently. + Settings for Minimum Memory Usage --------------------------------- - Make sure ```min(MXNET_EXEC_NUM_TEMP, MXNET_GPU_WORKER_NTHREADS) = 1``` diff --git a/src/common/utils.h b/src/common/utils.h index 9a9c686e73c9..2187ad053b66 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -769,7 +769,7 @@ inline void EmplaceBackZeros(const NDArrayStorageType stype, const mxnet::TShape */ template inline void ParallelCopy(DType* dst, const DType* src, index_t size) { - static index_t copy_block_size = dmlc::GetEnv("MXNET_CPU_PARALLEL_COPY_SIZE", 200000); + static index_t copy_block_size = dmlc::GetEnv("MXNET_CPU_PARALLEL_SIZE", 200000); if (size >= copy_block_size) { #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) for (index_t i = 0; i < size; ++i) { @@ -780,6 +780,24 @@ inline void ParallelCopy(DType* dst, const DType* src, index_t size) { } } +/*! + * \breif parallelize add by OpenMP + */ +template +inline void ParallelAdd(DType* dst, const DType* src, index_t size) { + static index_t add_block_size = dmlc::GetEnv("MXNET_CPU_PARALLEL_SIZE", 200000); + if (size >= add_block_size) { + #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) + for (index_t i = 0; i < size; ++i) { + dst[i] += src[i]; + } + } else { + for (index_t i = 0; i < size; ++i) { + dst[i] += src[i]; + } + } +} + /*! * \brief If numpy compatibility is turned off (default), the shapes passed in * by users follow the legacy shape definition: diff --git a/src/operator/nn/mkldnn/mkldnn_rnn-inl.h b/src/operator/nn/mkldnn/mkldnn_rnn-inl.h index ad3f7332a8f4..314106b98eb9 100644 --- a/src/operator/nn/mkldnn/mkldnn_rnn-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_rnn-inl.h @@ -120,25 +120,24 @@ class RnnPrimitive { template static RnnPrimitive Create(Args&&... args) { RnnPrimitive rnn_fwd_prim; - rnn_fwd_prim.pd_.reset( - new typename rnn_fwd::desc(std::forward(args)...), - [](typename rnn_fwd::desc* pd) { - delete reinterpret_cast(pd); + auto fwd_desc = typename rnn_fwd::desc(std::forward(args)...); + rnn_fwd_prim.fwd_pd_.reset( + new typename rnn_fwd::primitive_desc(fwd_desc, CpuEngine::Get()->get_engine()), + [](typename rnn_fwd::primitive_desc* pd) { + delete reinterpret_cast(pd); }); - const typename rnn_fwd::desc& fwd_desc = - *(reinterpret_cast(rnn_fwd_prim.pd_.get())); - typename rnn_fwd::primitive_desc fwd_pd(fwd_desc, CpuEngine::Get()->get_engine()); - rnn_fwd_prim.weights_layer_desc_ = fwd_pd.weights_layer_desc(); - rnn_fwd_prim.weights_iter_desc_ = fwd_pd.weights_iter_desc(); - rnn_fwd_prim.workspace_desc_ = fwd_pd.workspace_desc(); + auto fwd_pd = reinterpret_cast(rnn_fwd_prim.fwd_pd_.get()); + rnn_fwd_prim.weights_layer_desc_ = fwd_pd->weights_layer_desc(); + rnn_fwd_prim.weights_iter_desc_ = fwd_pd->weights_iter_desc(); + rnn_fwd_prim.workspace_desc_ = fwd_pd->workspace_desc(); - rnn_fwd_prim.primitive_ = std::shared_ptr(new rnn_fwd(fwd_pd)); + rnn_fwd_prim.primitive_ = std::shared_ptr(new rnn_fwd(*fwd_pd)); return rnn_fwd_prim; } RnnPrimitive() { - this->pd_ = nullptr; + this->fwd_pd_ = nullptr; this->primitive_ = nullptr; this->weights_layer_desc_ = mkldnn::memory::desc(); this->weights_iter_desc_ = mkldnn::memory::desc(); @@ -146,7 +145,7 @@ class RnnPrimitive { } RnnPrimitive(const RnnPrimitive& rnn_fwd_prim) { - this->pd_ = rnn_fwd_prim.pd_; + this->fwd_pd_ = rnn_fwd_prim.fwd_pd_; this->primitive_ = rnn_fwd_prim.primitive_; this->weights_layer_desc_ = rnn_fwd_prim.weights_layer_desc_; this->weights_iter_desc_ = rnn_fwd_prim.weights_iter_desc_; @@ -155,7 +154,7 @@ class RnnPrimitive { RnnPrimitive& operator=(const RnnPrimitive& rnn_fwd_prim) { if (this != &rnn_fwd_prim) { - this->pd_ = rnn_fwd_prim.pd_; + this->fwd_pd_ = rnn_fwd_prim.fwd_pd_; this->primitive_ = rnn_fwd_prim.primitive_; this->weights_layer_desc_ = rnn_fwd_prim.weights_layer_desc_; this->weights_iter_desc_ = rnn_fwd_prim.weights_iter_desc_; @@ -165,7 +164,7 @@ class RnnPrimitive { return *this; } - const void* GetPrimDesc() const { return pd_.get(); } + const void* GetPrimDesc() const { return fwd_pd_.get(); } const mkldnn::primitive& GetPrim() const { return *primitive_; } const mkldnn::memory::desc& GetLayerDesc() const { @@ -181,7 +180,7 @@ class RnnPrimitive { } private: - std::shared_ptr pd_; + std::shared_ptr fwd_pd_; std::shared_ptr primitive_; mkldnn::memory::desc weights_layer_desc_; mkldnn::memory::desc weights_iter_desc_; @@ -370,7 +369,9 @@ class MKLDNNRnnBackward { void SetDataGradsMem(void* diff_src, void* diff_state, void* diff_statecell, void* diff_out, void* diff_state_out, void* diff_statecell_out, const int dtype = mshadow::kFloat32); - void CommitWeightsDiff(void* diff_weights, void* diff_bias, const int dtype = mshadow::kFloat32); + void CommitWeightsDiff(void* diff_weights, void* diff_bias, + const OpReqType req, + const int dtype = mshadow::kFloat32); const mkldnn::primitive& GetBwd() const { return *bwd_.primitive_; } const mkldnn_args_map_t& GetArgsMap() const { return net_args_; } diff --git a/src/operator/nn/mkldnn/mkldnn_rnn.cc b/src/operator/nn/mkldnn/mkldnn_rnn.cc index e797b649d295..6da8f3b8a58a 100644 --- a/src/operator/nn/mkldnn/mkldnn_rnn.cc +++ b/src/operator/nn/mkldnn/mkldnn_rnn.cc @@ -213,13 +213,13 @@ RnnBwdPrimitive GetRnnBwdPrim(const MKLDNNRnnForwardTraining &fwd, auto dst_state_desc = layer_param.state_outputs ? memory::desc( layer_param.state_dims, data_type, tag::ldnc) : memory::desc(); - const void* fwd_desc = fwd.GetPrimDesc(); + const void* fwd_pd = fwd.GetPrimDesc(); auto bwd = RnnBwdPrimitive(); switch (mode) { case rnn_enum::kLstm: { - const lstm_forward::primitive_desc* desc = - reinterpret_cast(fwd_desc); - bwd = RnnBwdPrimitive::Create(*desc, + const lstm_forward::primitive_desc* pd = + reinterpret_cast(fwd_pd); + bwd = RnnBwdPrimitive::Create(*pd, prop, mkldnn_rnn_direction, // data desc src_layer_desc, src_state_desc, src_state_desc, weight_layer_desc, @@ -231,9 +231,9 @@ RnnBwdPrimitive GetRnnBwdPrim(const MKLDNNRnnForwardTraining &fwd, dst_state_desc); } break; case rnn_enum::kGru: { - const lbr_gru_forward::primitive_desc* desc = - reinterpret_cast(fwd_desc); - bwd = RnnBwdPrimitive::Create(*desc, + const lbr_gru_forward::primitive_desc* pd = + reinterpret_cast(fwd_pd); + bwd = RnnBwdPrimitive::Create(*pd, prop, mkldnn_rnn_direction, // data desc src_layer_desc, src_state_desc, weight_layer_desc, @@ -244,10 +244,10 @@ RnnBwdPrimitive GetRnnBwdPrim(const MKLDNNRnnForwardTraining &fwd, } break; case rnn_enum::kRnnRelu: case rnn_enum::kRnnTanh: { - const vanilla_rnn_forward::primitive_desc* desc = - reinterpret_cast(fwd_desc); + const vanilla_rnn_forward::primitive_desc* pd = + reinterpret_cast(fwd_pd); bwd = RnnBwdPrimitive::Create( - *desc, prop, + *pd, prop, mode == rnn_enum::kRnnTanh ? algorithm::eltwise_tanh : algorithm::eltwise_relu, mkldnn_rnn_direction, // data desc @@ -776,16 +776,8 @@ void MKLDNNRnnBackward::SetDataGradsMem( } } -template -void HalveWeightsDiff(DType* w, const size_t size) { - const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); - #pragma omp parallel for num_threads(omp_threads) - for (int i = 0; i < static_cast(size); ++i) { - w[i] *= 0.5; - } -} - -void MKLDNNRnnBackward::CommitWeightsDiff(void* diff_weights, void* diff_bias, int dtype) { +void MKLDNNRnnBackward::CommitWeightsDiff(void* diff_weights, void* diff_bias, + const OpReqType req, const int dtype) { using tag = mkldnn::memory::format_tag; auto& cpu_engine = CpuEngine::Get()->get_engine(); auto s = mkldnn::stream(cpu_engine); @@ -795,11 +787,12 @@ void MKLDNNRnnBackward::CommitWeightsDiff(void* diff_weights, void* diff_bias, i const int direction = param.bidirectional ? 2 : 1; const int ngates = GetRnnGatesNum(param.mode); const size_t dtype_bytes = mshadow::mshadow_sizeof(dtype); + const size_t wxh_size = param.single_w_size; + const size_t wx_size = param.input_size * param.state_size * ngates; + const size_t wh_size = param.state_size * param.state_size * ngates; const size_t wxh_bytes = param.single_w_size * dtype_bytes; const size_t wx_bytes = param.input_size * param.state_size * ngates * dtype_bytes; const size_t wh_bytes = param.state_size * param.state_size * ngates * dtype_bytes; - char* diff_wx_ptr = static_cast(diff_weights_layer_->get_data_handle()); - char* diff_wh_ptr = static_cast(diff_weights_iter_->get_data_handle()); /* naive weights layout is: 1st-layer: | wx_lr | wh_lr | wx_rl | wh_rl | @@ -807,68 +800,109 @@ void MKLDNNRnnBackward::CommitWeightsDiff(void* diff_weights, void* diff_bias, i size: | wxh_bytes | |wx_bytes|wh_bytes| */ - char* naive_weights = static_cast(diff_weights); - if (param.mode != rnn_enum::kGru) { - for (int shift = 0; shift < num_layer * direction; ++shift) { - std::memcpy(naive_weights + shift * wxh_bytes, - diff_wx_ptr + shift * wx_bytes, wx_bytes); - } - // align naive_weights to weights_iter memory - naive_weights += wx_bytes; - for (int shift = 0; shift < num_layer * direction; ++shift) { - std::memcpy(naive_weights + shift * wxh_bytes, - diff_wh_ptr + shift * wh_bytes, wh_bytes); - } - } else { - const size_t wx_bytes_per_gate = param.input_size * param.state_size * dtype_bytes; - const size_t wh_bytes_per_gate = param.state_size * param.state_size * dtype_bytes; - for (int shift = 0; shift < num_layer * direction; ++shift) { - std::memcpy(naive_weights + shift * wxh_bytes + wx_bytes_per_gate, - diff_wx_ptr + shift * wx_bytes, wx_bytes_per_gate); - std::memcpy(naive_weights + shift * wxh_bytes, - diff_wx_ptr + shift * wx_bytes + wx_bytes_per_gate, wx_bytes_per_gate); - std::memcpy(naive_weights + shift * wxh_bytes + 2 * wx_bytes_per_gate, - diff_wx_ptr + shift * wx_bytes + 2 * wx_bytes_per_gate, wx_bytes_per_gate); + if (kWriteTo == req) { + char* naive_weights = static_cast(diff_weights); + char* diff_wx_ptr = static_cast(diff_weights_layer_->get_data_handle()); + char* diff_wh_ptr = static_cast(diff_weights_iter_->get_data_handle()); + if (param.mode != rnn_enum::kGru) { + for (int shift = 0; shift < num_layer * direction; ++shift) { + std::memcpy(naive_weights + shift * wxh_bytes, + diff_wx_ptr + shift * wx_bytes, wx_bytes); + } + // align naive_weights to weights_iter memory + naive_weights += wx_bytes; + for (int shift = 0; shift < num_layer * direction; ++shift) { + std::memcpy(naive_weights + shift * wxh_bytes, + diff_wh_ptr + shift * wh_bytes, wh_bytes); + } + } else { + const size_t wx_bytes_per_gate = param.input_size * param.state_size * dtype_bytes; + const size_t wh_bytes_per_gate = param.state_size * param.state_size * dtype_bytes; + for (int shift = 0; shift < num_layer * direction; ++shift) { + std::memcpy(naive_weights + shift * wxh_bytes + wx_bytes_per_gate, + diff_wx_ptr + shift * wx_bytes, wx_bytes_per_gate); + std::memcpy(naive_weights + shift * wxh_bytes, + diff_wx_ptr + shift * wx_bytes + wx_bytes_per_gate, wx_bytes_per_gate); + std::memcpy(naive_weights + shift * wxh_bytes + 2 * wx_bytes_per_gate, + diff_wx_ptr + shift * wx_bytes + 2 * wx_bytes_per_gate, wx_bytes_per_gate); + } + // align naive_weights to weights_iter memory + naive_weights += wx_bytes; + for (int shift = 0; shift < num_layer * direction; ++shift) { + std::memcpy(naive_weights + shift * wxh_bytes + wh_bytes_per_gate, + diff_wh_ptr + shift * wh_bytes, wh_bytes_per_gate); + std::memcpy(naive_weights + shift * wxh_bytes, + diff_wh_ptr + shift * wh_bytes + wh_bytes_per_gate, wh_bytes_per_gate); + std::memcpy(naive_weights + shift * wxh_bytes + 2 * wh_bytes_per_gate, + diff_wh_ptr + shift * wh_bytes + 2 * wh_bytes_per_gate, wh_bytes_per_gate); + } } - // align naive_weights to weights_iter memory - naive_weights += wx_bytes; - for (int shift = 0; shift < num_layer * direction; ++shift) { - std::memcpy(naive_weights + shift * wxh_bytes + wh_bytes_per_gate, - diff_wh_ptr + shift * wh_bytes, wh_bytes_per_gate); - std::memcpy(naive_weights + shift * wxh_bytes, - diff_wh_ptr + shift * wh_bytes + wh_bytes_per_gate, wh_bytes_per_gate); - std::memcpy(naive_weights + shift * wxh_bytes + 2 * wh_bytes_per_gate, - diff_wh_ptr + shift * wh_bytes + 2 * wh_bytes_per_gate, wh_bytes_per_gate); + } else if (kAddTo == req) { + if (param.mode != rnn_enum::kGru) { + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { + DType* naive_weights = static_cast(diff_weights); + DType* diff_wx_ptr = static_cast(diff_weights_layer_->get_data_handle()); + DType* diff_wh_ptr = static_cast(diff_weights_iter_->get_data_handle()); + for (int shift = 0; shift < num_layer * direction; ++shift) { + common::ParallelAdd(naive_weights + shift * wxh_size, + diff_wx_ptr + shift * wx_size, wx_size); + } + // align naive_weights to weights_iter memory + naive_weights += wx_size; + for (int shift = 0; shift < num_layer * direction; ++shift) { + common::ParallelAdd(naive_weights + shift * wxh_size, + diff_wh_ptr + shift * wh_size, wh_size); + } + }); + } else { + const size_t wx_size_per_gate = param.input_size * param.state_size; + const size_t wh_size_per_gate = param.state_size * param.state_size; + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { + DType* naive_weights = static_cast(diff_weights); + DType* diff_wx_ptr = static_cast(diff_weights_layer_->get_data_handle()); + DType* diff_wh_ptr = static_cast(diff_weights_iter_->get_data_handle()); + for (int shift = 0; shift < num_layer * direction; ++shift) { + common::ParallelAdd(naive_weights + shift * wxh_size + wx_size_per_gate, + diff_wx_ptr + shift * wx_size, wx_size_per_gate); + common::ParallelAdd(naive_weights + shift * wxh_size, + diff_wx_ptr + shift * wx_size + wx_size_per_gate, wx_size_per_gate); + common::ParallelAdd(naive_weights + shift * wxh_size + 2 * wx_size_per_gate, + diff_wx_ptr + shift * wx_size + 2 * wx_size_per_gate, wx_size_per_gate); + } + // align naive_weights to weights_iter memory + naive_weights += wx_size; + for (int shift = 0; shift < num_layer * direction; ++shift) { + common::ParallelAdd(naive_weights + shift * wxh_size + wh_size_per_gate, + diff_wh_ptr + shift * wh_size, wh_size_per_gate); + common::ParallelAdd(naive_weights + shift * wxh_size, + diff_wh_ptr + shift * wh_size + wh_size_per_gate, wh_size_per_gate); + common::ParallelAdd(naive_weights + shift * wxh_size + 2 * wh_size_per_gate, + diff_wh_ptr + shift * wh_size + 2 * wh_size_per_gate, wh_size_per_gate); + } + }); } } - char* naive_bias = static_cast(diff_bias); - char* diff_bias_ptr = static_cast(this->diff_bias_->get_data_handle()); - const size_t bias_bytes = param.single_b_size * dtype_bytes; - const size_t naive_bias_bytes = param.naive_single_b_size * dtype_bytes; - if (param.mode != rnn_enum::kGru) { - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - DType* typed_bias = reinterpret_cast(diff_bias_ptr); - HalveWeightsDiff(typed_bias, num_layer * direction * param.single_b_size); - }); - for (int shift = 0; shift < num_layer * direction; ++shift) { - std::memcpy(naive_bias + shift * naive_bias_bytes, - diff_bias_ptr + shift * bias_bytes, bias_bytes); - std::memcpy(naive_bias + shift * naive_bias_bytes + bias_bytes, - diff_bias_ptr + shift * bias_bytes, bias_bytes); - } - } else { - const size_t bias_bytes_per_gate = param.state_size * dtype_bytes; - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { + if (kWriteTo == req) { + const size_t bias_bytes = param.single_b_size * dtype_bytes; + const size_t naive_bias_bytes = param.naive_single_b_size * dtype_bytes; + char* naive_bias = static_cast(diff_bias); + char* diff_bias_ptr = static_cast(this->diff_bias_->get_data_handle()); + if (param.mode != rnn_enum::kGru) { + for (int shift = 0; shift < num_layer * direction; ++shift) { + std::memcpy(naive_bias + shift * naive_bias_bytes, + diff_bias_ptr + shift * bias_bytes, bias_bytes); + std::memcpy(naive_bias + shift * naive_bias_bytes + bias_bytes, + diff_bias_ptr + shift * bias_bytes, bias_bytes); + } + } else { + const size_t bias_bytes_per_gate = param.state_size * dtype_bytes; for (int shift = 0; shift < num_layer * direction; ++shift) { char* naive_reset = naive_bias + shift * naive_bias_bytes; char* naive_update = naive_reset + bias_bytes_per_gate; char* update = diff_bias_ptr + shift * bias_bytes; char* reset = update + bias_bytes_per_gate; - DType* typed_update = reinterpret_cast(update); - HalveWeightsDiff(typed_update, param.state_size * 2); - std::memcpy(naive_update, update, bias_bytes_per_gate); std::memcpy(naive_reset, reset, bias_bytes_per_gate); std::memcpy(naive_update + naive_bias_bytes / 2, update, bias_bytes_per_gate); @@ -881,7 +915,46 @@ void MKLDNNRnnBackward::CommitWeightsDiff(void* diff_weights, void* diff_bias, i std::memcpy(naive_new_bx, new_bx, bias_bytes_per_gate); std::memcpy(naive_new_bh, new_bh, bias_bytes_per_gate); } - }); + } + } else if (kAddTo == req) { + const size_t bias_size = param.single_b_size; + const size_t naive_bias_size = param.naive_single_b_size; + if (param.mode != rnn_enum::kGru) { + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { + DType* naive_bias = static_cast(diff_bias); + DType* diff_bias_ptr = static_cast(this->diff_bias_->get_data_handle()); + for (int shift = 0; shift < num_layer * direction; ++shift) { + common::ParallelAdd(naive_bias + shift * naive_bias_size, + diff_bias_ptr + shift * bias_size, bias_size); + common::ParallelAdd(naive_bias + shift * naive_bias_size + bias_size, + diff_bias_ptr + shift * bias_size, bias_size); + } + }); + } else { + const size_t bias_size_per_gate = param.state_size; + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { + DType* naive_bias = static_cast(diff_bias); + DType* diff_bias_ptr = static_cast(this->diff_bias_->get_data_handle()); + for (int shift = 0; shift < num_layer * direction; ++shift) { + DType* naive_reset = naive_bias + shift * naive_bias_size; + DType* naive_update = naive_reset + bias_size_per_gate; + DType* update = diff_bias_ptr + shift * bias_size; + DType* reset = update + bias_size_per_gate; + + common::ParallelAdd(naive_update, update, bias_size_per_gate); + common::ParallelAdd(naive_reset, reset, bias_size_per_gate); + common::ParallelAdd(naive_update + naive_bias_size / 2, update, bias_size_per_gate); + common::ParallelAdd(naive_reset + naive_bias_size / 2, reset, bias_size_per_gate); + + DType* naive_new_bx = naive_update + bias_size_per_gate; + DType* naive_new_bh = naive_new_bx + naive_bias_size / 2; + DType* new_bx = reset + bias_size_per_gate; + DType* new_bh = new_bx + bias_size_per_gate; + common::ParallelAdd(naive_new_bx, new_bx, bias_size_per_gate); + common::ParallelAdd(naive_new_bh, new_bh, bias_size_per_gate); + } + }); + } } } @@ -899,19 +972,11 @@ void MKLDNNRnnOp::Forward(const OpContext &ctx, const std::vector &inputs, const std::vector &req, const std::vector &outputs) { + TmpMemMgr::Get()->Init(ctx.requested[0]); // In the `autograd.record()` context, RNNOp is required to run into // forward_training mode. const bool is_training = (ctx.is_train || ctx.need_grad); - // check output requests - if (kAddTo == req[rnn_enum::kOut]) - LOG(FATAL) << "Currently, `add` operation is not supported by RNNs."; const RNNParam& default_param = full_param_.default_param; - if (default_param.state_outputs) { - if (kAddTo == req[rnn_enum::kStateOut]) - LOG(FATAL) << "Currently, `add` operation is not supported by RNNs."; - if (default_param.mode == rnn_enum::kLstm && kAddTo == req[rnn_enum::kStateCellOut]) - LOG(FATAL) << "Currently, `add` operation against lstm-cell output is not supported."; - } // Initialize weights version if (!initialized_ && weights_version_ == 0) { @@ -932,24 +997,40 @@ void MKLDNNRnnOp::Forward(const OpContext &ctx, // Get data type int data_dtype = inputs[rnn_enum::kData].dtype(); + // Get temporary memory for output, state_out, statecell_out + const int num_layers = default_param.num_layers; + const int seq_length = default_param.seq_length_; + const int batch_size = default_param.batch_size_; + const int state_size = default_param.state_size; + const int directions = default_param.bidirectional ? 2 : 1; + mkldnn::memory::desc dst_desc({seq_length, batch_size, directions * state_size}, + get_mkldnn_type(data_dtype), mkldnn::memory::format_tag::tnc); + mkldnn::memory::desc state_desc({num_layers, directions, batch_size, state_size}, + get_mkldnn_type(data_dtype), mkldnn::memory::format_tag::ldnc); + auto out_mem = CreateMKLDNNMem(outputs[rnn_enum::kOut], dst_desc, req[rnn_enum::kOut]); + mkldnn_output_t stateout_mem; + mkldnn_output_t statecellout_mem; // Get input & output NDArray char *src = static_cast(inputs[rnn_enum::kData].data().dptr_); char *src_state = static_cast(inputs[rnn_enum::kState].data().dptr_); - char *dst = req[rnn_enum::kOut] == kNullOp ? nullptr - : static_cast(outputs[rnn_enum::kOut].data().dptr_); + char *dst = static_cast(out_mem.second->get_data_handle()); char *dst_state = nullptr; // Output state char *src_state_cell = nullptr; // Used in LSTM for cell state char *dst_state_cell = nullptr; // Used in LSTM for cell state if (default_param.state_outputs && req[rnn_enum::kStateOut] != kNullOp) { - dst_state = static_cast(outputs[rnn_enum::kStateOut].data().dptr_); + stateout_mem = CreateMKLDNNMem( + outputs[rnn_enum::kStateOut], state_desc, req[rnn_enum::kStateOut]); + dst_state = static_cast(stateout_mem.second->get_data_handle()); } if (default_param.mode == rnn_enum::kLstm) { src_state_cell = static_cast(inputs[rnn_enum::kStateCell].data().dptr_); if (default_param.state_outputs && req[rnn_enum::kStateCellOut] != kNullOp) { - dst_state_cell = static_cast(outputs[rnn_enum::kStateCellOut].data().dptr_); + statecellout_mem = CreateMKLDNNMem( + outputs[rnn_enum::kStateCellOut], state_desc, req[rnn_enum::kStateCellOut]); + dst_state_cell = static_cast(statecellout_mem.second->get_data_handle()); } } @@ -1000,6 +1081,12 @@ void MKLDNNRnnOp::Forward(const OpContext &ctx, } else { for (auto& inf_lyr : fwd_inf_vec_) RegisterMKLDNNRnn(inf_lyr); } + CommitOutput(outputs[rnn_enum::kOut], out_mem); + if (default_param.state_outputs) { + CommitOutput(outputs[rnn_enum::kStateOut], stateout_mem); + if (default_param.mode == rnn_enum::kLstm) + CommitOutput(outputs[rnn_enum::kStateCellOut], statecellout_mem); + } MKLDNNStream::Get()->Submit(); } @@ -1008,18 +1095,9 @@ void MKLDNNRnnOp::Backward(const OpContext& ctx, const std::vector& req, const std::vector& outputs) { using tag = mkldnn::memory::format_tag; + TmpMemMgr::Get()->Init(ctx.requested[0]); const RNNParam& default_param = full_param_.default_param; - if (kAddTo == req[rnn_enum::kData] || kAddTo == req[rnn_enum::kParams]) - LOG(FATAL) << "Currently, `add` operations against gradients of input and weights" - << " are not supported by RNNs."; - if (default_param.state_outputs) { - if (kAddTo == req[rnn_enum::kStateOut]) - LOG(FATAL) << "Currently, `add` operation against gradients of begining state" - << " is not supported by RNNs."; - if (default_param.mode == rnn_enum::kLstm && req[rnn_enum::kStateCell]) - LOG(FATAL) << "Currently, `add` operation against gradients of begining cell-state" - << " is not supported by LSTM."; - } + // Initialize the bwd_vec_ if (bwd_vec_.size() != fwd_inf_vec_.size()) { bwd_vec_.clear(); @@ -1038,21 +1116,38 @@ void MKLDNNRnnOp::Backward(const OpContext& ctx, const int data_dtype = inputs[rnn_enum::kData].dtype(); const int w_dtype = inputs[rnn_enum::kParams].dtype(); const size_t w_bytes = mshadow::mshadow_sizeof(w_dtype); + // Get temporary memory for diff_src, diff_state, diff_statecell + const int num_layers = default_param.num_layers; + const int seq_length = default_param.seq_length_; + const int batch_size = default_param.batch_size_; + const int input_size = default_param.input_size_; + const int state_size = default_param.state_size; + const int directions = default_param.bidirectional ? 2 : 1; + mkldnn::memory::desc src_desc({seq_length, batch_size, input_size}, + get_mkldnn_type(data_dtype), tag::tnc); + mkldnn::memory::desc state_desc({num_layers, directions, batch_size, state_size}, + get_mkldnn_type(data_dtype), tag::ldnc); + auto diff_input_mem = CreateMKLDNNMem(outputs[rnn_enum::kData], src_desc, req[rnn_enum::kData]); + mkldnn_output_t diff_state_mem; + mkldnn_output_t diff_statecell_mem; // index description of outputs NDArray // 0 1 2 3 // | dx | dw | dhx | dcx| - char* dx = req[rnn_enum::kData] == kNullOp ? nullptr - : static_cast(outputs[rnn_enum::kData].data().dptr_); + char* dx = static_cast(diff_input_mem.second->get_data_handle()); char* dw = static_cast(outputs[rnn_enum::kParams].data().dptr_); char* db = dw + (inputs[rnn_enum::kParams].data().Size() - GetRnnBiasSize(default_param.num_layers, default_param.state_size, default_param.bidirectional + 1, default_param.mode)) * w_bytes; - char* dhx = req[rnn_enum::kState] == kNullOp ? nullptr - : static_cast(outputs[rnn_enum::kState].data().dptr_); + diff_state_mem = CreateMKLDNNMem( + outputs[rnn_enum::kState], state_desc, req[rnn_enum::kState]); + char* dhx = static_cast(diff_state_mem.second->get_data_handle()); char* dcx = nullptr; if (full_param_.default_param.mode == rnn_enum::kLstm - && req[rnn_enum::kStateCell] != kNullOp) - dcx = static_cast(outputs[rnn_enum::kStateCell].data().dptr_); + && req[rnn_enum::kStateCell] != kNullOp) { + diff_statecell_mem = CreateMKLDNNMem( + outputs[rnn_enum::kStateCell], state_desc, req[rnn_enum::kStateCell]); + dcx = static_cast(diff_statecell_mem.second->get_data_handle()); + } // index description of inputs NDArray // 0 1 2 3 4 5 6 7 8 9 @@ -1100,12 +1195,16 @@ void MKLDNNRnnOp::Backward(const OpContext& ctx, RegisterMKLDNNRnn(*bwd); } } + CommitOutput(outputs[rnn_enum::kData], diff_input_mem); + CommitOutput(outputs[rnn_enum::kState], diff_state_mem); + if (full_param_.default_param.mode == rnn_enum::kLstm) + CommitOutput(outputs[rnn_enum::kStateCell], diff_statecell_mem); MKLDNNStream::Get()->Submit(); // Commit weights diff if (req[rnn_enum::kParams] != kNullOp) { for (size_t lyr = 0; lyr < bwd_vec_.size(); ++lyr) { - bwd_vec_.at(lyr).CommitWeightsDiff(dw, db, w_dtype); + bwd_vec_.at(lyr).CommitWeightsDiff(dw, db, req[rnn_enum::kParams], w_dtype); dw += full_param_.layer_params.at(lyr).single_w_size * w_bytes; db += full_param_.layer_params.at(lyr).single_b_size * w_bytes; } diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc index 6d568c81bc1c..542968ef0a2c 100644 --- a/src/operator/rnn.cc +++ b/src/operator/rnn.cc @@ -181,6 +181,10 @@ static std::vector RNNResourceEx(const NodeAttrs& attrs, const if (param.p != 0 && 1.0f - param.p > 0) { request.emplace_back(ResourceRequest::kCuDNNDropoutDesc); } +#endif + } else { +#if MXNET_USE_MKLDNN == 1 + request.emplace_back(ResourceRequest::kTempSpace); #endif } return request; @@ -243,7 +247,8 @@ static OpStatePtr CreateRNNState(const nnvm::NodeAttrs &attrs, #if MXNET_USE_MKLDNN == 1 if ((in_types[0] == mshadow::kFloat32 || in_types[0] == mshadow::kFloat16) - && in_shapes[0].ndim() == 3 && ctx.dev_type == kCPU) { + && in_shapes[0].ndim() == 3 && ctx.dev_type == kCPU + && dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1)) { const mxnet::TShape& data_shape = in_shapes[rnn_enum::kData]; state = OpStatePtr::Create(param, data_shape[0], data_shape[1], data_shape[2]); @@ -270,7 +275,7 @@ static void RNNStatefulComputeExCPU(const OpStatePtr& state_ptr, const std::vector& req, const std::vector& outputs) { if ((inputs[0].dtype() == mshadow::kFloat32 || inputs[0].dtype() == mshadow::kFloat16) && - inputs[0].shape().ndim() == 3) { + inputs[0].shape().ndim() == 3 && dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1)) { MKLDNNRnnOp& op = state_ptr.get_state(); op.Forward(ctx, inputs, req, outputs); } else { @@ -284,7 +289,7 @@ static void RNNStatefulGradComputeExCPU(const OpStatePtr& state_ptr, const std::vector& req, const std::vector& outputs) { if ((inputs[0].dtype() == mshadow::kFloat32 || inputs[0].dtype() == mshadow::kFloat16) && - inputs[0].shape().ndim() == 3) { + inputs[0].shape().ndim() == 3 && dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1)) { MKLDNNRnnOp& op = state_ptr.get_state(); op.Backward(ctx, inputs, req, outputs); } else { diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 7b0404d8abb7..66031d20d65b 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -36,15 +36,6 @@ import os def check_rnn_consistency(cell1, cell2, T, N, I, H, grad_req, rtol=1e-2, atol=1e-4): - if default_context().device_type == 'cpu': - # NOTE(zixuanweeei): Currently, we don't add `add` requests support on fused mkl-dnn rnn operator. - # We tracked this issue by https://github.com/apache/incubator-mxnet/issues/16578 - if isinstance(grad_req, dict) and 'add' in grad_req.values(): - print("Skip the test when requiring `add` operation against gradients on CPU context.") - return - if isinstance(grad_req, str) and grad_req == 'add': - print("Skip the test when requiring `add` operation against gradients on CPU context.") - return dshape = (N, T, I) data = mx.sym.Variable('data') @@ -182,9 +173,9 @@ def test_gru_sym(): stack.add(mx.rnn.GRUCell(H, prefix='l1_')) stack.add(mx.rnn.GRUCell(H, prefix='l2_')) - check_rnn_consistency(fused, stack, T, N, I, H, 'write', atol=2e-4) - check_rnn_consistency(fused, stack, T, N, I, H, 'add', atol=2e-4) - check_rnn_consistency(fused, stack, T, N, I, H, 'null', atol=2e-4) + check_rnn_consistency(fused, stack, T, N, I, H, 'write') + check_rnn_consistency(fused, stack, T, N, I, H, 'add') + check_rnn_consistency(fused, stack, T, N, I, H, 'null') @with_seed() @assert_raises_cudnn_not_satisfied(min_version='5.1.10') @@ -208,9 +199,9 @@ def test_gru_bidirectional(): mx.rnn.GRUCell(H, prefix='r1_'), output_prefix='bi_gru_1_')) - check_rnn_consistency(fused, stack, T, N, I, H, 'write', atol=2e-4) - check_rnn_consistency(fused, stack, T, N, I, H, 'add', atol=2e-4) - check_rnn_consistency(fused, stack, T, N, I, H, 'null', atol=2e-4) + check_rnn_consistency(fused, stack, T, N, I, H, 'write') + check_rnn_consistency(fused, stack, T, N, I, H, 'add') + check_rnn_consistency(fused, stack, T, N, I, H, 'null') @with_seed() @assert_raises_cudnn_not_satisfied(min_version='5.1.10')