diff --git a/src/operator/nn/mkldnn/mkldnn_rnn-inl.h b/src/operator/nn/mkldnn/mkldnn_rnn-inl.h new file mode 100644 index 000000000000..7557723499a9 --- /dev/null +++ b/src/operator/nn/mkldnn/mkldnn_rnn-inl.h @@ -0,0 +1,444 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file mkldnn_rnn-inl.h + * \brief Common functions used by MKLDNN RNN operator + * \author Zixuan Wei +*/ + +#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RNN_INL_H_ +#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RNN_INL_H_ + +#if MXNET_USE_MKLDNN == 100 + +#include +#include "../../rnn-inl.h" +#include "./mkldnn_base-inl.h" + +namespace mxnet { +namespace op { + +struct MKLDNNRnnLayerParam { + using memory = mkldnn::memory; + using dims = mkldnn::memory::dims; + + int mode; + bool bidirectional; + bool state_outputs; + int num_layer; + int batch_size; + int input_size; + int state_size; + int seq_len; + + dims src_dims; // Dimensions of source input in format_tag::tnc + dims weight_layer_dims; // Dimensions of layer weights in format_tag::ldigo + dims weight_iter_dims; // Dimensions of iter weights in format_tag::ldigo + dims bias_dims; // Dimensions of bias in format_tag::ldgo + dims dst_dims; // Dimensions of output in format_tag::tnc + dims state_dims; // Dimensions of the state cell in format_tag::ldnc + + size_t workspace_size; // used for the cached mkl-dnn memory in Forward inference + size_t reserve_size; // used for the reserved cached memory in Backward + size_t single_w_size; // weights size of a single cell + size_t single_b_size; // bias size of a single cell from mkl-dnn + size_t naive_single_b_size; // bias size of a single cell from framework + size_t single_state_size; // state size of a single cell, hy, cy + + MKLDNNRnnLayerParam(int num_layer, int batch_size, int seq_len, + int input_size, int state_size, + int mode, bool bidirectional = true) + : mode(mode), bidirectional(bidirectional), state_outputs(true), + num_layer(num_layer), batch_size(batch_size), input_size(input_size), + state_size(state_size), seq_len(seq_len) { } + + void SetDims(); +}; + +typedef std::vector LayerParamVector; +struct MKLDNNRnnFullParam { + RNNParam default_param; + LayerParamVector layer_params; +}; + +MKLDNNRnnFullParam MKLDNNRnnFullParamParser(const RNNParam& rnn_param, const int seq_len, + const int batch_size, const int input_size); + +/* + * Use this to allocate memory from MKLDNNRnnOp temporary space. + */ +class MKLDNNRnnMemMgr { + // The memory buffer in NDArray life-cycle + NDArray workspace_; + // This points to the memory buffer from a NDArray + char *curr_mem; + // The total bytes of the workspace of a MKLDNNRnnOp + size_t mem_size = 0; + // The current available memory bytes + size_t curr_size = 0; + const size_t alignment = kMKLDNNAlign; + const mkldnn::engine& cpu_engine = CpuEngine::Get()->get_engine(); + // Here we hold all memory related to the stateful RNN operators + std::vector > mem_holder; + + public: + void Init(dim_t size, const Context& ctx, int dtype = mshadow::kFloat32); + + void RegisterMem(std::shared_ptr mem) { + mem_holder.push_back(mem); + } + + mkldnn::memory *Alloc(const mkldnn::memory::desc &md); +}; + +/* + * Rnn Primitive. + */ +class RnnPrimitive { + public: + /* Create a RnnPrimitive with rnn type: + * lstm_forward, lbr_gru_forward, vanilla_rnn_forward + */ + template + static RnnPrimitive Create(Args&&... args) { + RnnPrimitive rnn_fwd_prim; + rnn_fwd_prim.pd_.reset( + new typename rnn_fwd::desc(std::forward(args)...), + [](typename rnn_fwd::desc* pd) { + delete reinterpret_cast(pd); + }); + const typename rnn_fwd::desc& fwd_desc = + *(reinterpret_cast(rnn_fwd_prim.pd_.get())); + typename rnn_fwd::primitive_desc fwd_pd(fwd_desc, CpuEngine::Get()->get_engine()); + rnn_fwd_prim.weights_layer_desc_ = fwd_pd.weights_layer_desc(); + rnn_fwd_prim.weights_iter_desc_ = fwd_pd.weights_iter_desc(); + rnn_fwd_prim.workspace_desc_ = fwd_pd.workspace_desc(); + + rnn_fwd_prim.primitive_ = std::shared_ptr(new rnn_fwd(fwd_pd)); + + return rnn_fwd_prim; + } + + RnnPrimitive() { + this->pd_ = nullptr; + this->primitive_ = nullptr; + this->weights_layer_desc_ = mkldnn::memory::desc(); + this->weights_iter_desc_ = mkldnn::memory::desc(); + this->workspace_desc_ = mkldnn::memory::desc(); + } + + RnnPrimitive(const RnnPrimitive& rnn_fwd_prim) { + this->pd_ = rnn_fwd_prim.pd_; + this->primitive_ = rnn_fwd_prim.primitive_; + this->weights_layer_desc_ = rnn_fwd_prim.weights_layer_desc_; + this->weights_iter_desc_ = rnn_fwd_prim.weights_iter_desc_; + this->workspace_desc_ = rnn_fwd_prim.workspace_desc_; + } + + RnnPrimitive& operator=(const RnnPrimitive& rnn_fwd_prim) { + if (this != &rnn_fwd_prim) { + this->pd_ = rnn_fwd_prim.pd_; + this->primitive_ = rnn_fwd_prim.primitive_; + this->weights_layer_desc_ = rnn_fwd_prim.weights_layer_desc_; + this->weights_iter_desc_ = rnn_fwd_prim.weights_iter_desc_; + this->workspace_desc_ = rnn_fwd_prim.workspace_desc_; + } + + return *this; + } + + const void* GetPrimDesc() const { return pd_.get(); } + const mkldnn::primitive& GetPrim() const { return *primitive_; } + + const mkldnn::memory::desc& GetLayerDesc() const { + return weights_layer_desc_; + } + + const mkldnn::memory::desc& GetIterDesc() const { + return weights_iter_desc_; + } + + const mkldnn::memory::desc& GetWorkspaceDesc() const { + return workspace_desc_; + } + + private: + std::shared_ptr pd_; + std::shared_ptr primitive_; + mkldnn::memory::desc weights_layer_desc_; + mkldnn::memory::desc weights_iter_desc_; + mkldnn::memory::desc workspace_desc_; +}; + +RnnPrimitive GetRnnFwdPrim(const MKLDNNRnnLayerParam &layer_param, const bool is_train, + const NDArray &data, const NDArray ¶ms); + +/* + * Use this to manage memory and primitive of MKL-DNN RNN forward inference. + */ +class MKLDNNRnnForward { + public: + MKLDNNRnnForward(const MKLDNNRnnLayerParam &layer_param, const bool is_train, + const NDArray &data, const NDArray ¶ms) + : initialized_(false), param_(layer_param), + fwd_inf_(GetRnnFwdPrim(layer_param, false, data, params)) { } + + void SetNewDataMem(void* x, void* hx, void* cx, + void* y, void* hy, void* cy, + const int dtype = mshadow::kFloat32); + void SetWeightsMem(MKLDNNRnnMemMgr* mgr, void* w_ptr, void* b_ptr, + const bool is_train = false, + const int dtype = mshadow::kFloat32); + void ReorderWeights(); + + const mkldnn::primitive& GetFwd() const { return fwd_inf_.GetPrim(); } + + const size_t GetSize(int dtype) const { + size_t bytes = mshadow::mshadow_sizeof(dtype); + size_t size = 0; + size += fwd_inf_.GetLayerDesc().get_size(); + size += fwd_inf_.GetIterDesc().get_size(); + return size / bytes + 1; + } + + const MKLDNNRnnLayerParam &GetParam() const { return param_; } + + const mkldnn_args_map_t &GetArgsMap() const { return net_args_; } + + const bool IsInitialized() const { return initialized_; } + void Reset() { initialized_ = false; } + + private: + bool initialized_; + MKLDNNRnnLayerParam param_; + RnnPrimitive fwd_inf_; // forward inference primitive + + mkldnn::memory *weights_layer_ = nullptr; + mkldnn::memory *weights_iter_ = nullptr; + mkldnn::memory *bias_ = nullptr; + + mkldnn::memory *weights_layer_r_ = nullptr; + mkldnn::memory *weights_iter_r_ = nullptr; + + /* + * net_args must contain some keys as below: + * MKLDNN_ARG_SRC + * MKLDNN_ARG_SRC_ITER + * MKLDNN_WEIGHTS_LAYER + * MKLDNN_WEIGHTS_ITER + * MKLDNN_BIAS + * MKLDNN_ARG_DST + * MKLDNN_ARG_DST_ITER + * if mode == Lstm, it also needs two additional key: + * MKLDNN_ARG_SRC_ITER_C + * MKLDNN_ARG_DST_ITER_C + */ + mkldnn_args_map_t net_args_; + + friend class MKLDNNRnnForwardTraining; +}; + +typedef std::shared_ptr mkldnn_shared_mem_t; +/* + * Use this to manage memory and primitive of MKL-DNN RNN forward training. + */ +class MKLDNNRnnForwardTraining { + public: + MKLDNNRnnForwardTraining(const MKLDNNRnnLayerParam &layer_param, const bool is_train, + const NDArray &data, const NDArray ¶ms) + : fwd_trn_(GetRnnFwdPrim(layer_param, is_train, data, params)), + param_(&layer_param) { } + + void SetTrnMem(const MKLDNNRnnForward& fwd); + void FetchData(const MKLDNNRnnForward& fwd); + + const MKLDNNRnnLayerParam& GetParam() const { return *param_; } + const void* GetPrimDesc() const { return fwd_trn_.GetPrimDesc(); } + const mkldnn::primitive& GetFwd() const { return fwd_trn_.GetPrim(); } + const mkldnn_args_map_t& GetArgsMap() const { return net_args_; } + + private: + RnnPrimitive fwd_trn_; + const MKLDNNRnnLayerParam* param_; + + mkldnn_shared_mem_t weights_layer_ = nullptr; + mkldnn_shared_mem_t weights_iter_ = nullptr; + mkldnn::memory* bias_ = nullptr; + + mkldnn_shared_mem_t workspace_ = nullptr; + + // Key MKLDNN_ARGS_WORKSPACE must be included in forward training + mkldnn_args_map_t net_args_; + + friend class MKLDNNRnnBackward; +}; + +/* + * Rnn Backward primitive + */ +class RnnBwdPrimitive { + public: + template + static RnnBwdPrimitive Create(typename rnn_fwd::primitive_desc const & fwd_pd, Args&&... args) { + RnnBwdPrimitive bwd_prim; + typename rnn_bwd::desc bwd_desc(std::forward(args)...); + typename rnn_bwd::primitive_desc bwd_pd(bwd_desc, CpuEngine::Get()->get_engine(), fwd_pd); + bwd_prim.weights_layer_desc_ = bwd_pd.weights_layer_desc(); + bwd_prim.weights_iter_desc_ = bwd_pd.weights_iter_desc(); + bwd_prim.diff_weights_layer_desc_ = bwd_pd.diff_weights_layer_desc(); + bwd_prim.diff_weights_iter_desc_ = bwd_pd.diff_weights_iter_desc(); + bwd_prim.diff_bias_desc_ = bwd_pd.diff_bias_desc(); + + bwd_prim.primitive_ = std::shared_ptr(new rnn_bwd(bwd_pd)); + + return bwd_prim; + } + + RnnBwdPrimitive() { + this->primitive_ = nullptr; + this->weights_layer_desc_ = mkldnn::memory::desc(); + this->weights_iter_desc_ = mkldnn::memory::desc(); + this->diff_weights_layer_desc_ = mkldnn::memory::desc(); + this->diff_weights_iter_desc_ = mkldnn::memory::desc(); + this->diff_bias_desc_ = mkldnn::memory::desc(); + } + + RnnBwdPrimitive(const RnnBwdPrimitive& bwd) { + this->primitive_ = bwd.primitive_; + this->weights_layer_desc_ = bwd.weights_layer_desc_; + this->weights_iter_desc_ = bwd.weights_iter_desc_; + this->diff_weights_layer_desc_ = bwd.diff_weights_layer_desc_; + this->diff_weights_iter_desc_ = bwd.diff_weights_iter_desc_; + this->diff_bias_desc_ = bwd.diff_bias_desc_; + } + + RnnBwdPrimitive& operator=(const RnnBwdPrimitive& bwd) { + if (this != &bwd) { + this->primitive_ = bwd.primitive_; + this->weights_layer_desc_ = bwd.weights_layer_desc_; + this->weights_iter_desc_ = bwd.weights_iter_desc_; + this->diff_weights_layer_desc_ = bwd.diff_weights_layer_desc_; + this->diff_weights_iter_desc_ = bwd.diff_weights_iter_desc_; + this->diff_bias_desc_ = bwd.diff_bias_desc_; + } + + return *this; + } + + private: + std::shared_ptr primitive_; + mkldnn::memory::desc weights_layer_desc_; + mkldnn::memory::desc weights_iter_desc_; + mkldnn::memory::desc diff_weights_layer_desc_; + mkldnn::memory::desc diff_weights_iter_desc_; + mkldnn::memory::desc diff_bias_desc_; + friend class MKLDNNRnnBackward; +}; +RnnBwdPrimitive GetRnnBwdPrim(const MKLDNNRnnForwardTraining& fwd, + const NDArray& data, const NDArray& params); + +/* + * Use this to manage memory and primitive of MKL-DNN RNN backward. + */ +class MKLDNNRnnBackward { + public: + MKLDNNRnnBackward(const MKLDNNRnnForwardTraining& fwd, + const NDArray& data, const NDArray& params) + : bwd_(GetRnnBwdPrim(fwd, data, params)), + fwd_ptr_(&fwd) { } + + void FetchDataWeightsMem(const MKLDNNRnnForwardTraining& fwd); + void SetWeightsGradsMem(); + void SetDataGradsMem(void* diff_src, void* diff_state, void* diff_statecell, + void* diff_out, void* diff_state_out, void* diff_statecell_out, + const int dtype = mshadow::kFloat32); + void CommitWeightsDiff(void* diff_weights, void* diff_bias, const int dtype = mshadow::kFloat32); + + const mkldnn::primitive& GetBwd() const { return *bwd_.primitive_; } + const mkldnn_args_map_t& GetArgsMap() const { return net_args_; } + + private: + bool initialized_; + RnnBwdPrimitive bwd_; + const MKLDNNRnnForwardTraining* fwd_ptr_; + + mkldnn_shared_mem_t weights_layer_; + mkldnn_shared_mem_t weights_iter_; + + mkldnn_shared_mem_t diff_weights_layer_; + mkldnn_shared_mem_t diff_weights_iter_; + mkldnn_shared_mem_t diff_bias_; + + mkldnn_args_map_t net_args_; +}; + +/* + * Use MKLDNNRnnOp to manage fused or unfused RNN layers. A MKLDNNRnnOp contains + * the parameter(s) and primitive(s) of RNN layer(s). According to the direction, + * input size, and state size, multple layers could be inplemented by unfused and + * fused layers - MKLDNNRnnForward, which holds the memory and forward primitive + * of MKL-DNN. + */ +class MKLDNNRnnOp { + public: + explicit MKLDNNRnnOp(const RNNParam ¶m, const int seq_len, + const int batch_size, const int input_size) + : initialized_(false), weights_version_(0), + full_param_(MKLDNNRnnFullParamParser(param, seq_len, batch_size, input_size)) { } + + void Forward(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs); + + void Backward(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs); + + const RNNParam& GetParam() const { return full_param_.default_param; } + + private: + bool initialized_; + size_t weights_version_; + MKLDNNRnnFullParam full_param_; + MKLDNNRnnMemMgr mgr_; + std::vector fwd_inf_vec_; // forward inference layers + std::vector fwd_trn_vec_; // forward training layers + std::vector bwd_vec_; // backward layers + + // Used to store the intermediate results of multi-layer + std::vector dst_; + + // Used to store the intermediate diff_src of multi_layer + mkldnn_shared_mem_t diff_src; + + void Init(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs); +}; + +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_MKLDNN == 100 +#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RNN_INL_H_ diff --git a/src/operator/nn/mkldnn/mkldnn_rnn.cc b/src/operator/nn/mkldnn/mkldnn_rnn.cc new file mode 100644 index 000000000000..2d1bde97121c --- /dev/null +++ b/src/operator/nn/mkldnn/mkldnn_rnn.cc @@ -0,0 +1,1109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file mkldnn_rnn.cc + * \brief Common functions used by MKLDNN RNN operator + * \author Zixuan Wei +*/ + +#if MXNET_USE_MKLDNN == 100 + +#include +#include "./mkldnn_rnn-inl.h" + +namespace mxnet { +namespace op { + +inline int GetRnnGatesNum(int mode) { + switch (mode) { + case rnn_enum::kLstm: + return 4; + case rnn_enum::kGru: + return 3; + case rnn_enum::kRnnRelu: + case rnn_enum::kRnnTanh: + return 1; + default: + LOG(FATAL) << "unsupported RNN mode:" << mode; + return -1; + } +} + +void MKLDNNRnnLayerParam::SetDims() { + const int ngates = GetRnnGatesNum(mode); + //* NOTES: LBR-GRU's new gate formula needs two bias. So it has one more bias with LBR-GRU + const int nbias = mode == rnn_enum::kGru ? (ngates + 1) : ngates; + const int num_direction = bidirectional ? 2 : 1; + + src_dims.assign({seq_len, batch_size, input_size}); + weight_layer_dims.assign({num_layer, num_direction, input_size, ngates, state_size}); + weight_iter_dims.assign({num_layer, num_direction, state_size, ngates, state_size}); + bias_dims.assign({num_layer, num_direction, nbias, state_size}); + dst_dims.assign({seq_len, batch_size, state_size * num_direction}); + state_dims.assign({num_layer, num_direction, batch_size, state_size}); + + // unidirectional size of a single cell + single_w_size = (input_size + state_size) * ngates * state_size; + single_b_size = nbias * state_size; + naive_single_b_size = ngates * state_size * 2; // naive RNN variants have double bias + single_state_size = batch_size * state_size; + + // Get workspace size for cached weights memory + // multiplication of tensor dimensions + static auto tz_volume = [](const memory::dims& tz_dims) { + return std::accumulate(tz_dims.begin(), tz_dims.end(), static_cast(1), + std::multiplies()); + }; + + workspace_size = tz_volume(weight_layer_dims) + tz_volume(weight_iter_dims) + + tz_volume(bias_dims); + reserve_size = 0; +} + +MKLDNNRnnFullParam MKLDNNRnnFullParamParser(const RNNParam& rnn_param, const int seq_len, + const int batch_size, const int input_size) { + MKLDNNRnnFullParam full_param; + full_param.default_param = rnn_param; + size_t state_size = rnn_param.state_size; + LayerParamVector &layer_params = full_param.layer_params; + + full_param.default_param.seq_length_ = seq_len; + full_param.default_param.batch_size_ = batch_size; + full_param.default_param.input_size_ = input_size; + // Set basic size by constructing MKLDNNRnnLayerParam instance(s) + if (rnn_param.bidirectional) { // unfused bidirectional multi-layer RNN + layer_params.emplace_back(1, batch_size, seq_len, input_size, state_size, rnn_param.mode); + for (size_t layer = 1; layer < rnn_param.num_layers; ++layer) { + layer_params.emplace_back(1, batch_size, seq_len, state_size * 2, state_size, + rnn_param.mode); + } + } else if (input_size == static_cast(state_size)) { // fused multi-layer RNN + layer_params.emplace_back(rnn_param.num_layers, batch_size, seq_len, input_size, + state_size, rnn_param.mode, false); + } else { // unfused 1st layer, plus fused 2-end layers + layer_params.emplace_back(1, batch_size, seq_len, input_size, state_size, rnn_param.mode, + false); + if (rnn_param.num_layers > 1) + layer_params.emplace_back(rnn_param.num_layers - 1, batch_size, seq_len, state_size, + state_size, rnn_param.mode, false); + } + + // Set dims, workspace size, and state_outputs flag + for (auto& layer_param : layer_params) { + layer_param.SetDims(); + layer_param.state_outputs = rnn_param.state_outputs; + } + return full_param; +} + +void MKLDNNRnnMemMgr::Init(dim_t size, const Context& ctx, int dtype) { + workspace_ = NDArray(TShape({size}), ctx, false, dtype); + curr_mem = static_cast(workspace_.data().dptr_); + mem_size = size * mshadow::mshadow_sizeof(dtype); + curr_size = size * mshadow::mshadow_sizeof(dtype); +} + +mkldnn::memory *MKLDNNRnnMemMgr::Alloc(const mkldnn::memory::desc &md) { + if (curr_mem == nullptr) { + curr_mem = static_cast(workspace_.data().dptr_); + } + + mkldnn_mem_ptr ret(new mkldnn::memory()); + size_t addr = reinterpret_cast(curr_mem); + size_t last_chunk = addr % alignment; + size_t padding = alignment - last_chunk; + addr += padding; + CHECK_EQ(addr % alignment, 0); + + curr_size -= (md.get_size() + padding); + if (curr_size < 0) { + ret.reset(new mkldnn::memory(md, cpu_engine)); + } else { + curr_mem += (md.get_size() + padding); + ret.reset(new mkldnn::memory(md, cpu_engine, reinterpret_cast(addr))); + } + RegisterMem(ret); + return ret.get(); +} + +RnnPrimitive GetRnnFwdPrim( + const MKLDNNRnnLayerParam &layer_param, const bool is_train, + const NDArray &data, const NDArray ¶ms) { + using namespace mkldnn; + using tag = mkldnn::memory::format_tag; + const int mode = layer_param.mode; + memory::data_type data_type = get_mkldnn_type(data.dtype()); + memory::data_type weight_type = get_mkldnn_type(params.dtype()); + const prop_kind prop = is_train ? prop_kind::forward_training : prop_kind::forward_inference; + const rnn_direction mkldnn_rnn_direction = layer_param.bidirectional ? + rnn_direction::bidirectional_concat : rnn_direction::unidirectional; + + auto src_layer_desc = memory::desc(layer_param.src_dims, data_type, tag::tnc); + auto weight_layer_desc = memory::desc(layer_param.weight_layer_dims, weight_type, tag::any); + auto weight_iter_desc = memory::desc(layer_param.weight_iter_dims, weight_type, tag::any); + auto bias_desc = memory::desc(layer_param.bias_dims, data_type, tag::ldgo); + auto dst_layer_desc = memory::desc(layer_param.dst_dims, data_type, tag::tnc); + auto src_state_desc = memory::desc(layer_param.state_dims, data_type, tag::ldnc); + auto dst_state_desc = layer_param.state_outputs ? memory::desc( + layer_param.state_dims, data_type, tag::ldnc) : memory::desc(); + + auto fwd = RnnPrimitive(); + switch (mode) { + case rnn_enum::kLstm: + fwd = RnnPrimitive::Create(prop, mkldnn_rnn_direction, + src_layer_desc, src_state_desc, src_state_desc, weight_layer_desc, + weight_iter_desc, bias_desc, dst_layer_desc, dst_state_desc, + dst_state_desc); + break; + case rnn_enum::kGru: + fwd = RnnPrimitive::Create(prop, mkldnn_rnn_direction, + src_layer_desc, src_state_desc, weight_layer_desc, + weight_iter_desc, bias_desc, dst_layer_desc, dst_state_desc); + break; + case rnn_enum::kRnnRelu: + case rnn_enum::kRnnTanh: + fwd = RnnPrimitive::Create(prop, + mode == rnn_enum::kRnnTanh ? algorithm::eltwise_tanh : algorithm::eltwise_relu, + mkldnn_rnn_direction, src_layer_desc, src_state_desc, weight_layer_desc, + weight_iter_desc, bias_desc, dst_layer_desc, dst_state_desc); + break; + default: + LOG(FATAL) << "unsupported RNN mode:" << mode; + break; + } + return fwd; +} + +RnnBwdPrimitive GetRnnBwdPrim(const MKLDNNRnnForwardTraining &fwd, + const NDArray &data, const NDArray ¶ms) { + using namespace mkldnn; + using tag = mkldnn::memory::format_tag; + const MKLDNNRnnLayerParam& layer_param = fwd.GetParam(); + const int mode = layer_param.mode; + memory::data_type data_type = get_mkldnn_type(data.dtype()); + memory::data_type weight_type = get_mkldnn_type(params.dtype()); + const prop_kind prop = prop_kind::backward; + rnn_direction mkldnn_rnn_direction = layer_param.bidirectional ? + rnn_direction::bidirectional_concat : rnn_direction::unidirectional; + + auto src_layer_desc = memory::desc(layer_param.src_dims, data_type, tag::tnc); + auto weight_layer_desc = memory::desc(layer_param.weight_layer_dims, weight_type, tag::any); + auto weight_iter_desc = memory::desc(layer_param.weight_iter_dims, weight_type, tag::any); + auto bias_desc = memory::desc(layer_param.bias_dims, data_type, tag::ldgo); + auto dst_layer_desc = memory::desc(layer_param.dst_dims, data_type, tag::tnc); + auto src_state_desc = memory::desc(layer_param.state_dims, data_type, tag::ldnc); + auto dst_state_desc = layer_param.state_outputs ? memory::desc( + layer_param.state_dims, data_type, tag::ldnc) : memory::desc(); + + const void* fwd_desc = fwd.GetPrimDesc(); + auto bwd = RnnBwdPrimitive(); + switch (mode) { + case rnn_enum::kLstm: { + const lstm_forward::primitive_desc* desc = + reinterpret_cast(fwd_desc); + bwd = RnnBwdPrimitive::Create(*desc, + prop, mkldnn_rnn_direction, + // data desc + src_layer_desc, src_state_desc, src_state_desc, weight_layer_desc, + weight_iter_desc, bias_desc, dst_layer_desc, dst_state_desc, + dst_state_desc, + // diff desc + src_layer_desc, src_state_desc, src_state_desc, weight_layer_desc, + weight_iter_desc, bias_desc, dst_layer_desc, dst_state_desc, + dst_state_desc); + } break; + case rnn_enum::kGru: { + const lbr_gru_forward::primitive_desc* desc = + reinterpret_cast(fwd_desc); + bwd = RnnBwdPrimitive::Create(*desc, + prop, mkldnn_rnn_direction, + // data desc + src_layer_desc, src_state_desc, weight_layer_desc, + weight_iter_desc, bias_desc, dst_layer_desc, dst_state_desc, + // diff desc + src_layer_desc, src_state_desc, weight_layer_desc, + weight_iter_desc, bias_desc, dst_layer_desc, dst_state_desc); + } break; + case rnn_enum::kRnnRelu: + case rnn_enum::kRnnTanh: { + const vanilla_rnn_forward::primitive_desc* desc = + reinterpret_cast(fwd_desc); + bwd = RnnBwdPrimitive::Create( + *desc, prop, + mode == rnn_enum::kRnnTanh ? algorithm::eltwise_tanh : algorithm::eltwise_relu, + mkldnn_rnn_direction, + // data desc + src_layer_desc, src_state_desc, weight_layer_desc, + weight_iter_desc, bias_desc, dst_layer_desc, dst_state_desc, + // diff desc + src_layer_desc, src_state_desc, weight_layer_desc, + weight_iter_desc, bias_desc, dst_layer_desc, dst_state_desc); + } break; + default: + LOG(FATAL) << "unsupported RNN mode:" << mode; + break; + } + return bwd; +} + +/* + * Naive weights layout is: + * | l0_l2r_wx | l0_l2r_wh | l0_r2l_wx | l0_r2l_wh | + * | l1_l2r_wx | l1_l2r_wh | l1_r2l_wx | l1_r2l_wh | + * ... + * + * We need concat them to be: + * | l0_l2r_wx | l0_r2l_wx | l1_l2r_wx | l1_r2l_wx | + * | l0_l2r_wh | l0_r2l_wh | l1_l2r_wh | l1_r2l_wh | + * ... + * + * All the memory blocks are in goi format. + */ +static void ConcatWeights(const mkldnn::memory &dst, + const int concat_dimension, + const std::vector &src_ptrs, + const mkldnn::memory::format_tag src_format) { + using memory = mkldnn::memory; + auto cpu_engine = dst.get_engine(); + mkldnn::stream s(cpu_engine); + const memory::desc& dst_desc = dst.get_desc(); + // Use dst memory dims to initialize src memory dims, then set the concat + // dim to 1. And Rnn weights are 5-dimension tensor. + memory::dims src_dims(dst_desc.data.dims, dst_desc.data.dims + 5); + src_dims.at(concat_dimension) = 1; + std::vector src_descs; + std::unordered_map concat_args; + + for (size_t i = 0; i < src_ptrs.size(); ++i) { + src_descs.emplace_back(src_dims, + static_cast(dst_desc.data.data_type), src_format); + concat_args.emplace(MKLDNN_ARG_MULTIPLE_SRC + i, + memory(src_descs.back(), cpu_engine, src_ptrs.at(i))); + } + concat_args.emplace(MKLDNN_ARG_DST, dst); + + auto concat_pd = mkldnn::concat::primitive_desc(dst.get_desc(), + concat_dimension, src_descs, cpu_engine); + mkldnn::concat(concat_pd).execute(s, concat_args); +} + +#define RNN_HANDLE_FUNC_NAME set_handle +#define RNN_HANDLE_FUNC(RNN_FUNC_NAME) \ +auto RNN_FUNC_NAME = [&cpu_engine, &args](int arg_name, const desc& md, \ + void* handle) { \ + if (args.find(arg_name) != args.end()) { \ + if (handle != nullptr) args.at(arg_name).set_data_handle(handle); \ + } else { \ + args[arg_name] = handle ? mkldnn::memory(md, cpu_engine, handle) \ + : mkldnn::memory(md, cpu_engine); \ + } \ +} + +#define RNN_FWD_SET(NAME, DIMS, TAG, HANDLE, DTYPE) \ +RNN_FWD_SET_(RNN_HANDLE_FUNC_NAME, NAME, DIMS, TAG, HANDLE, DTYPE) + +#define RNN_FWD_SET_(FUNC, NAME, DIMS, TAG, HANDLE, DTYPE) \ +FUNC(MKLDNN_ARG_##NAME, {DIMS, get_mkldnn_type(DTYPE), TAG}, HANDLE) + +#define RNN_BWD_SET(NAME, ARGS, HANDLE) \ +RNN_BWD_SET_(RNN_HANDLE_FUNC_NAME, NAME, ARGS, HANDLE) + +#define RNN_BWD_SET_(FUNC, NAME, ARGS, HANDLE) \ +FUNC(MKLDNN_ARG_DIFF_##NAME, ARGS.at(MKLDNN_ARG_##NAME).get_desc(), HANDLE) + +/* + * Set new src data handler to Forward memory. The memory primitives are + * not initialized until SetNewDataMem is first invoked. Src data handler + * must not be nullptr, except for cx with LSTM mode. If either hy, cy is + * nullptr, it may run with non-state_ouput or non-LSTM mode. Thus, the + * corresponding memory should be a empty mkldnn::memory(). + */ +void MKLDNNRnnForward::SetNewDataMem(void* x, void* hx, void* cx, + void* y, void* hy, void* cy, + const int dtype) { + using dims = mkldnn::memory::dims; + using desc = mkldnn::memory::desc; + using format_tag = mkldnn::memory::format_tag; + auto& cpu_engine = CpuEngine::Get()->get_engine(); + mkldnn_args_map_t& args = net_args_; + + RNN_HANDLE_FUNC(RNN_HANDLE_FUNC_NAME); + + // Set various data memory + RNN_FWD_SET(SRC, param_.src_dims, format_tag::tnc, x, dtype); + RNN_FWD_SET(DST, param_.dst_dims, format_tag::tnc, y, dtype); + RNN_FWD_SET(SRC_ITER, param_.state_dims, format_tag::ldnc, hx, dtype); + + if (param_.state_outputs) { + RNN_FWD_SET(DST_ITER, param_.state_dims, format_tag::ldnc, hy, dtype); + } + + if (param_.mode == rnn_enum::kLstm) { + RNN_FWD_SET(SRC_ITER_C, param_.state_dims, format_tag::ldnc, cx, dtype); + if (param_.state_outputs) { + RNN_FWD_SET(DST_ITER_C, param_.state_dims, format_tag::ldnc, cy, dtype); + } + } +} + +/* + * Reorder the concatenated weights memory to a efficient memory block + * with primitive-prefered format. + */ +void MKLDNNRnnForward::ReorderWeights() { + auto& cpu_engine = CpuEngine::Get()->get_engine(); + mkldnn::stream s(cpu_engine); + mkldnn::reorder(*weights_layer_r_, *weights_layer_) + .execute(s, *weights_layer_r_, *weights_layer_); + mkldnn::reorder(*weights_iter_r_, *weights_iter_) + .execute(s, *weights_iter_r_, *weights_iter_); + s.wait(); +} + +void AdjustGruGateOrder(char* weight, + const size_t input_size, + const size_t hidden_size, + const int dtype) { + // mxnet gru gate order is reset, update and new gates + // mkldnn gru gate order is update, reset and new gates + size_t single_weight_bytes = input_size * hidden_size * mshadow::mshadow_sizeof(dtype); + char* weight_reset = weight; + char* weight_update = weight + single_weight_bytes; + std::swap_ranges(weight_reset, weight_update, weight_update); +} + +/* + * Fuse uni-directional bias among single layer. + */ +template +void FuseBias(DType* fuse_bias, DType* naive_bias, + const int mode, const size_t state_size) { + const size_t ngates = GetRnnGatesNum(mode); + const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); + const size_t nbias = mode == rnn_enum::kGru ? ngates + 1 : ngates; + // MSVC-14.0 (OpenMP 2.0 compatible) doesn't support unsigned integral type in + // OpenMP 'for' statement. + const int state_size_ = static_cast(state_size); + const int single_b_sz = static_cast(nbias * state_size); + DType* bx = naive_bias; + DType* bh = naive_bias + state_size * ngates; + if (mode == rnn_enum::kGru) { + // While mxnet gru gate order is reset, update and new gates, + // mkldnn gru gate order is update, reset and new gates. So + // we need to swap the order of reset and update from mxnet. + #pragma omp parallel for num_threads(omp_threads) + for (int j = 0; j < state_size_; j++) { + // Swap summed reset, update bias + fuse_bias[j + state_size] = bx[j] + bh[j]; + fuse_bias[j] = bx[j + state_size] + bh[j + state_size]; + + // Memcpy two new gates + fuse_bias[j + 2 * state_size] = bx[j + 2 * state_size]; + fuse_bias[j + 3 * state_size] = bh[j + 2 * state_size]; + } + } else { + #pragma omp parallel for num_threads(omp_threads) + for (int j = 0; j < single_b_sz; ++j) { + // Sum two bias + fuse_bias[j] = bx[j] + bh[j]; + } + } +} + +inline void EmplaceNetArgs(mkldnn_args_map_t* net_args, const int arg_name, + const mkldnn::memory* mem) { + if (net_args->find(arg_name) != net_args->end()) { + if (net_args->at(arg_name).get_data_handle() == mem->get_data_handle()) { + return; + } else { + net_args->at(arg_name).set_data_handle(mem->get_data_handle()); + } + } else { + net_args->emplace(arg_name, *mem); + } +} + +/* + * Copy naive memory to mkldnn-format memory. It will initialize the memory + * when first invoked. Then, the naive weight_layer and weight_iter are + * concatenated to xxx_xx_r memory. Per the different gates order of GRU, + * it will swap the memory blocks of gates among concatenated memory + * inplace. From then on, the xxx_xx_r memory is reordered to target + * memory with preferred format_tag. Finally, naive bias is fused to MKLDNN + * bias memory. + */ +void MKLDNNRnnForward::SetWeightsMem(MKLDNNRnnMemMgr* mgr, void *w_ptr, void *b_ptr, + const bool is_train, const int dtype) { + using format_tag = mkldnn::memory::format_tag; + auto mkldnn_dtype = get_mkldnn_type(dtype); + // Get the weights' memory for RNN forward primitive + if (weights_layer_ == nullptr) { + weights_layer_ = mgr->Alloc(fwd_inf_.GetLayerDesc()); + } + if (weights_iter_ == nullptr) { + weights_iter_ = mgr->Alloc(fwd_inf_.GetIterDesc()); + } + if (bias_ == nullptr) { + bias_ = mgr->Alloc( + {param_.bias_dims, mkldnn_dtype, format_tag::ldgo}); + } + + // Get the intermediate memory for weights concat & reorder + if (weights_layer_r_ == nullptr) { + weights_layer_r_ = mgr->Alloc( + {param_.weight_layer_dims, mkldnn_dtype, format_tag::ldgoi}); + } + if (weights_iter_r_ == nullptr) { + weights_iter_r_ = mgr->Alloc( + {param_.weight_iter_dims, mkldnn_dtype, format_tag::ldgoi}); + } + + // Get the bytes of a real type + size_t dtype_bytes = mshadow::mshadow_sizeof(dtype); + + // convert void* to char* for arithmetic operations + char *weights_ptr = static_cast(w_ptr); + size_t wx_bytes = GetRnnGatesNum(param_.mode) * param_.state_size * + param_.input_size * dtype_bytes; //* DIMS: ngates x state_size x input_size + char *l2r_wx = weights_ptr; + char *l2r_wh = l2r_wx + wx_bytes; //* DIMS: ngates x state_size * state_size + + if (param_.num_layer == 1 && param_.bidirectional) { + //* single bidirectinal layer, concat weights on direction axis + char *r2l_wx = weights_ptr + param_.single_w_size * dtype_bytes; + char *r2l_wh = r2l_wx + wx_bytes; //* DIMS: ngates x state_size * state_size + ConcatWeights(*weights_layer_r_, 1, {l2r_wx, r2l_wx}, format_tag::ldgoi); + ConcatWeights(*weights_iter_r_, 1, {l2r_wh, r2l_wh}, format_tag::ldgoi); + } else if (param_.num_layer == 1 && !param_.bidirectional) { + //* single uni-directional layer, no concatenate operator needed + weights_layer_r_->set_data_handle(l2r_wx); + weights_iter_r_->set_data_handle(l2r_wh); + } else if (param_.num_layer > 1 && !param_.bidirectional) { + //* concat fused multi-layer weights on layer axis + std::vector l2r_wx_ptrs; + std::vector l2r_wh_ptrs; + for (int lyr = 0; lyr < param_.num_layer; ++lyr) { + char *lth_wx = l2r_wx + lyr * param_.single_w_size * dtype_bytes; + char *lth_wh = lth_wx + wx_bytes; + l2r_wx_ptrs.push_back(lth_wx); + l2r_wh_ptrs.push_back(lth_wh); + } + ConcatWeights(*weights_layer_r_, 0, l2r_wx_ptrs, format_tag::ldgoi); + ConcatWeights(*weights_iter_r_, 0, l2r_wh_ptrs, format_tag::ldgoi); + } else { + LOG(FATAL) << "Undifined RNN fusion workflow for num_layer = " << param_.num_layer + << ", and bidirectional is " << param_.bidirectional; + } + + // Adjust gates order of LBR-GRU among concatenated memory inplace. + //* DIMS: ngates x state_size x state_size (ngates = 3, when mode == gru) + size_t wh_bytes = 3 * param_.state_size * param_.state_size * dtype_bytes; + char* fused_wx = static_cast(weights_layer_r_->get_data_handle()); + char* fused_wh = static_cast(weights_iter_r_->get_data_handle()); + if (param_.mode == rnn_enum::kGru) { + for (size_t lyr = 0; lyr < static_cast(param_.num_layer); ++lyr) { + for (size_t d = 0; d < param_.bidirectional + 1U; ++d) { + AdjustGruGateOrder(fused_wx, param_.input_size, param_.state_size, dtype); + AdjustGruGateOrder(fused_wh, param_.state_size, param_.state_size, dtype); + fused_wx += wx_bytes; + fused_wh += wh_bytes; + } + } + } + // Reorder after adjustment only when is_train == false. When is_train == true, i.e. + // in forward training path, we use plain memory (ldxxx) as the space for weights and + // their gradients. Then, forward training primitives could fetch them from the scope + // of forward inference. And from there, we don't need to reorder the plain memory to + // the optimal rnn-packed memory for forward inference. + if (!is_train) + ReorderWeights(); + + // Process bias + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { + DType* naive_b_ptr = static_cast(b_ptr); + DType* fused_bias = static_cast(bias_->get_data_handle()); + for (int lyr = 0; lyr < param_.num_layer; ++lyr) { + for (int d = 0; d < param_.bidirectional + 1; ++d) { + FuseBias(fused_bias, naive_b_ptr, param_.mode, param_.state_size); + fused_bias += param_.single_b_size; + naive_b_ptr += param_.naive_single_b_size; + } + } + }); + + // insert weights into net_args + if (!is_train) { + EmplaceNetArgs(&this->net_args_, MKLDNN_ARG_WEIGHTS_LAYER, this->weights_layer_); + EmplaceNetArgs(&this->net_args_, MKLDNN_ARG_WEIGHTS_ITER, this->weights_iter_); + EmplaceNetArgs(&this->net_args_, MKLDNN_ARG_BIAS, this->bias_); + } + initialized_ = true; +} + +void MKLDNNRnnForwardTraining::SetTrnMem(const MKLDNNRnnForward& fwd) { + using memory = mkldnn::memory; + const auto& cpu_engine = CpuEngine::Get()->get_engine(); + auto s = mkldnn::stream(cpu_engine); + // Prepare mkldnn::memorys for weights_layer, weight_iter, and workspace + if (workspace_ == nullptr) + workspace_ = mkldnn_shared_mem_t(new memory(fwd_trn_.GetWorkspaceDesc(), cpu_engine)); + if (weights_layer_ == nullptr) + weights_layer_ = mkldnn_shared_mem_t(new memory(fwd_trn_.GetLayerDesc(), cpu_engine)); + if (weights_iter_ == nullptr) + weights_iter_ = mkldnn_shared_mem_t(new memory(fwd_trn_.GetIterDesc(), cpu_engine)); + + // fill weights memory using the reordered weights of fwd_inference primitive + if (fwd.weights_layer_r_->get_desc() == fwd_trn_.GetLayerDesc()) { + weights_layer_->set_data_handle(fwd.weights_layer_r_->get_data_handle()); + } else { + mkldnn::reorder(*fwd.weights_layer_r_, *weights_layer_) + .execute(s, *fwd.weights_layer_r_, *weights_layer_); + } + + if (fwd.weights_iter_r_->get_desc() == fwd_trn_.GetIterDesc()) { + weights_iter_->set_data_handle(fwd.weights_iter_r_->get_data_handle()); + } else { + mkldnn::reorder(*fwd.weights_iter_r_, *weights_iter_) + .execute(s, *fwd.weights_iter_r_, *weights_iter_); + } + s.wait(); + + // bias are always in format_tag::ldgo + this->bias_ = fwd.bias_; + + // insert weights into net_args + EmplaceNetArgs(&this->net_args_, MKLDNN_ARG_WEIGHTS_LAYER, this->weights_layer_.get()); + EmplaceNetArgs(&this->net_args_, MKLDNN_ARG_WEIGHTS_ITER, this->weights_iter_.get()); + EmplaceNetArgs(&this->net_args_, MKLDNN_ARG_BIAS, this->bias_); + EmplaceNetArgs(&this->net_args_, MKLDNN_ARG_WORKSPACE, this->workspace_.get()); +} + +void MKLDNNRnnForwardTraining::FetchData(const MKLDNNRnnForward& fwd) { + for (auto& kv : fwd.net_args_) { + switch (kv.first) { + case MKLDNN_ARG_WEIGHTS_LAYER: + case MKLDNN_ARG_WEIGHTS_ITER: + case MKLDNN_ARG_BIAS: + case MKLDNN_ARG_WORKSPACE: + continue; + + default: + EmplaceNetArgs(&this->net_args_, kv.first, &kv.second); + } + } +} + +void MKLDNNRnnOp::Init(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using memory = mkldnn::memory; + using format_tag = mkldnn::memory::format_tag; + + const size_t num_fusion = full_param_.layer_params.size(); + if (fwd_inf_vec_.size() < num_fusion) { + size_t buffer_size = 0; // Element number, instead of bytes, in the buffer + for (auto& layer_param : full_param_.layer_params) { + buffer_size += layer_param.workspace_size + layer_param.reserve_size; + } + buffer_size += outputs[rnn_enum::kOut].data().Size() * (num_fusion - 1); + buffer_size += kMKLDNNAlign * num_fusion * 5; // Add margin for alignment + + for (auto& layer_param : full_param_.layer_params) { + fwd_inf_vec_.emplace_back(layer_param, + ctx.is_train, inputs[rnn_enum::kData], inputs[rnn_enum::kParams]); + buffer_size += fwd_inf_vec_.back().GetSize(inputs[rnn_enum::kParams].dtype()); + } + mgr_.Init(buffer_size, ctx.run_ctx.ctx, inputs[rnn_enum::kParams].dtype()); + } + + if (ctx.is_train && fwd_trn_vec_.size() < num_fusion) { + for (auto& layer_param : full_param_.layer_params) { + fwd_trn_vec_.emplace_back(layer_param, + true, inputs[rnn_enum::kData], inputs[rnn_enum::kParams]); + } + } + + // Get the bytes of a real type + const NDArray &weights = inputs[rnn_enum::kParams]; + int dtype = weights.dtype(); + size_t dtype_bytes = mshadow::mshadow_sizeof(dtype); + + const RNNParam &default_param = full_param_.default_param; + char *weights_ptr = static_cast(weights.data().dptr_); + char *bias_ptr = weights_ptr + (weights.data().Size() - + GetRnnBiasSize(default_param.num_layers, default_param.state_size, + default_param.bidirectional + 1, default_param.mode)) * dtype_bytes; + for (auto& fwd_layer : fwd_inf_vec_) { + size_t single_w_bytes = fwd_layer.GetParam().single_w_size * dtype_bytes; + size_t single_b_bytes = fwd_layer.GetParam().naive_single_b_size * dtype_bytes; + size_t directions = fwd_layer.GetParam().bidirectional ? 2 : 1; + size_t layer_weights_bytes = single_w_bytes * directions; + size_t layer_bias_bytes = single_b_bytes * directions; // Naive MXNet has double bias + + if (!fwd_layer.IsInitialized() || ctx.is_train) + fwd_layer.SetWeightsMem(&(this->mgr_), weights_ptr, bias_ptr, dtype); + weights_ptr += layer_weights_bytes; + bias_ptr += layer_bias_bytes; + } + + if (ctx.is_train) { + CHECK_EQ(fwd_trn_vec_.size(), fwd_inf_vec_.size()) << + "Layers' configurations of forward inference and forward training are disparate."; + for (size_t lyr = 0; lyr < fwd_inf_vec_.size(); ++lyr) + fwd_trn_vec_.at(lyr).SetTrnMem(fwd_inf_vec_.at(lyr)); + } + + CHECK_EQ(num_fusion, fwd_inf_vec_.size()) << + "Layer vector's size has a different value than the number of fusion."; + if (dst_.size() < num_fusion - 1) { + int data_dtype = outputs[rnn_enum::kOut].dtype(); + for (auto fwd = fwd_inf_vec_.begin(); fwd < fwd_inf_vec_.end() - 1; ++fwd) + dst_.push_back(mgr_.Alloc( + {fwd->GetParam().dst_dims, get_mkldnn_type(data_dtype), format_tag::tnc})); + } + + initialized_ = true; +} + +void MKLDNNRnnBackward::FetchDataWeightsMem(const MKLDNNRnnForwardTraining& fwd) { + using memory = mkldnn::memory; + auto& cpu_engine = CpuEngine::Get()->get_engine(); + auto s = mkldnn::stream(cpu_engine); + + if (this->weights_layer_ == nullptr) + this->weights_layer_ = mkldnn_shared_mem_t(new memory(bwd_.weights_layer_desc_, cpu_engine)); + if (this->weights_iter_ == nullptr) + this->weights_iter_ = mkldnn_shared_mem_t(new memory(bwd_.weights_iter_desc_, cpu_engine)); + + for (auto& kv : fwd.net_args_) { + const mkldnn::memory* valid_mem; + switch (kv.first) { + case MKLDNN_ARG_WEIGHTS_LAYER: { + if (bwd_.weights_layer_desc_ == fwd.fwd_trn_.GetLayerDesc()) { + this->weights_layer_->set_data_handle(kv.second.get_data_handle()); + } else { + mkldnn::reorder(*fwd.weights_layer_, *this->weights_layer_) + .execute(s, *fwd.weights_layer_, *this->weights_layer_); + } + valid_mem = this->weights_layer_.get(); + } break; + case MKLDNN_ARG_WEIGHTS_ITER: { + if (bwd_.weights_iter_desc_ == fwd.fwd_trn_.GetLayerDesc()) { + this->weights_iter_->set_data_handle(kv.second.get_data_handle()); + } else { + mkldnn::reorder(*fwd.weights_iter_, *this->weights_iter_) + .execute(s, *fwd.weights_iter_, *this->weights_iter_); + } + valid_mem = this->weights_iter_.get(); + } break; + + default: + valid_mem = &kv.second; + } + EmplaceNetArgs(&this->net_args_, kv.first, valid_mem); + } + s.wait(); +} + +void MKLDNNRnnBackward::SetWeightsGradsMem() { + auto& cpu_engine = CpuEngine::Get()->get_engine(); + if (this->diff_weights_layer_ == nullptr) + this->diff_weights_layer_ = std::make_shared( + bwd_.diff_weights_layer_desc_, cpu_engine); + if (this->diff_weights_iter_ == nullptr) + this->diff_weights_iter_ = std::make_shared( + bwd_.diff_weights_iter_desc_, cpu_engine); + if (this->diff_bias_ == nullptr) + this->diff_bias_ = std::make_shared( + bwd_.diff_bias_desc_, cpu_engine); + EmplaceNetArgs(&this->net_args_, MKLDNN_ARG_DIFF_WEIGHTS_LAYER, + this->diff_weights_layer_.get()); + EmplaceNetArgs(&this->net_args_, MKLDNN_ARG_DIFF_WEIGHTS_ITER, + this->diff_weights_iter_.get()); + EmplaceNetArgs(&this->net_args_, MKLDNN_ARG_DIFF_BIAS, + this->diff_bias_.get()); +} + +void MKLDNNRnnBackward::SetDataGradsMem( + void* diff_src, void* diff_state, void* diff_statecell, + void* diff_dst, void* diff_state_out, void* diff_statecell_out, + const int dtype) { + using desc = mkldnn::memory::desc; + auto& cpu_engine = CpuEngine::Get()->get_engine(); + mkldnn_args_map_t& args = this->net_args_; + + RNN_HANDLE_FUNC(RNN_HANDLE_FUNC_NAME); + + // Set various diff memory + auto& fwd_args = fwd_ptr_->GetArgsMap(); + RNN_BWD_SET(SRC, fwd_args, diff_src); + RNN_BWD_SET(SRC_ITER, fwd_args, diff_state); + RNN_BWD_SET(DST, fwd_args, diff_dst); + + if (fwd_ptr_->GetParam().state_outputs) + RNN_BWD_SET(DST_ITER, fwd_args, diff_state_out); + + if (fwd_ptr_->GetParam().mode == rnn_enum::kLstm) { + RNN_BWD_SET(SRC_ITER_C, fwd_args, diff_statecell); + if (fwd_ptr_->GetParam().state_outputs) { + RNN_BWD_SET(DST_ITER_C, fwd_args, diff_statecell_out); + } + } +} + +template +void HalveWeightsDiff(DType* w, const size_t size) { + const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < static_cast(size); ++i) { + w[i] *= 0.5; + } +} + +void MKLDNNRnnBackward::CommitWeightsDiff(void* diff_weights, void* diff_bias, int dtype) { + using tag = mkldnn::memory::format_tag; + auto& cpu_engine = CpuEngine::Get()->get_engine(); + auto s = mkldnn::stream(cpu_engine); + + const MKLDNNRnnLayerParam& param = fwd_ptr_->GetParam(); + const int num_layer = param.num_layer; + const int direction = param.bidirectional ? 2 : 1; + const int ngates = GetRnnGatesNum(param.mode); + const size_t dtype_bytes = mshadow::mshadow_sizeof(dtype); + const size_t wxh_bytes = param.single_w_size * dtype_bytes; + const size_t wx_bytes = param.input_size * param.state_size * ngates * dtype_bytes; + const size_t wh_bytes = param.state_size * param.state_size * ngates * dtype_bytes; + char* diff_wx_ptr = static_cast(diff_weights_layer_->get_data_handle()); + char* diff_wh_ptr = static_cast(diff_weights_iter_->get_data_handle()); + + /* naive weights layout is: + 1st-layer: | wx_lr | wh_lr | wx_rl | wh_rl | + 2st-layer: | wx_lr | wh_lr | wx_rl | wh_rl | + size: | wxh_bytes | + |wx_bytes|wh_bytes| + */ + char* naive_weights = static_cast(diff_weights); + if (param.mode != rnn_enum::kGru) { + for (int shift = 0; shift < num_layer * direction; ++shift) { + std::memcpy(naive_weights + shift * wxh_bytes, + diff_wx_ptr + shift * wx_bytes, wx_bytes); + } + // align naive_weights to weights_iter memory + naive_weights += wx_bytes; + for (int shift = 0; shift < num_layer * direction; ++shift) { + std::memcpy(naive_weights + shift * wxh_bytes, + diff_wh_ptr + shift * wh_bytes, wh_bytes); + } + } else { + const size_t wx_bytes_per_gate = param.input_size * param.state_size * dtype_bytes; + const size_t wh_bytes_per_gate = param.state_size * param.state_size * dtype_bytes; + for (int shift = 0; shift < num_layer * direction; ++shift) { + std::memcpy(naive_weights + shift * wxh_bytes + wx_bytes_per_gate, + diff_wx_ptr + shift * wx_bytes, wx_bytes_per_gate); + std::memcpy(naive_weights + shift * wxh_bytes, + diff_wx_ptr + shift * wx_bytes + wx_bytes_per_gate, wx_bytes_per_gate); + std::memcpy(naive_weights + shift * wxh_bytes + 2 * wx_bytes_per_gate, + diff_wx_ptr + shift * wx_bytes + 2 * wx_bytes_per_gate, wx_bytes_per_gate); + } + // align naive_weights to weights_iter memory + naive_weights += wx_bytes; + for (int shift = 0; shift < num_layer * direction; ++shift) { + std::memcpy(naive_weights + shift * wxh_bytes + wh_bytes_per_gate, + diff_wh_ptr + shift * wh_bytes, wh_bytes_per_gate); + std::memcpy(naive_weights + shift * wxh_bytes, + diff_wh_ptr + shift * wh_bytes + wh_bytes_per_gate, wh_bytes_per_gate); + std::memcpy(naive_weights + shift * wxh_bytes + 2 * wh_bytes_per_gate, + diff_wh_ptr + shift * wh_bytes + 2 * wh_bytes_per_gate, wh_bytes_per_gate); + } + } + + char* naive_bias = static_cast(diff_bias); + char* diff_bias_ptr = static_cast(this->diff_bias_->get_data_handle()); + const size_t bias_bytes = param.single_b_size * dtype_bytes; + const size_t naive_bias_bytes = param.naive_single_b_size * dtype_bytes; + if (param.mode != rnn_enum::kGru) { + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { + DType* typed_bias = reinterpret_cast(diff_bias_ptr); + HalveWeightsDiff(typed_bias, num_layer * direction * param.single_b_size); + }); + for (int shift = 0; shift < num_layer * direction; ++shift) { + std::memcpy(naive_bias + shift * naive_bias_bytes, + diff_bias_ptr + shift * bias_bytes, bias_bytes); + std::memcpy(naive_bias + shift * naive_bias_bytes + bias_bytes, + diff_bias_ptr + shift * bias_bytes, bias_bytes); + } + } else { + const size_t bias_bytes_per_gate = param.state_size * dtype_bytes; + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { + for (int shift = 0; shift < num_layer * direction; ++shift) { + char* naive_reset = naive_bias + shift * naive_bias_bytes; + char* naive_update = naive_reset + bias_bytes_per_gate; + char* update = diff_bias_ptr + shift * bias_bytes; + char* reset = update + bias_bytes_per_gate; + + DType* typed_update = reinterpret_cast(update); + HalveWeightsDiff(typed_update, param.state_size * 2); + + std::memcpy(naive_update, update, bias_bytes_per_gate); + std::memcpy(naive_reset, reset, bias_bytes_per_gate); + std::memcpy(naive_update + naive_bias_bytes / 2, update, bias_bytes_per_gate); + std::memcpy(naive_reset + naive_bias_bytes / 2, reset, bias_bytes_per_gate); + + char* naive_new_bx = naive_update + bias_bytes_per_gate; + char* naive_new_bh = naive_new_bx + naive_bias_bytes / 2; + char* new_bx = reset + bias_bytes_per_gate; + char* new_bh = new_bx + bias_bytes_per_gate; + std::memcpy(naive_new_bx, new_bx, bias_bytes_per_gate); + std::memcpy(naive_new_bh, new_bh, bias_bytes_per_gate); + } + }); + } +} + +template +inline void RegisterMKLDNNRnn(MKLDNNRnnX const& rnn) { + MKLDNNStream::Get()->RegisterPrimArgs(rnn.GetFwd(), rnn.GetArgsMap()); +} + +template <> +inline void RegisterMKLDNNRnn(MKLDNNRnnBackward const& rnn) { + MKLDNNStream::Get()->RegisterPrimArgs(rnn.GetBwd(), rnn.GetArgsMap()); +} + +void MKLDNNRnnOp::Forward(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + // check output requests + if (kAddTo == req[rnn_enum::kOut]) + LOG(FATAL) << "Currently, `add` operation is not supported by RNNs."; + const RNNParam& default_param = full_param_.default_param; + if (default_param.state_outputs) { + if (kAddTo == req[rnn_enum::kStateOut]) + LOG(FATAL) << "Currently, `add` operation is not supported by RNNs."; + if (default_param.mode == rnn_enum::kLstm && kAddTo == req[rnn_enum::kStateCellOut]) + LOG(FATAL) << "Currently, `add` operation against lstm-cell output is not supported."; + } + + // Initialize weights version + if (!initialized_ && weights_version_ == 0) { + weights_version_ = inputs[rnn_enum::kParams].version(); + } + + // Check if weights NDArray was changed. If so, reset initialized_ + if (weights_version_ != inputs[rnn_enum::kParams].version() && + fwd_inf_vec_.size() > 0) { + initialized_ = false; + for (auto& fwd : fwd_inf_vec_) fwd.Reset(); + weights_version_ = inputs[rnn_enum::kParams].version(); + } + + if (!initialized_ || ctx.is_train || fwd_trn_vec_.size() == 0) { + Init(ctx, inputs, req, outputs); + } + + // Get data type + int data_dtype = inputs[rnn_enum::kData].dtype(); + + // Get input & output NDArray + char *src = static_cast(inputs[rnn_enum::kData].data().dptr_); + char *src_state = static_cast(inputs[rnn_enum::kState].data().dptr_); + char *dst = req[rnn_enum::kOut] == kNullOp ? nullptr + : static_cast(outputs[rnn_enum::kOut].data().dptr_); + char *dst_state = nullptr; // Output state + char *src_state_cell = nullptr; // Used in LSTM for cell state + char *dst_state_cell = nullptr; // Used in LSTM for cell state + + if (default_param.state_outputs && req[rnn_enum::kStateOut] != kNullOp) { + dst_state = static_cast(outputs[rnn_enum::kStateOut].data().dptr_); + } + + if (default_param.mode == rnn_enum::kLstm) { + src_state_cell = static_cast(inputs[rnn_enum::kStateCell].data().dptr_); + if (default_param.state_outputs && req[rnn_enum::kStateCellOut] != kNullOp) { + dst_state_cell = static_cast(outputs[rnn_enum::kStateCellOut].data().dptr_); + } + } + + if (fwd_inf_vec_.size() == 1) { + fwd_inf_vec_.front().SetNewDataMem(src, src_state, src_state_cell, + dst, dst_state, dst_state_cell, data_dtype); + if (ctx.is_train) { + fwd_trn_vec_.front().FetchData(fwd_inf_vec_.front()); + } + } else { + CHECK_EQ(fwd_inf_vec_.size(), dst_.size() + 1) << "Output memory error."; + size_t cell_bytes = (default_param.bidirectional + 1) * default_param.batch_size_ * + default_param.state_size * mshadow::mshadow_sizeof(data_dtype); + + // Set input data memory for the first layer. This stores intermediate output + // results in this->xxx, used as the source input of the next layer. + fwd_inf_vec_.front().SetNewDataMem(src, src_state, src_state_cell, + this->dst_.front()->get_data_handle(), dst_state, dst_state_cell, data_dtype); + if (ctx.is_train) { + fwd_trn_vec_.front().FetchData(fwd_inf_vec_.front()); + } + // 1st_lyr -> dst_handle -> next_lyr -> dst_handle -> next_lyr -> ... + for (size_t lyr = 1; lyr < fwd_inf_vec_.size() - 1; ++lyr) { + src_state += cell_bytes; + if (src_state_cell) src_state_cell += cell_bytes; + if (dst_state) dst_state += cell_bytes; + if (dst_state_cell) dst_state_cell += cell_bytes; + fwd_inf_vec_.at(lyr).SetNewDataMem(this->dst_.at(lyr - 1)->get_data_handle(), + src_state, src_state_cell, + this->dst_.at(lyr)->get_data_handle(), dst_state, dst_state_cell, data_dtype); + if (ctx.is_train) { + fwd_trn_vec_.at(lyr).FetchData(fwd_inf_vec_.at(lyr)); + } + } + // Set output data memory for the last layer. + src_state += cell_bytes; + if (src_state_cell) src_state_cell += cell_bytes; + if (dst_state) dst_state += cell_bytes; + if (dst_state_cell) dst_state_cell += cell_bytes; + fwd_inf_vec_.back().SetNewDataMem(this->dst_.back()->get_data_handle(), + src_state, src_state_cell, dst, dst_state, dst_state_cell, data_dtype); + if (ctx.is_train) { + fwd_trn_vec_.back().FetchData(fwd_inf_vec_.back()); + } + } + if (ctx.is_train) { + for (auto& trn_lyr : fwd_trn_vec_) RegisterMKLDNNRnn(trn_lyr); + } else { + for (auto& inf_lyr : fwd_inf_vec_) RegisterMKLDNNRnn(inf_lyr); + } + MKLDNNStream::Get()->Submit(); +} + +void MKLDNNRnnOp::Backward(const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using tag = mkldnn::memory::format_tag; + const RNNParam& default_param = full_param_.default_param; + if (kAddTo == req[rnn_enum::kData] || kAddTo == req[rnn_enum::kParams]) + LOG(FATAL) << "Currently, `add` operations against gradients of input and weights" + << " are not supported by RNNs."; + if (default_param.state_outputs) { + if (kAddTo == req[rnn_enum::kStateOut]) + LOG(FATAL) << "Currently, `add` operation against gradients of begining state" + << " is not supported by RNNs."; + if (default_param.mode == rnn_enum::kLstm && req[rnn_enum::kStateCell]) + LOG(FATAL) << "Currently, `add` operation against gradients of begining cell-state" + << " is not supported by LSTM."; + } + // Initialize the bwd_vec_ + if (bwd_vec_.size() != fwd_inf_vec_.size()) { + bwd_vec_.clear(); + for (size_t lyr = 0; lyr < fwd_inf_vec_.size(); ++lyr) + bwd_vec_.emplace_back(fwd_trn_vec_.at(lyr), inputs[rnn_enum::kData], + inputs[rnn_enum::kParams]); + } + // Fetch weights, src and dst from Forward layer + if (bwd_vec_.size() != fwd_trn_vec_.size()) + LOG(FATAL) << "MKL-DNN RNN fusion error."; + for (size_t lyr = 0; lyr < bwd_vec_.size(); ++lyr) { + bwd_vec_.at(lyr).FetchDataWeightsMem(fwd_trn_vec_.at(lyr)); + bwd_vec_.at(lyr).SetWeightsGradsMem(); + } + + const int data_dtype = inputs[rnn_enum::kData].dtype(); + const int w_dtype = inputs[rnn_enum::kParams].dtype(); + const size_t w_bytes = mshadow::mshadow_sizeof(w_dtype); + // index description of outputs NDArray + // 0 1 2 3 + // | dx | dw | dhx | dcx| + char* dx = req[rnn_enum::kData] == kNullOp ? nullptr + : static_cast(outputs[rnn_enum::kData].data().dptr_); + char* dw = static_cast(outputs[rnn_enum::kParams].data().dptr_); + char* db = dw + (inputs[rnn_enum::kParams].data().Size() - + GetRnnBiasSize(default_param.num_layers, default_param.state_size, + default_param.bidirectional + 1, default_param.mode)) * w_bytes; + char* dhx = req[rnn_enum::kState] == kNullOp ? nullptr + : static_cast(outputs[rnn_enum::kState].data().dptr_); + char* dcx = nullptr; + if (full_param_.default_param.mode == rnn_enum::kLstm + && req[rnn_enum::kStateCell] != kNullOp) + dcx = static_cast(outputs[rnn_enum::kStateCell].data().dptr_); + + // index description of inputs NDArray + // 0 1 2 3 4 5 6 7 8 9 + // | x | w | hx | y | dy | hy | dhy | cx | cy | dcy | + char* dy = static_cast(inputs[4].data().dptr_); + char* dhy = nullptr; + if (default_param.state_outputs) + dhy = static_cast(inputs[6].data().dptr_); + + char* dcy = nullptr; + if ((default_param.mode == rnn_enum::kLstm) && default_param.state_outputs) + dcy = static_cast(inputs[9].data().dptr_); + + if (bwd_vec_.size() == 1) { + bwd_vec_.back().SetDataGradsMem(dx, dhx, dcx, dy, dhy, dcy, data_dtype); + RegisterMKLDNNRnn(bwd_vec_.back()); + } else { + const size_t cell_bytes = (default_param.bidirectional + 1) * default_param.batch_size_ * + default_param.state_size * mshadow::mshadow_sizeof(data_dtype); + if (diff_src == nullptr) { + auto desc = mkldnn::memory::desc(full_param_.layer_params.back().src_dims, + get_mkldnn_type(data_dtype), tag::tnc); + diff_src = std::make_shared(desc, CpuEngine::Get()->get_engine()); + } + // Sets primitives from bottom to top, then submits them in reversed order. + bwd_vec_.front().SetDataGradsMem(dx, dhx, dcx, + diff_src->get_data_handle(), dhy, dcy, data_dtype); + for (size_t lyr = 1; lyr < bwd_vec_.size() - 1; ++lyr) { + if (dhx) dhx += cell_bytes; + if (dcx) dcx += cell_bytes; + if (dhy) dhy += cell_bytes; + if (dcy) dcy += cell_bytes; + bwd_vec_.at(lyr).SetDataGradsMem(diff_src->get_data_handle(), dhx, dcx, + diff_src->get_data_handle(), dhy, dcy, data_dtype); + } + if (dhx) dhx += cell_bytes; + if (dcx) dcx += cell_bytes; + if (dhy) dhy += cell_bytes; + if (dcy) dcy += cell_bytes; + bwd_vec_.back().SetDataGradsMem(diff_src->get_data_handle(), dhx, dcx, + dy, dhy, dcy, data_dtype); + + for (std::vector::const_reverse_iterator bwd = bwd_vec_.rbegin(); + bwd < bwd_vec_.rend(); ++bwd) { + RegisterMKLDNNRnn(*bwd); + } + } + MKLDNNStream::Get()->Submit(); + + // Commit weights diff + if (req[rnn_enum::kParams] != kNullOp) { + for (size_t lyr = 0; lyr < bwd_vec_.size(); ++lyr) { + bwd_vec_.at(lyr).CommitWeightsDiff(dw, db, w_dtype); + dw += full_param_.layer_params.at(lyr).single_w_size * w_bytes; + db += full_param_.layer_params.at(lyr).single_b_size * w_bytes; + } + } +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_MKLDNN == 100 diff --git a/src/operator/nn/mkldnn/mkldnn_rnn_impl.h b/src/operator/nn/mkldnn/mkldnn_rnn_impl.h deleted file mode 100644 index ea8e07ea617c..000000000000 --- a/src/operator/nn/mkldnn/mkldnn_rnn_impl.h +++ /dev/null @@ -1,740 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RNN_IMPL_H_ -#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RNN_IMPL_H_ -#if MXNET_USE_MKLDNN == 1 -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "../../math_functions-inl.h" -#include "../../operator_common.h" -#include "../../rnn_impl.h" -#include "../../rnn-inl.h" -#include "mkldnn.hpp" -#include "./mkldnn_base-inl.h" - -namespace mxnet { -namespace op { - -static algorithm GetMKLDNNRNNAlgo(int mode, - int* ngates, - int* nstates) { - algorithm algo = algorithm::vanilla_rnn; - switch (mode) { - case rnn_enum::kLstm: - *ngates = 4; - *nstates = 2; - algo = algorithm::vanilla_lstm; - break; - case rnn_enum::kGru: - *ngates = 3; - *nstates = 1; - algo = algorithm::vanilla_gru; - break; - case rnn_enum::kRnnRelu: - case rnn_enum::kRnnTanh: - *ngates = 1; - *nstates = 1; - algo = algorithm::vanilla_rnn; - break; - default: - LOG(FATAL) << "unsupported RNN mode:" << mode; - break; - } - return algo; -} - -static void ConcatData(mkldnn::memory::format src_format, - mkldnn::memory::format dst_format, - std::vector srcs_cds, - mkldnn::memory::dims dst_cds, - mkldnn::memory::data_type mkldnn_dtype, - int concat_dimension, - std::vector srcs_data, - const mkldnn::memory &dst) { - auto cpu_engine = CpuEngine::Get()->get_engine(); - std::vector srcs_pd; - std::vector srcs; - for (size_t i = 0; i < srcs_cds.size(); i++) { - auto desc = mkldnn::memory::desc(srcs_cds[i], mkldnn_dtype, src_format); - auto mpd = mkldnn::memory::primitive_desc(desc, cpu_engine); - auto src_memory = mkldnn::memory(mpd, srcs_data[i]); - srcs_pd.push_back(mpd); - srcs.push_back(src_memory); - } - std::vector inputs; - for (size_t i = 0; i < srcs_cds.size(); i++) { - inputs.push_back(srcs[i]); - } - auto dst_desc = mkldnn::memory::desc(dst_cds, mkldnn_dtype, dst_format); - auto concat_pd = concat::primitive_desc(dst_desc, concat_dimension, srcs_pd); - MKLDNNStream::Get()->RegisterPrim(concat(concat_pd, inputs, dst)); - MKLDNNStream::Get()->Submit(); -} - -// cached mkldnn memory -// first layer wx, wh with next L - 1 layers wx and wh -// with L layers hx and cx, src and dst data/iter etc. -// it will prepare memory on before and after reorder and concat. -// for unidirectional, it will fused as dim like 1 + (L - 1) when I != H. -// for bidirectional, it will fused as data + back_data (weight, bias, iter etc), -// also need to identify first layer and next layers -static size_t GetMKLDNNRNNCacheMemorySize(int L, - int D, - int T, - int N, - int I, - int H, - int mode) { - size_t size = 0; - switch (mode) { - case rnn_enum::kLstm: - size = 2 * (D * (I + H) * 4 * H + (L - 1) * D * (D * H + H) * 4 * H + - L * D * 2 * N * H) + T * N * D * H + L * 2 * D * 4 * H + (L + 2) * D * 2 * N * H + - 6 * D * (I + H + 2) * 4 * H + T * N * I * 2; - break; - case rnn_enum::kGru: - size = 2 * (D * (I + H) * 3 * H + (L - 1) * D * (D * H + H) * 3 * H + - L * D * 2 * N * H) + T * N * D * H + L * 2 * D * 3 * H + (L + 2) * D * 2 * N * H + - 6 * D * (I + H + 2) * 3 * H + T * N * I * 2; - break; - case rnn_enum::kRnnRelu: - case rnn_enum::kRnnTanh: - size = 2 * (D * (I + H) * 1 * H + (L - 1) * D * (D * H + H) * 1 * H + - L * D * 2 * N * H) + T * N * D * H + L * 2 * D * 1 * H + (L + 2) * D * 2 * N * H + - 6 * D * (I + H + 2) * 1 * H + T * N * I * 2; - break; - default: - LOG(FATAL) << "unknown RNN mode " << mode; - break; - } - return size; -} - -template -static void AdjustGruWeightGateOrder(DType* weight, - const int I, - const int H) { - // mxnet gru gate order is reset, update and new gates - // mkldnn gru gate order is update, reset and new gates - const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); - DType* weight_reset = weight; - DType* weight_update = weight + I * H; - #pragma omp parallel for num_threads(omp_threads) - for (int i = 0; i < I * H; i++) { - DType tmp = weight_update[i]; - weight_update[i] = weight_reset[i]; - weight_reset[i] = tmp; - } -} - -template -static void AdjustGruBiasGateOrder(DType* bias, - const int H) { - // mxnet gru gate order is reset, update and new gates - // mkldnn gru gate order is update, reset and new gates - const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); - DType* bias_reset = bias; - DType* bias_update = bias + H; - #pragma omp parallel for num_threads(omp_threads) - for (int i = 0; i < H; i++) { - DType tmp = bias_update[i]; - bias_update[i] = bias_reset[i]; - bias_reset[i] = tmp; - } -} -// since there is different sematics of MKLDNN's Fused RNN and MXNet FusedRNN, -// bidirectional will be fused layer by layer, -// unidirectional will be done by fused 1 + fused (L - 1) layers or fused L layers(when I = H) - -template -static void MKLDNNRNNForwardSingleLayerBi(bool state_outputs, - const int T, - const int N, - const int I, - const int H, - DType* x_ptr, - mkldnn::memory *user_src_layer_memory, - DType* hx_ptr, - DType* cx_ptr, - DType* w_ptr, - DType* b_ptr, - DType* y_ptr, - DType* hy_ptr, - DType* cy_ptr, - std::vector *concat_weight_memory, - std::vector *concat_iter_memory, - std::vector *x_memory, - std::vector *hcx_memory, - std::vector *wx_memory, - std::vector *wh_memory, - std::vector *bias_memory, - std::vector *y_memory, - std::vector *hcy_memory, - std::vector *rnn_forward_prim, - int layer_index, - bool *has_cache, - int lvalue, - int dtype, - bool is_train, - int mode) { - int ngates = 0, nstates = 0; - algorithm nalgorithm = GetMKLDNNRNNAlgo(mode, &ngates, &nstates); - mkldnn::memory::data_type mkldnn_dtype = get_mkldnn_type(dtype); - const int single_cell_size = N * H; - const int single_b_size = ngates * H; - DType* wx = w_ptr; // ngates * H, I - DType* wh = w_ptr + I * H * ngates; // ngates * H, H - DType* back_wx = w_ptr + ngates * H * (I + H); - DType* back_wh = back_wx + I * H * ngates; - DType* bx = b_ptr; - DType* bh = b_ptr + H * ngates; - DType* back_bx = b_ptr + single_b_size * 2; - DType* back_bh = back_bx + H * ngates; - const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); - auto cpu_engine = CpuEngine::Get()->get_engine(); - auto null_memory_ = null_memory(cpu_engine); - int offset1 = 0, offset2 = 0; - bool initialized = *has_cache; - mkldnn::memory::dims src_layer_tz = {T, N, I}; - mkldnn::memory::dims dst_layer_tz = {T, N, 2 * H}; - mkldnn::memory::dims weights_layer_tz = {1, 2, I, ngates, H}; // ldigo - mkldnn::memory::dims weights_layer_r_tz = {1, 1, I, ngates, H}; // ldigo for reorder - mkldnn::memory::dims weights_iter_tz = {1, 2, H, ngates, H}; // ldigo - mkldnn::memory::dims weights_iter_r_tz = {1, 1, H, ngates, H}; // ldigo for reorder - mkldnn::memory::dims bias_tz = {1, 2, ngates, H}; - mkldnn::memory::dims src_iter_tz = {1, 2, nstates, N, H}; // ldsnc - mkldnn::memory::dims dst_iter_tz = {1, 2, nstates, N, H}; // ldsnc - - if (!initialized) { - if (mode == rnn_enum::kGru) { - AdjustGruWeightGateOrder(wx, I, H); - AdjustGruWeightGateOrder(back_wx, I, H); - AdjustGruWeightGateOrder(wh, H, H); - AdjustGruWeightGateOrder(back_wh, H, H); - AdjustGruBiasGateOrder(bx, H); - AdjustGruBiasGateOrder(back_bx, H); - AdjustGruBiasGateOrder(bh, H); - AdjustGruBiasGateOrder(back_bh, H); - } - auto src_wx = (*concat_weight_memory)[2 * layer_index]; - auto src_wh = (*concat_weight_memory)[2 * layer_index + 1]; - std::vector srcs_data1; - srcs_data1.push_back(wx); - srcs_data1.push_back(back_wx); - ConcatData(mkldnn::memory::format::ldgoi, mkldnn::memory::format::ldgoi, - {weights_layer_r_tz, weights_layer_r_tz}, weights_layer_tz, - mkldnn_dtype, 1, srcs_data1, src_wx); - srcs_data1.clear(); - srcs_data1.push_back(wh); - srcs_data1.push_back(back_wh); - ConcatData(mkldnn::memory::format::ldgoi, mkldnn::memory::format::ldgoi, - {weights_iter_r_tz, weights_iter_r_tz}, weights_iter_tz, - mkldnn_dtype, 1, srcs_data1, src_wh); - int tmpvalue = 0; - if (lvalue > 0) { - tmpvalue = lvalue + 1; - } - MKLDNNStream::Get()->RegisterPrim(reorder(src_wx, (*wx_memory)[tmpvalue])); - MKLDNNStream::Get()->RegisterPrim(reorder(src_wh, (*wh_memory)[tmpvalue])); - - DType* user_bias = reinterpret_cast - ((*bias_memory)[tmpvalue].get_data_handle()); - #pragma omp parallel for num_threads(omp_threads) - for (int j = 0; j < single_b_size; j++) { - user_bias[j] = bx[j] + bh[j]; - user_bias[single_b_size + j] = back_bx[j] + back_bh[j]; - } - } - if (lvalue > 0) { - (*wx_memory)[layer_index].set_data_handle((*wx_memory)[lvalue + 1].get_data_handle()); - (*wh_memory)[layer_index].set_data_handle((*wh_memory)[lvalue + 1].get_data_handle()); - (*bias_memory)[layer_index].set_data_handle((*bias_memory)[lvalue + 1].get_data_handle()); - } - - auto src_layer_md = mkldnn::memory::desc( - { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); - auto weight_layer_md = mkldnn::memory::desc( - { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); - auto weight_iter_md = mkldnn::memory::desc( - { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); - auto dst_layer_md = mkldnn::memory::desc( - { dst_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); - auto dst_iter_md = mkldnn::memory::desc( - { dst_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - auto src_iter_md = mkldnn::memory::desc( - {src_iter_tz}, mkldnn_dtype, mkldnn::memory::format::ldsnc); - auto bias_md = mkldnn::memory::desc({bias_tz}, - mkldnn_dtype, mkldnn::memory::format::ldgo); - - auto user_src_iter_memory = (*concat_iter_memory)[2]; - if (mode == rnn_enum::kLstm) { - std::vector srcs_data1; - srcs_data1.push_back(hx_ptr); - srcs_data1.push_back(cx_ptr); - auto tmp1_src_iter_memory = (*concat_iter_memory)[0]; - ConcatData(mkldnn::memory::format::ldsnc, mkldnn::memory::format::ldsnc, - {{1, 1, 1, N, H}, {1, 1, 1, N, H}}, {1, 1, nstates, N, H}, mkldnn_dtype, 2, - srcs_data1, tmp1_src_iter_memory); - std::vector srcs_data2; - srcs_data2.push_back(hx_ptr + single_cell_size); - srcs_data2.push_back(cx_ptr + single_cell_size); - auto tmp2_src_iter_memory = (*concat_iter_memory)[1]; - ConcatData(mkldnn::memory::format::ldsnc, mkldnn::memory::format::ldsnc, - {{1, 1, 1, N, H}, {1, 1, 1, N, H}}, {1, 1, nstates, N, H}, mkldnn_dtype, 2, - srcs_data2, tmp2_src_iter_memory); - std::vector srcs_data3; - srcs_data3.push_back(reinterpret_cast(tmp1_src_iter_memory.get_data_handle())); - srcs_data3.push_back(reinterpret_cast(tmp2_src_iter_memory.get_data_handle())); - ConcatData(mkldnn::memory::format::ldsnc, mkldnn::memory::format::ldsnc, - {{1, 1, nstates, N, H}, {1, 1, nstates, N, H}}, {1, 2, nstates, N, H}, - mkldnn_dtype, 1, srcs_data3, user_src_iter_memory); - } else { - user_src_iter_memory.set_data_handle(hx_ptr); - } - (*hcx_memory)[layer_index].set_data_handle(user_src_iter_memory.get_data_handle()); - - rnn_cell::desc rnn_cell(nalgorithm, - mode == rnn_enum::kRnnRelu ? algorithm::eltwise_relu : algorithm::eltwise_tanh); - - rnn_forward::desc layer_desc(prop_kind::forward_inference, rnn_cell, - rnn_direction::bidirectional_concat, src_layer_md, - src_iter_md, weight_layer_md, weight_iter_md, - bias_md, dst_layer_md, dst_iter_md); - - auto prim_desc - = rnn_forward::primitive_desc(layer_desc, cpu_engine); - - if (x_ptr && layer_index == 0) { - (*x_memory)[layer_index].set_data_handle(x_ptr); - } else { - (*x_memory)[layer_index].set_data_handle((*user_src_layer_memory).get_data_handle()); - } - (*y_memory)[layer_index].set_data_handle(y_ptr); - - if (rnn_forward_prim->size() <= (size_t)layer_index) { - primitive rnn_prim = rnn_forward(prim_desc, (*x_memory)[layer_index], - (*hcx_memory)[layer_index], (*wx_memory)[layer_index], - (*wh_memory)[layer_index], (*bias_memory)[layer_index], - (*y_memory)[layer_index], - (*hcy_memory)[layer_index], null_memory_); - rnn_forward_prim->push_back(rnn_prim); - } - MKLDNNStream::Get()->RegisterPrim((*rnn_forward_prim)[layer_index]); - MKLDNNStream::Get()->Submit(); - - if (state_outputs) { - DType* dst_hcy = reinterpret_cast ((*hcy_memory)[layer_index].get_data_handle()); - if (mode == rnn_enum::kLstm) { - offset1 = nstates * single_cell_size; - offset2 = (nstates + 1) * single_cell_size; - #pragma omp parallel for num_threads(omp_threads) - for (int n = 0; n < single_cell_size; n++) { - hy_ptr[n] = dst_hcy[n]; - hy_ptr[n + single_cell_size] = dst_hcy[n + offset1]; - cy_ptr[n] = dst_hcy[n + single_cell_size]; - cy_ptr[n + single_cell_size] = dst_hcy[n + offset2]; - } - } else { - #pragma omp parallel for num_threads(omp_threads) - for (int n = 0; n < 2 * single_cell_size; n++) { - hy_ptr[n] = dst_hcy[n]; - } - } - } -} - - -template -static void MKLDNNRNNForwardUnidi(bool state_outputs, - const int L, - const int T, - const int N, - const int I, - const int H, - DType* x_ptr, - mkldnn::memory *user_src_layer_memory, - DType* hx_ptr, - DType* cx_ptr, - DType* w_ptr, - DType* b_ptr, - DType* y_ptr, - DType* hy_ptr, - DType* cy_ptr, - std::vector *concat_weight_memory, - std::vector *concat_iter_memory, - std::vector *x_memory, - std::vector *hcx_memory, - std::vector *wx_memory, - std::vector *wh_memory, - std::vector *bias_memory, - std::vector *y_memory, - std::vector *hcy_memory, - std::vector *rnn_forward_prim, - int layer_index, - bool *has_cache, - int dtype, - bool is_train, - int mode) { - int ngates = 0, nstates = 0; - algorithm nalgorithm = GetMKLDNNRNNAlgo(mode, &ngates, &nstates); - mkldnn::memory::data_type mkldnn_dtype = get_mkldnn_type(dtype); - const int cell_size = N * H; - const int single_cell_size = N * H; - const int single_b_size = ngates * H; - int w_size = (I + H) * H * ngates; - const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); - auto cpu_engine = CpuEngine::Get()->get_engine(); - auto null_memory_ = null_memory(cpu_engine); - int offset1 = 0, offset2 = 0; - bool initialized = *has_cache; - - mkldnn::memory::dims src_layer_tz = {T, N, I}; - mkldnn::memory::dims dst_layer_tz = {T, N, H}; - mkldnn::memory::dims weights_layer_tz = {L, 1, I, ngates, H}; // ldigo - mkldnn::memory::dims weights_iter_tz = {L, 1, H, ngates, H}; // ldigo - mkldnn::memory::dims bias_tz = {L, 1, ngates, H}; - mkldnn::memory::dims src_iter_tz = {L, 1, nstates, N, H}; // ldsnc - mkldnn::memory::dims dst_iter_tz = {L, 1, nstates, N, H}; // ldsnc - mkldnn::memory::dims weights_layer_r_tz = {1, 1, I, ngates, H}; // ldigo for reorder - mkldnn::memory::dims weights_iter_r_tz = {1, 1, H, ngates, H}; // ldigo for reorder - - auto weight_layer_md = mkldnn::memory::desc( - { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); - auto weight_iter_md = mkldnn::memory::desc( - { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); - auto src_layer_md = mkldnn::memory::desc( - { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); - auto dst_layer_md = mkldnn::memory::desc( - {dst_layer_tz}, mkldnn_dtype, mkldnn::memory::format::tnc); - auto src_iter_md = mkldnn::memory::desc( - {src_iter_tz}, mkldnn_dtype, mkldnn::memory::format::ldsnc); - auto bias_md = mkldnn::memory::desc({bias_tz}, - mkldnn_dtype, mkldnn::memory::format::ldgo); - auto dst_iter_md = mkldnn::memory::desc( - {dst_iter_tz}, mkldnn_dtype, mkldnn::memory::format::ldsnc); - - for (int l = 0; l < L; l++) { - if (mode == rnn_enum::kLstm) { - std::vector srcs_data; - srcs_data.push_back(hx_ptr); - srcs_data.push_back(cx_ptr); - auto tmp_src_iter_memory = (*concat_iter_memory)[l + layer_index]; - ConcatData(mkldnn::memory::format::ldsnc, mkldnn::memory::format::ldsnc, - {{1, 1, 1, N, H}, {1, 1, 1, N, H}}, {1, 1, nstates, N, H}, mkldnn_dtype, - 2, srcs_data, tmp_src_iter_memory); - } else { - (*concat_iter_memory)[l + layer_index].set_data_handle(hx_ptr); - } - hx_ptr += cell_size; - if (mode == rnn_enum::kLstm) { - cx_ptr += cell_size; - } - } - - auto user_src_iter_memory = null_memory_; - if (L == 1) { - user_src_iter_memory = (*concat_iter_memory)[layer_index]; - } else { - user_src_iter_memory = (*concat_iter_memory)[L + layer_index]; - std::vector src_l_data; - std::vector src_l_dim; - for (int l = 0; l < L; l++) { - src_l_data.push_back(reinterpret_cast - ((*concat_iter_memory)[l + layer_index].get_data_handle())); - src_l_dim.push_back({1, 1, nstates, N, H}); - } - ConcatData(mkldnn::memory::format::ldsnc, mkldnn::memory::format::ldsnc, src_l_dim, - {L, 1, nstates, N, H}, mkldnn_dtype, 0, src_l_data, user_src_iter_memory); - } - (*hcx_memory)[layer_index].set_data_handle(user_src_iter_memory.get_data_handle()); - - auto src_wx_f = (*concat_weight_memory)[2 * layer_index]; - auto src_wh_f = (*concat_weight_memory)[2 * layer_index + 1]; - - std::vector srcs_data_x; - std::vector srcs_data_h; - std::vector src_l_dim_x; - std::vector src_l_dim_h; - if (!initialized) { - if (L == 1) { - DType* wx = w_ptr; - DType* wh = w_ptr + I * H * ngates; - if (mode == rnn_enum::kGru) { - AdjustGruWeightGateOrder(wx, I, H); - AdjustGruWeightGateOrder(wh, H, H); - AdjustGruBiasGateOrder(b_ptr, H); - AdjustGruBiasGateOrder(b_ptr + H * ngates, H); - } - src_wx_f.set_data_handle(wx); - src_wh_f.set_data_handle(wh); - } else { - for (int l = 0; l < L; l++) { - DType* wx = w_ptr; - DType* wh = w_ptr + I * H * ngates; - DType* bx = b_ptr + l * ngates * H * 2; - DType* bh = b_ptr + l * ngates * H * 2 + H * ngates; - if (mode == rnn_enum::kGru) { - AdjustGruWeightGateOrder(wx, I, H); - AdjustGruWeightGateOrder(wh, H, H); - AdjustGruBiasGateOrder(bx, H); - AdjustGruBiasGateOrder(bh, H); - } - srcs_data_x.push_back(wx); - srcs_data_h.push_back(wh); - src_l_dim_x.push_back(weights_layer_r_tz); - src_l_dim_h.push_back(weights_iter_r_tz); - w_ptr = w_ptr + w_size; - } - ConcatData(mkldnn::memory::format::ldgoi, mkldnn::memory::format::ldgoi, - src_l_dim_x, weights_layer_tz, mkldnn_dtype, 0, srcs_data_x, src_wx_f); - ConcatData(mkldnn::memory::format::ldgoi, mkldnn::memory::format::ldgoi, - src_l_dim_h, weights_iter_tz, mkldnn_dtype, 0, srcs_data_h, src_wh_f); - } - MKLDNNStream::Get()->RegisterPrim(reorder(src_wx_f, (*wx_memory)[layer_index])); - MKLDNNStream::Get()->RegisterPrim(reorder(src_wh_f, (*wh_memory)[layer_index])); - - DType* user_bias_f = reinterpret_cast ((*bias_memory)[layer_index].get_data_handle()); - #pragma omp parallel for num_threads(omp_threads) - for (int j = 0; j < L * single_b_size; j++) { - int k = j / single_b_size; - user_bias_f[j] = b_ptr[j + k * single_b_size] + b_ptr[j + k * single_b_size + single_b_size]; - } - } - - rnn_cell::desc rnn_cell(nalgorithm, - mode == rnn_enum::kRnnRelu ? algorithm::eltwise_relu : algorithm::eltwise_tanh); - - rnn_forward::desc layer_desc(prop_kind::forward_inference, rnn_cell, - rnn_direction::unidirectional, src_layer_md, - src_iter_md, weight_layer_md, weight_iter_md, - bias_md, dst_layer_md, dst_iter_md); - - auto prim_desc - = rnn_forward::primitive_desc(layer_desc, cpu_engine); - - if (x_ptr && layer_index == 0) { - (*x_memory)[layer_index].set_data_handle(x_ptr); - } else { - (*x_memory)[layer_index].set_data_handle((*user_src_layer_memory).get_data_handle()); - } - (*y_memory)[layer_index].set_data_handle(y_ptr); - - if (rnn_forward_prim->size() <= (size_t)layer_index) { - primitive rnn_prim = rnn_forward(prim_desc, (*x_memory)[layer_index], - (*hcx_memory)[layer_index], (*wx_memory)[layer_index], - (*wh_memory)[layer_index], (*bias_memory)[layer_index], - (*y_memory)[layer_index], - (*hcy_memory)[layer_index], null_memory_); - rnn_forward_prim->push_back(rnn_prim); - } - MKLDNNStream::Get()->RegisterPrim((*rnn_forward_prim)[layer_index]); - MKLDNNStream::Get()->Submit(); - - if (state_outputs) { - DType* dst_hcy = reinterpret_cast ((*hcy_memory)[layer_index].get_data_handle()); - if (mode == rnn_enum::kLstm) { - for (int l = 0; l < L; l++) { - offset1 = l * single_cell_size; - offset2 = l * nstates * single_cell_size; - #pragma omp parallel for num_threads(omp_threads) - for (int n = 0; n < single_cell_size; n++) { - hy_ptr[offset1 + n] = dst_hcy[offset2 + n]; - cy_ptr[offset1 + n] = dst_hcy[offset2 + n + single_cell_size]; - } - } - } else { - #pragma omp parallel for num_threads(omp_threads) - for (int n = 0; n < L * single_cell_size; n++) { - hy_ptr[n] = dst_hcy[n]; - } - } - } -} - -template -static void MKLDNNRNNForward(bool state_outputs, - const int L, - const int D, - const int T, - const int N, - const int I, - const int H, - DType* x_ptr, - DType* hx_ptr, - DType* cx_ptr, - DType* w_ptr, - DType* b_ptr, - DType* y_ptr, - DType* hy_ptr, - DType* cy_ptr, - std::vector *concat_weight_memory, - std::vector *concat_iter_memory, - std::vector *x_memory, - std::vector *hcx_memory, - std::vector *wx_memory, - std::vector *wh_memory, - std::vector *bias_memory, - std::vector *y_memory, - std::vector *hcy_memory, - std::vector *rnn_forward_prim, - bool *has_cache, - int dtype, - bool is_train, - int mode) { - int ngates = 0, nstates = 0; - GetMKLDNNRNNAlgo(mode, &ngates, &nstates); - const int b_size = 2 * H * ngates * D; - const int cell_size = N * H * D; - // First layer - int w_size = (I + H) * H * ngates * D; - auto cpu_engine = CpuEngine::Get()->get_engine(); - auto null_memory_ = null_memory(cpu_engine); - DType* tmpNull = NULL; - // when D = 1 and I == H, L layers can be fused together - if (D == 1 && I == H) { - MKLDNNRNNForwardUnidi(state_outputs, L, T, N, I, H, x_ptr, &null_memory_, - hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, concat_weight_memory, - concat_iter_memory, x_memory, hcx_memory, wx_memory, wh_memory, - bias_memory, y_memory, hcy_memory, rnn_forward_prim, - 0, has_cache, dtype, is_train, mode); - } else { - auto user_src_layer_memory_l = null_memory_; - if (D == 2) { - MKLDNNRNNForwardSingleLayerBi(state_outputs, T, N, I, H, x_ptr, &user_src_layer_memory_l, - hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, concat_weight_memory, - concat_iter_memory, x_memory, hcx_memory, wx_memory, wh_memory, - bias_memory, y_memory, hcy_memory, rnn_forward_prim, - 0, has_cache, 0, dtype, is_train, mode); - } else { - MKLDNNRNNForwardUnidi(state_outputs, 1, T, N, I, H, x_ptr, &user_src_layer_memory_l, - hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, concat_weight_memory, - concat_iter_memory, x_memory, hcx_memory, wx_memory, wh_memory, - bias_memory, y_memory, hcy_memory, rnn_forward_prim, - 0, has_cache, dtype, is_train, mode); - } - if (L > 1) { - user_src_layer_memory_l = (*y_memory)[0]; - // go to next L - 1 layers. - // If D = 2, do it layer by layer. If D = 1, fused L - 1 layers - w_ptr += w_size; - b_ptr += b_size; - if (D == 2) { - w_size = (H * D + H) * H * ngates * D; - for (int l = 0; l < L - 1; l++) { - if (state_outputs) { - hy_ptr += cell_size; - if (mode == rnn_enum::kLstm) { - cy_ptr += cell_size; - } - } - hx_ptr += cell_size; - if (mode == rnn_enum::kLstm) { - cx_ptr += cell_size; - } - MKLDNNRNNForwardSingleLayerBi(state_outputs, T, N, D * H, H, tmpNull, - &user_src_layer_memory_l, hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, - cy_ptr, concat_weight_memory, concat_iter_memory, x_memory, - hcx_memory, wx_memory, wh_memory, bias_memory, - y_memory, hcy_memory, rnn_forward_prim, - 1, has_cache, l + 1, dtype, is_train, mode); - user_src_layer_memory_l = (*y_memory)[1]; - w_ptr += w_size; - b_ptr += b_size; - } - } - if (D == 1) { - if (state_outputs) { - hy_ptr += cell_size; - if (mode == rnn_enum::kLstm) { - cy_ptr += cell_size; - } - } - w_size = (H + H) * H * ngates; - MKLDNNRNNForwardUnidi(state_outputs, L - 1, T, N, H, H, tmpNull, &user_src_layer_memory_l, - hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, concat_weight_memory, - concat_iter_memory, x_memory, hcx_memory, wx_memory, - wh_memory, bias_memory, y_memory, hcy_memory, - rnn_forward_prim, 1, has_cache, dtype, is_train, mode); - } - } - } - *has_cache = true; -} - -template -static void MKLDNNRNNForwardInference(bool state_outputs, - const int num_layers, - const int direction, - const int seq_length, - const int batch_size, - const int input_size, - const int state_size, - DType* x_ptr, - DType* hx_ptr, - DType* cx_ptr, - DType* w_ptr, - DType* b_ptr, - DType* y_ptr, - DType* hy_ptr, - DType* cy_ptr, - std::vector* concat_weight_memory, - std::vector* concat_iter_memory, - std::vector *x_memory, - std::vector *hcx_memory, - std::vector *wx_memory, - std::vector *wh_memory, - std::vector *bias_memory, - std::vector *y_memory, - std::vector *hcy_memory, - std::vector *rnn_forward_prim, - bool *has_cache, - int dtype, - bool is_train, - int mode) { - switch (mode) { - case rnn_enum::kLstm: - case rnn_enum::kGru: - case rnn_enum::kRnnTanh: - case rnn_enum::kRnnRelu: - MKLDNNRNNForward(state_outputs, num_layers, direction, seq_length, - batch_size, input_size, state_size, x_ptr, hx_ptr, - cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, - concat_weight_memory, concat_iter_memory, x_memory, - hcx_memory, wx_memory, wh_memory, - bias_memory, y_memory, hcy_memory, rnn_forward_prim, - has_cache, dtype, is_train, mode); - break; - default: - LOG(FATAL) << "unknown RNN mode" << mode; - break; - } -} - -} // namespace op -} // namespace mxnet -#endif // MXNET_USE_MKLDNN == 1 -#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RNN_IMPL_H_ diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index cf6fe10fd328..fe1488c2ad88 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -26,11 +26,6 @@ #ifndef MXNET_OPERATOR_RNN_INL_H_ #define MXNET_OPERATOR_RNN_INL_H_ -#if MXNET_USE_CUDNN == 1 -STATIC_ASSERT_CUDNN_VERSION_GE(7000); -#endif -#define MXNET_USE_CUDNN_GE_7200 MXNET_USE_CUDNN == 1 && CUDNN_VERSION >= 7200 - #include #include #include @@ -46,13 +41,87 @@ STATIC_ASSERT_CUDNN_VERSION_GE(7000); #include "./math_functions-inl.h" #include "./operator_common.h" #include "./rnn_impl.h" -#if MXNET_USE_MKLDNN == 1 -#include "./nn/mkldnn/mkldnn_rnn_impl.h" + +#if MXNET_USE_CUDNN == 1 +STATIC_ASSERT_CUDNN_VERSION_GE(7000); #endif +#define MXNET_USE_CUDNN_GE_7200 MXNET_USE_CUDNN == 1 && CUDNN_VERSION >= 7200 namespace mxnet { namespace op { +namespace rnn_enum { + enum RNNOpInputs {kData, kParams, kState, kStateCell, kSequenceLength}; + enum RNNOpOutputs {kOut, kStateOut, kStateCellOut}; + enum RNNModeType {kRnnRelu, kRnnTanh, kLstm, kGru}; + enum RNNOpResource {kTempSpace, kCuDNNDropoutDescSpace}; +} + +struct RNNParam : public dmlc::Parameter { + uint32_t state_size; + uint32_t num_layers; + bool bidirectional, state_outputs; + int mode; + float p; + int seq_length_, batch_size_, input_size_; + + bool use_sequence_length; + dmlc::optional projection_size; + dmlc::optional lstm_state_clip_min, lstm_state_clip_max; + bool lstm_state_clip_nan; + + DMLC_DECLARE_PARAMETER(RNNParam) { + DMLC_DECLARE_FIELD(state_size) + .describe("size of the state for each layer"); + + DMLC_DECLARE_FIELD(num_layers) + .describe("number of stacked layers"); + + DMLC_DECLARE_FIELD(bidirectional).set_default(false) + .describe("whether to use bidirectional recurrent layers"); + + DMLC_DECLARE_FIELD(mode) + .add_enum("rnn_relu", rnn_enum::kRnnRelu) + .add_enum("rnn_tanh", rnn_enum::kRnnTanh) + .add_enum("lstm", rnn_enum::kLstm) + .add_enum("gru", rnn_enum::kGru) + .describe("the type of RNN to compute"); + + DMLC_DECLARE_FIELD(p).set_default(0.) + .set_range(0, 1) + .describe("drop rate of the dropout on the outputs of each RNN layer, except the last layer."); + + DMLC_DECLARE_FIELD(state_outputs).set_default(false) + .describe("Whether to have the states as symbol outputs."); + + DMLC_DECLARE_FIELD(projection_size) + .set_default(dmlc::optional()) + .describe("size of project size"); + + DMLC_DECLARE_FIELD(lstm_state_clip_min) + .set_default(dmlc::optional()) + .describe("Minimum clip value of LSTM states. This option must be used together with " + "lstm_state_clip_max."); + + DMLC_DECLARE_FIELD(lstm_state_clip_max) + .set_default(dmlc::optional()) + .describe("Maximum clip value of LSTM states. This option must be used together with " + "lstm_state_clip_min."); + + DMLC_DECLARE_FIELD(lstm_state_clip_nan) + .set_default(false) + .describe("Whether to stop NaN from propagating in state by clipping it to min/max. " + "If clipping range is not specified, this option is ignored."); + + DMLC_DECLARE_FIELD(use_sequence_length) + .set_default(false) + .describe( + "If set to true, this layer takes in an extra input parameter " + "`sequence_length` " + "to specify variable length sequence"); + } +}; + inline int GetRnnParamSize(int num_layer, int input_size, int state_size, @@ -86,9 +155,9 @@ inline int GetRnnParamSize(int num_layer, } inline int GetRnnBiasSize(int num_layer, - int state_size, - int direction, - int mode) { + int state_size, + int direction, + int mode) { int size = 2 * state_size * direction * num_layer; switch (mode) { case rnn_enum::kRnnRelu: @@ -104,6 +173,15 @@ inline int GetRnnBiasSize(int num_layer, return size; } +/* + * Calculate the space size of the intermediate results for RNN inference. + * The inference procedure of a fusion RNN operator calculates the outputs + * layer by layer. In one layer calculation, the steps are: + * - wx[1...Ngates] * x[1...T] among all time stamp(sz: TxNxHxNgates) + * - wh[1...Ngates] * h[t] time by time(sz: NxHxNgates) + * - output -> h[t](, c[t] additionally with Lstm) time by time(sz: NxH(x2)) + * - intermediate y[1...T] as next layer's inputs(sz: TxNxHxD) + */ inline size_t GetRNNWorkspaceSize(int seq_length, int batch_size, int hidden_size, @@ -112,15 +190,19 @@ inline size_t GetRNNWorkspaceSize(int seq_length, size_t size = 0; switch (mode) { case rnn_enum::kLstm: - size = (seq_length + 1) * batch_size * hidden_size * 4 + batch_size * hidden_size * 2 - + seq_length * batch_size * hidden_size * direction + hidden_size * seq_length * 8; + size = seq_length * batch_size * hidden_size * (4 + direction) + // wx*x + inter-y + batch_size * hidden_size * 6 + // wh*h + h + c + seq_length * hidden_size * 8; // Used in Backward, Δbx, Δbh break; case rnn_enum::kGru: - size = seq_length * batch_size * hidden_size * direction * 4 + batch_size * hidden_size * 8; + // Differs with Lstm, the outputs of three gates are also held in memory + size = seq_length * batch_size * hidden_size * direction * (3 + 1) + // wx*x + inter-y + batch_size * hidden_size * (6 + direction); // wh*h + h + Ngates break; case rnn_enum::kRnnRelu: case rnn_enum::kRnnTanh: - size = seq_length * batch_size * hidden_size * direction * 2 + batch_size * hidden_size * 4; + size = seq_length * batch_size * hidden_size * direction * 2 + // wx*x + inter-y + batch_size * hidden_size * (1 + direction); // h + Ngates break; default: LOG(FATAL) << "unknown RNN mode " << mode; @@ -158,71 +240,6 @@ inline size_t GetRNNReserveSpaceSize(int num_layer, return size; } -struct RNNParam : public dmlc::Parameter { - uint32_t state_size; - uint32_t num_layers; - bool bidirectional, state_outputs; - int mode; - float p; - int seq_length_, batch_size_, input_size_; - - bool use_sequence_length; - dmlc::optional projection_size; - dmlc::optional lstm_state_clip_min, lstm_state_clip_max; - bool lstm_state_clip_nan; - - DMLC_DECLARE_PARAMETER(RNNParam) { - DMLC_DECLARE_FIELD(state_size) - .describe("size of the state for each layer"); - - DMLC_DECLARE_FIELD(num_layers) - .describe("number of stacked layers"); - - DMLC_DECLARE_FIELD(bidirectional).set_default(false) - .describe("whether to use bidirectional recurrent layers"); - - DMLC_DECLARE_FIELD(mode) - .add_enum("rnn_relu", rnn_enum::kRnnRelu) - .add_enum("rnn_tanh", rnn_enum::kRnnTanh) - .add_enum("lstm", rnn_enum::kLstm) - .add_enum("gru", rnn_enum::kGru) - .describe("the type of RNN to compute"); - - DMLC_DECLARE_FIELD(p).set_default(0.) - .set_range(0, 1) - .describe("drop rate of the dropout on the outputs of each RNN layer, except the last layer."); - - DMLC_DECLARE_FIELD(state_outputs).set_default(false) - .describe("Whether to have the states as symbol outputs."); - - DMLC_DECLARE_FIELD(projection_size) - .set_default(dmlc::optional()) - .describe("size of project size"); - - DMLC_DECLARE_FIELD(lstm_state_clip_min) - .set_default(dmlc::optional()) - .describe("Minimum clip value of LSTM states. This option must be used together with " - "lstm_state_clip_max."); - - DMLC_DECLARE_FIELD(lstm_state_clip_max) - .set_default(dmlc::optional()) - .describe("Maximum clip value of LSTM states. This option must be used together with " - "lstm_state_clip_min."); - - DMLC_DECLARE_FIELD(lstm_state_clip_nan) - .set_default(false) - .describe("Whether to stop NaN from propagating in state by clipping it to min/max. " - "If clipping range is not specified, this option is ignored."); - - DMLC_DECLARE_FIELD(use_sequence_length) - .set_default(false) - .describe( - "If set to true, this layer takes in an extra input parameter " - "`sequence_length` " - "to specify variable length sequence"); - } -}; - inline size_t GetNumInputArguments(RNNParam param_) { size_t num_inputs = (param_.mode == rnn_enum::kLstm) ? 4U : 3U; if (param_.use_sequence_length) num_inputs += 1U; @@ -398,30 +415,11 @@ class RNNOp { public: RNNParam param_; Context ctx_; -#if MXNET_USE_MKLDNN == 1 - std::vector concat_weight_memory; - std::vector concat_iter_memory; - std::vector rnn_forward_prim; - std::vector x_memory; - std::vector hcx_memory; - std::vector wx_memory; - std::vector wh_memory; - std::vector bias_memory; - std::vector y_memory; - std::vector hcy_memory; - size_t weights_version; - bool has_cache; - bool init_mem_; - size_t reserve_mem_size_; - NDArray mem_space_; -#endif + explicit RNNOp(RNNParam param, Context ctx) { this->param_ = param; this->ctx_ = ctx; -#if MXNET_USE_MKLDNN == 1 - init_mem_ = false; - reserve_mem_size_ = 0; -#endif + #if MXNET_USE_CUDNN == 1 init_cudnn_ = false; dtype_ = mshadow::DataType::kCudnnFlag; @@ -500,7 +498,7 @@ class RNNOp { CUDNN_CALL(cudnnCreateRNNDataDescriptor(&y_data_desc_)); CUDNN_CALL(cudnnCreateRNNDataDescriptor(&dx_data_desc_)); CUDNN_CALL(cudnnCreateRNNDataDescriptor(&dy_data_desc_)); -#endif +#endif // MXNET_USE_CUDNN_GE_7200 #else if (ctx_.dev_type == kGPU) { LOG(FATAL) << "RNN on GPU is only available for cuDNN at the moment."; @@ -886,64 +884,23 @@ class RNNOp { param_.p, param_.mode); } else { -#if MXNET_USE_MKLDNN == 1 - if (dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1) && param_.mode != rnn_enum::kGru) { - // TODO(zixuanweeei): MKLDNN GRU has precision issue. A stable one - // will be added to MXNet when we figure out the issue. - int dtype = in_data[rnn_enum::kData].type_flag_; - MKLDNNRNNForwardInference(param_.state_outputs, - param_.num_layers, - direction, - param_.seq_length_, - param_.batch_size_, - param_.input_size_, - param_.state_size, - x.dptr_, - hx.dptr_, - cx_ptr, - w.dptr_, - b_ptr, - y.dptr_, - hy_ptr, - cy_ptr, - &concat_weight_memory, - &concat_iter_memory, - &x_memory, - &hcx_memory, - &wx_memory, - &wh_memory, - &bias_memory, - &y_memory, - &hcy_memory, - &rnn_forward_prim, - &has_cache, - dtype, - ctx.is_train, - param_.mode); - } else { -#endif // MXNET_USE_MKLDNN == 1 - // Before integrating MKLDNN GRU fp32 inference - // using below code for keep func being OK - RNNForwardInference(work_cpu_space, - param_.state_outputs, - param_.num_layers, - direction, - param_.seq_length_, - param_.batch_size_, - param_.input_size_, - param_.state_size, - x.dptr_, - hx.dptr_, - cx_ptr, - w.dptr_, - b_ptr, - y.dptr_, - hy_ptr, - cy_ptr, - param_.mode); -#if MXNET_USE_MKLDNN == 1 - } -#endif + RNNForwardInference(work_cpu_space, + param_.state_outputs, + param_.num_layers, + direction, + param_.seq_length_, + param_.batch_size_, + param_.input_size_, + param_.state_size, + x.dptr_, + hx.dptr_, + cx_ptr, + w.dptr_, + b_ptr, + y.dptr_, + hy_ptr, + cy_ptr, + param_.mode); } } } @@ -1493,6 +1450,10 @@ class RNNOp { } #endif // MXNET_USE_CUDNN == 1 && defined(__CUDACC__) } + // naive private variables used in CPU Context + bool init_space_, temp_init_space_; + size_t reserve_cpu_space_size_, temp_cpu_space_size_; + NDArray reserve_cpu_space_, temp_cpu_space_; #if MXNET_USE_CUDNN == 1 && defined(__CUDACC__) // cuDNN versions up to and including v7.6.4 did not sync a last dgrad kernel back to the main @@ -1537,39 +1498,8 @@ class RNNOp { cudaEvent_t dgrad_sync_event_; bool dgrad_sync_needed_ = false; #endif // MXNET_USE_CUDNN - bool init_space_, temp_init_space_; - size_t reserve_cpu_space_size_, temp_cpu_space_size_; - NDArray reserve_cpu_space_, temp_cpu_space_; }; // class RNNOp -static OpStatePtr CreateRNNState(const nnvm::NodeAttrs &attrs, - const Context ctx, - const mxnet::ShapeVector &in_shapes, - const std::vector &in_types) { - const RNNParam& param = nnvm::get(attrs.parsed); - OpStatePtr state = OpStatePtr(); - int dtype = in_types[rnn_enum::kData]; - int itype = dtype; - if (param.use_sequence_length) { - size_t seq_len_input_idx = rnn_enum::kSequenceLength; - if (param.mode != rnn_enum::kLstm) { - seq_len_input_idx -= 1; - } - itype = in_types[seq_len_input_idx]; - } - - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - MSHADOW_TYPE_SWITCH(itype, IType, { - if (ctx.dev_type == kGPU) { - state = OpStatePtr::Create>(param, ctx); - } else { - state = OpStatePtr::Create>(param, ctx); - } - }); - }); - return state; -} - template void RNNStatefulCompute(const OpStatePtr& state, const OpContext& ctx, @@ -1581,14 +1511,14 @@ void RNNStatefulCompute(const OpStatePtr& state, // Hacky. This relies on fact that seq-len type is either the last input, // or we aren't using seq-len input and this type should be same as dtype. // Would prefer direct access to RNNParam object here but not sure how to get. - int itype = inputs[inputs.size()-1].type_flag_; + int itype = inputs[inputs.size() - 1].type_flag_; MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - MSHADOW_TYPE_SWITCH(itype, IType, { - RNNOp& op = state.get_state>(); - op.Forward(ctx, inputs, req, outputs); - }); + MSHADOW_TYPE_SWITCH(itype, IType, { + RNNOp& op = state.get_state>(); + op.Forward(ctx, inputs, req, outputs); }); + }); } /* @@ -1620,38 +1550,38 @@ void RNNStatefulGradCompute(const OpStatePtr& state, // Hacky. This relies on fact that seq-len type is either the last input, // or we aren't using seq-len input and this type should be same as dtype. // Would prefer direct access to RNNParam object here but not sure how to get. - int itype = outputs[outputs.size()-1].type_flag_; + int itype = outputs[outputs.size() - 1].type_flag_; MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - MSHADOW_TYPE_SWITCH(itype, IType, { - RNNOp& op = state.get_state>(); - const RNNParam& param = op.param_; - int index = 5; - if (param.state_outputs) { - out_data.push_back(inputs[index++]); - out_grad.push_back(inputs[index++]); - } - - if (param.mode == rnn_enum::kLstm) { - in_data.push_back(inputs[index++]); - if (param.state_outputs) { - out_data.push_back(inputs[index++]); - out_grad.push_back(inputs[index]); - } - } - - - if (param.use_sequence_length) { - size_t seq_len_input_idx = rnn_enum::kSequenceLength; - if (param.mode != rnn_enum::kLstm) { - seq_len_input_idx -= 1; - } - in_data.push_back(outputs[seq_len_input_idx]); - } - - op.Backward(ctx, out_grad, in_data, out_data, req, in_grad); - }); + MSHADOW_TYPE_SWITCH(itype, IType, { + RNNOp& op = state.get_state>(); + const RNNParam& param = op.param_; + int index = 5; + if (param.state_outputs) { + out_data.push_back(inputs[index++]); + out_grad.push_back(inputs[index++]); + } + + if (param.mode == rnn_enum::kLstm) { + in_data.push_back(inputs[index++]); + if (param.state_outputs) { + out_data.push_back(inputs[index++]); + out_grad.push_back(inputs[index]); + } + } + + + if (param.use_sequence_length) { + size_t seq_len_input_idx = rnn_enum::kSequenceLength; + if (param.mode != rnn_enum::kLstm) { + seq_len_input_idx -= 1; + } + in_data.push_back(outputs[seq_len_input_idx]); + } + + op.Backward(ctx, out_grad, in_data, out_data, req, in_grad); }); + }); } } // namespace op diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc index b2ac2f0cb615..78a3f04d3c7f 100644 --- a/src/operator/rnn.cc +++ b/src/operator/rnn.cc @@ -27,6 +27,9 @@ #include #include "./rnn-inl.h" +#if MXNET_USE_MKLDNN == 100 +#include "./nn/mkldnn/mkldnn_rnn-inl.h" +#endif // MXNET_USE_MKLDNN == 100 namespace mxnet { namespace op { @@ -190,9 +193,9 @@ inline static bool RNNStorageType(const nnvm::NodeAttrs& attrs, std::vector *out_attrs) { DispatchMode wanted_mode = DispatchMode::kFCompute; - #if MXNET_USE_MKLDNN == 1 - wanted_mode = DispatchMode::kFComputeEx; - #endif +#if MXNET_USE_MKLDNN == 100 + wanted_mode = DispatchMode::kFComputeEx; +#endif // MXNET_USE_MKLDNN == 100 return storage_type_assign(out_attrs, mxnet::kDefaultStorage, dispatch_mode, wanted_mode); @@ -222,432 +225,73 @@ struct RNNGrad { } }; -#if MXNET_USE_MKLDNN == 1 -static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - std::vector in_blobs; - std::vector out_blobs; - std::vector temp_ndarrays_i; - std::vector temp_ndarrays_o; - for (const NDArray& in : inputs) { - if (in.storage_type() == kDefaultStorage) { - temp_ndarrays_i.push_back(in.Reorder2Default()); - in_blobs.emplace_back(temp_ndarrays_i.back().data()); - } else { - in_blobs.emplace_back(in.data()); +static OpStatePtr CreateRNNState(const nnvm::NodeAttrs &attrs, + const Context ctx, + const mxnet::ShapeVector &in_shapes, + const std::vector &in_types) { + const RNNParam& param = nnvm::get(attrs.parsed); + OpStatePtr state = OpStatePtr(); + int dtype = in_types[rnn_enum::kData]; + int itype = dtype; + if (param.use_sequence_length) { + size_t seq_len_input_idx = rnn_enum::kSequenceLength; + if (param.mode != rnn_enum::kLstm) { + seq_len_input_idx -= 1; } + itype = in_types[seq_len_input_idx]; } - for (const NDArray& out : outputs) { - if (out.storage_type() == kDefaultStorage) { - temp_ndarrays_o.push_back(out.Reorder2Default()); - out_blobs.emplace_back(temp_ndarrays_o.back().data()); - } else { - out_blobs.emplace_back(out.data()); - } +#if MXNET_USE_MKLDNN == 100 + if ((in_types[0] == mshadow::kFloat32 || in_types[0] == mshadow::kFloat16) + && in_shapes[0].ndim() == 3 && ctx.dev_type == kCPU) { + const mxnet::TShape& data_shape = in_shapes[rnn_enum::kData]; + state = OpStatePtr::Create(param, data_shape[0], + data_shape[1], data_shape[2]); + return state; } - int dtype = in_blobs[rnn_enum::kData].type_flag_; - int itype = in_blobs[inputs.size()-1].type_flag_; - mkldnn::memory::data_type mkldnn_dtype = get_mkldnn_type(dtype); - Stream *s = ctx.get_stream(); - auto cpu_engine = CpuEngine::Get()->get_engine(); +#endif // MXNET_USE_MKLDNN == 100 + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { MSHADOW_TYPE_SWITCH(itype, IType, { - RNNOp& op = state_ptr.get_state>(); - const RNNParam& param = op.param_; - int ngates = 0, nstates = 0; - GetMKLDNNRNNAlgo(param.mode, &ngates, &nstates); - int D = param.bidirectional ? 2 : 1; - Tensor x = in_blobs[rnn_enum::kData].get(s); - int T = x.shape_[0]; - int N = x.shape_[1]; - int I = x.shape_[2]; - int H = param.state_size; - int L = param.num_layers; - - const size_t r_size = GetMKLDNNRNNCacheMemorySize(L, D, T, N, I, H, param.mode); - if (op.init_mem_ && op.reserve_mem_size_ < r_size) { - op.init_mem_ = false; - } - const size_t weights_version = inputs[rnn_enum::kParams].version(); - if (!op.init_mem_) { - op.mem_space_ = NDArray(TShape({static_cast(r_size)}), op.ctx_, false, dtype); - op.reserve_mem_size_ = r_size; - op.init_mem_ = true; - op.has_cache = false; - // Assign weights_version - op.weights_version = weights_version; + if (ctx.dev_type == kGPU) { + state = OpStatePtr::Create>(param, ctx); + } else { + state = OpStatePtr::Create>(param, ctx); } - // Check if NDArray was changed. - if (op.weights_version != weights_version) { - op.has_cache = false; - op.weights_version = weights_version; - } - - DType* workptr = static_cast(op.mem_space_.data().dptr_); - mkldnn::memory::dims src_layer_tz_0 = {T, N, I}; - mkldnn::memory::dims src_layer_tz = {T, N, D * H}; - mkldnn::memory::dims dst_layer_tz = {T, N, D * H}; - auto dst_layer_md = mkldnn::memory::desc( - { dst_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); - if (op.x_memory.size() == 0) { - if (D == 1 && I == H) { - auto user_src_layer_md = mkldnn::memory::desc( - { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); - auto user_src_layer_memory_n = mkldnn::memory({ user_src_layer_md, cpu_engine }); - op.x_memory.push_back(user_src_layer_memory_n); - - mkldnn::memory::dims weights_layer_tz = {L, 1, I, ngates, H}; // ldigo - mkldnn::memory::dims weights_iter_tz = {L, 1, H, ngates, H}; // ldigo - mkldnn::memory::dims bias_tz = {L, 1, ngates, H}; - auto user_weight_layer_md = mkldnn::memory::desc( - { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); - auto user_weight_iter_md = mkldnn::memory::desc( - { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); - auto user_bias_md = mkldnn::memory::desc({ bias_tz }, - mkldnn_dtype, mkldnn::memory::format::ldgo); - DType* weight_layer_n = workptr; // L * I * ngates * H - auto user_weight_layer_memory_n - = mkldnn::memory({ user_weight_layer_md, cpu_engine }, weight_layer_n); - op.wx_memory.push_back(user_weight_layer_memory_n); - - DType* weight_iter_n = weight_layer_n + L * I * ngates * H; // L * H * ngates * H - auto user_weight_iter_memory_n - = mkldnn::memory({ user_weight_iter_md, cpu_engine }, weight_iter_n); - op.wh_memory.push_back(user_weight_iter_memory_n); - - DType* bias_n = weight_iter_n + L * H * ngates * H; // L * ngates * H - auto user_bias_memory_n = - mkldnn::memory({ user_bias_md, cpu_engine }, bias_n); - op.bias_memory.push_back(user_bias_memory_n); - - auto wx_md_n = mkldnn::memory::desc( - { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); - DType* wx_n = bias_n + L * ngates * H; // L * ngates * I * H - auto wx_memory_n = - mkldnn::memory({ wx_md_n, cpu_engine }, wx_n); - DType* wh_n = wx_n + L * ngates * I * H; // L * ngates * H * H - auto wh_md_n = mkldnn::memory::desc( - { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); - auto wh_memory_n = - mkldnn::memory({ wh_md_n, cpu_engine }, wh_n); - - op.concat_weight_memory.push_back(wx_memory_n); - op.concat_weight_memory.push_back(wh_memory_n); - workptr = wh_n + L * ngates * H * H; - - mkldnn::memory::dims src_iter_tz_n1 = {1, 1, nstates, N, H}; // ldsnc - auto src_iter_md_n1 = mkldnn::memory::desc( - { src_iter_tz_n1 }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - for (int l = 0; l < L; l++) { - DType* src_iter_n1 = workptr; // nstates * N * H - auto src_iter_memory_n1 = - mkldnn::memory({ src_iter_md_n1, cpu_engine }, src_iter_n1); - op.concat_iter_memory.push_back(src_iter_memory_n1); - workptr = src_iter_n1 + nstates * N * H; - } - mkldnn::memory::dims src_iter_tz_n = {L, 1, nstates, N, H}; // ldsnc - auto src_iter_md_n = mkldnn::memory::desc( - { src_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - DType* src_iter_n = workptr; // L * nstates * N * H - auto src_iter_memory_n = - mkldnn::memory({ src_iter_md_n, cpu_engine }, src_iter_n); - op.concat_iter_memory.push_back(src_iter_memory_n); - op.hcx_memory.push_back(src_iter_memory_n); - DType* dst_layer_n = src_iter_n + L * nstates * N * H; // T * N * D * H - auto dst_layer_memory_n - = mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_n); - op.y_memory.push_back(dst_layer_memory_n); - - mkldnn::memory::dims dst_iter_tz_n = {L, 1, nstates, N, H}; // ldsnc - auto dst_iter_md_n = mkldnn::memory::desc( - { dst_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - DType* dst_iter_n = dst_layer_n + T * N * D * H; // L * nstates * N * H - auto dst_iter_memory_n = - mkldnn::memory({ dst_iter_md_n, cpu_engine }, dst_iter_n); - op.hcy_memory.push_back(dst_iter_memory_n); - workptr = dst_iter_n + L * nstates * N * H; - - } else { - auto user_src_layer_md_0 = mkldnn::memory::desc( - { src_layer_tz_0 }, mkldnn_dtype, mkldnn::memory::format::tnc); - auto user_src_layer_memory_0 = mkldnn::memory({ user_src_layer_md_0, cpu_engine }); - op.x_memory.push_back(user_src_layer_memory_0); - - mkldnn::memory::dims weights_layer_tz_0 = {1, D, I, ngates, H}; // ldigo - mkldnn::memory::dims weights_iter_tz_0 = {1, D, H, ngates, H}; // ldigo - mkldnn::memory::dims bias_tz_0 = {1, D, ngates, H}; - auto user_weight_layer_md_0 = mkldnn::memory::desc( - { weights_layer_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldigo); - auto user_weight_iter_md_0 = mkldnn::memory::desc( - { weights_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldigo); - auto user_bias_md_0 = mkldnn::memory::desc({ bias_tz_0 }, - mkldnn_dtype, mkldnn::memory::format::ldgo); - - DType* weight_layer_0 = workptr; // D * I * ngates * H - auto user_weight_layer_memory_0 - = mkldnn::memory({ user_weight_layer_md_0, cpu_engine }, weight_layer_0); - op.wx_memory.push_back(user_weight_layer_memory_0); - - DType* weight_iter_0 = weight_layer_0 + D * I * ngates * H; // D * H * ngates * H - auto user_weight_iter_memory_0 - = mkldnn::memory({ user_weight_iter_md_0, cpu_engine }, weight_iter_0); - op.wh_memory.push_back(user_weight_iter_memory_0); - - DType* bias_0 = weight_iter_0 + D * H * ngates * H; // D * ngates * H - auto user_bias_memory_0 = - mkldnn::memory({ user_bias_md_0, cpu_engine }, bias_0); - op.bias_memory.push_back(user_bias_memory_0); - workptr = bias_0 + D * ngates * H; - - auto wx_md_0 = mkldnn::memory::desc( - { weights_layer_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldgoi); - auto wx_memory_0 = - mkldnn::memory({ wx_md_0, cpu_engine }); - auto wh_md_0 = mkldnn::memory::desc( - { weights_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldgoi); - auto wh_memory_0 = - mkldnn::memory({ wh_md_0, cpu_engine }); - if (D == 2) { - DType* wx_0 = workptr; // D * ngates * I * H - wx_memory_0.set_data_handle(wx_0); - DType* wh_0 = wx_0 + D * ngates * I * H; // D * ngates * H * H - wh_memory_0.set_data_handle(wh_0); - workptr = wh_0 + D * ngates * H * H; - } - op.concat_weight_memory.push_back(wx_memory_0); - op.concat_weight_memory.push_back(wh_memory_0); - - mkldnn::memory::dims src_iter_undi_tz_0 = {1, 1, nstates, N, H}; // ldsnc - auto src_iter_undi_md_0 = mkldnn::memory::desc( - { src_iter_undi_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - DType* src_iter_undi_0 = workptr; // nstates * N * H - auto src_iter_undi_memory_0 = - mkldnn::memory({ src_iter_undi_md_0, cpu_engine }, src_iter_undi_0); - op.concat_iter_memory.push_back(src_iter_undi_memory_0); - workptr = src_iter_undi_0 + nstates * N * H; - if (D == 1) { - op.hcx_memory.push_back(src_iter_undi_memory_0); - } else { - DType* src_iter_undi2_0 = workptr; // nstates * N * H - auto src_iter_undi2_memory_0 = - mkldnn::memory({ src_iter_undi_md_0, cpu_engine }, src_iter_undi2_0); - op.concat_iter_memory.push_back(src_iter_undi2_memory_0); - - mkldnn::memory::dims src_iter_tz_0 = {1, D, nstates, N, H}; // ldsnc - auto src_iter_md_0 = mkldnn::memory::desc( - { src_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - DType* src_iter_0 = src_iter_undi2_0 + nstates * N * H; // D * nstates * N * H - auto src_iter_memory_0 = - mkldnn::memory({ src_iter_md_0, cpu_engine }, src_iter_0); - op.concat_iter_memory.push_back(src_iter_memory_0); - op.hcx_memory.push_back(src_iter_memory_0); - workptr = src_iter_0 + D * nstates * N * H; - } - - DType* dst_layer_0 = workptr; // T * N * D * H - auto dst_layer_memory_0 - = mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_0); - op.y_memory.push_back(dst_layer_memory_0); - - mkldnn::memory::dims dst_iter_tz_0 = {1, D, nstates, N, H}; // ldsnc - auto dst_iter_md_0 = mkldnn::memory::desc( - { dst_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - DType* dst_iter_0 = dst_layer_0 + T * N * D * H; // D * nstates * N * H - auto dst_iter_memory_0 = - mkldnn::memory({ dst_iter_md_0, cpu_engine }, dst_iter_0); - op.hcy_memory.push_back(dst_iter_memory_0); - workptr = dst_iter_0 + D * nstates * N * H; - - // next L - 1 layers - if (L > 1 && D == 1) { - auto user_src_layer_md = mkldnn::memory::desc( - { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); - auto user_src_layer_memory = mkldnn::memory({ user_src_layer_md, cpu_engine }); - op.x_memory.push_back(user_src_layer_memory); - - mkldnn::memory::dims weights_layer_tz = {L - 1, 1, H, ngates, H}; // ldigo - mkldnn::memory::dims weights_iter_tz = {L - 1, 1, H, ngates, H}; // ldigo - mkldnn::memory::dims bias_tz = {L - 1, 1, ngates, H}; - auto user_weight_layer_md = mkldnn::memory::desc( - { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); - auto user_weight_iter_md = mkldnn::memory::desc( - { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); - auto user_bias_md = mkldnn::memory::desc({ bias_tz }, - mkldnn_dtype, mkldnn::memory::format::ldgo); - - DType* weight_layer_n = workptr; // (L - 1) * H * ngates * H - auto user_weight_layer_memory_n - = mkldnn::memory({ user_weight_layer_md, cpu_engine }, weight_layer_n); - op.wx_memory.push_back(user_weight_layer_memory_n); - - DType* weight_iter_n = weight_layer_n + - (L - 1) * H * ngates * H; // (L - 1) * H * ngates * H - auto user_weight_iter_memory_n - = mkldnn::memory({ user_weight_iter_md, cpu_engine }, weight_iter_n); - op.wh_memory.push_back(user_weight_iter_memory_n); - - DType* bias_n = weight_iter_n + (L - 1) * H * ngates * H; // (L - 1) * ngates * H - auto user_bias_memory_n = - mkldnn::memory({ user_bias_md, cpu_engine }, bias_n); - op.bias_memory.push_back(user_bias_memory_n); - - auto wx_md_n = mkldnn::memory::desc( - { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); - DType* wx_n = bias_n + (L - 1) * ngates * H; // (L - 1) * ngates * H * H - auto wx_memory_n = - mkldnn::memory({ wx_md_n, cpu_engine }, wx_n); - DType* wh_n = wx_n + (L - 1) * ngates * H * H; // (L - 1) * ngates * H * H - auto wh_md_n = mkldnn::memory::desc( - { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); - auto wh_memory_n = - mkldnn::memory({ wh_md_n, cpu_engine }, wh_n); - - op.concat_weight_memory.push_back(wx_memory_n); - op.concat_weight_memory.push_back(wh_memory_n); - workptr = wh_n + (L - 1) * ngates * H * H; - - mkldnn::memory::dims src_iter_tz_n1 = {1, 1, nstates, N, H}; // ldsnc - auto src_iter_md_n1 = mkldnn::memory::desc( - { src_iter_tz_n1 }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - for (int l = 0; l < L - 1; l++) { - DType* src_iter_n1 = workptr; // nstates * N * H - auto src_iter_memory_n1 = - mkldnn::memory({ src_iter_md_n1, cpu_engine }, src_iter_n1); - op.concat_iter_memory.push_back(src_iter_memory_n1); - workptr = src_iter_n1 + nstates * N * H; - } - mkldnn::memory::dims src_iter_tz_n = {L - 1, 1, nstates, N, H}; // ldsnc - auto src_iter_md_n = mkldnn::memory::desc( - { src_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - DType* src_iter_n = workptr; // (L - 1) * nstates * N * H - auto src_iter_memory_n = - mkldnn::memory({ src_iter_md_n, cpu_engine }, src_iter_n); - op.concat_iter_memory.push_back(src_iter_memory_n); - op.hcx_memory.push_back(src_iter_memory_n); - - DType* dst_layer_n = src_iter_n + (L - 1) * nstates * N * H; // T * N * D * H - auto dst_layer_memory_n - = mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_n); - op.y_memory.push_back(dst_layer_memory_n); - - mkldnn::memory::dims dst_iter_tz_n = {L - 1, 1, nstates, N, H}; // ldsnc - auto dst_iter_md_n = mkldnn::memory::desc( - { dst_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - DType* dst_iter_n = dst_layer_n + T * N * D * H; // (L - 1) * nstates * N * H - auto dst_iter_memory_n = - mkldnn::memory({ dst_iter_md_n, cpu_engine }, dst_iter_n); - op.hcy_memory.push_back(dst_iter_memory_n); - } - - if (L > 1 && D == 2) { - mkldnn::memory::dims weights_layer_tz = {1, D, H * D, ngates, H}; // ldigo - mkldnn::memory::dims weights_iter_tz = {1, D, H, ngates, H}; // ldigo - mkldnn::memory::dims bias_tz = {1, D, ngates, H}; - auto user_weight_layer_md = mkldnn::memory::desc( - { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); - auto user_weight_iter_md = mkldnn::memory::desc( - { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); - auto user_bias_md = mkldnn::memory::desc({ bias_tz }, - mkldnn_dtype, mkldnn::memory::format::ldgo); - - auto user_src_layer_md = mkldnn::memory::desc( - { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); - auto user_src_layer_memory = mkldnn::memory({ user_src_layer_md, cpu_engine }); - op.x_memory.push_back(user_src_layer_memory); - - auto wx_md_n = mkldnn::memory::desc( - { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); - auto wh_md_n = mkldnn::memory::desc( - { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); - - for (int l = 0; l < L; l++) { - DType* weight_layer_n = workptr; // D * (H * D) * ngates * H - auto user_weight_layer_memory_n - = mkldnn::memory({ user_weight_layer_md, cpu_engine }, weight_layer_n); - op.wx_memory.push_back(user_weight_layer_memory_n); - - DType* weight_iter_n = weight_layer_n + - D * (H * D) * ngates * H; // D * H * ngates * H - auto user_weight_iter_memory_n - = mkldnn::memory({ user_weight_iter_md, cpu_engine }, weight_iter_n); - op.wh_memory.push_back(user_weight_iter_memory_n); - - DType* bias_n = weight_iter_n + D * H * ngates * H; // D * ngates * H - auto user_bias_memory_n = - mkldnn::memory({ user_bias_md, cpu_engine }, bias_n); - op.bias_memory.push_back(user_bias_memory_n); - workptr = bias_n + D * ngates * H; - } - - DType* wx_n = workptr; // D * ngates * (D * H) * H - DType* wh_n = wx_n + D * ngates * (D * H) * H; // D * ngates * H * H - auto wx_memory_n = - mkldnn::memory({ wx_md_n, cpu_engine }, wx_n); - auto wh_memory_n = - mkldnn::memory({ wh_md_n, cpu_engine }, wh_n); - op.concat_weight_memory.push_back(wx_memory_n); - op.concat_weight_memory.push_back(wh_memory_n); - - mkldnn::memory::dims src_iter_undi_tz = {1, 1, nstates, N, H}; // ldsnc - auto src_iter_undi_md = mkldnn::memory::desc( - { src_iter_undi_tz }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - DType* src_iter_undi = wh_n + D * ngates * H * H; // nstates * N * H - auto src_iter_undi_memory = - mkldnn::memory({ src_iter_undi_md, cpu_engine }, src_iter_undi); - op.concat_iter_memory.push_back(src_iter_undi_memory_0); - - DType* src_iter_undi2 = src_iter_undi + nstates * N * H; // nstates * N * H - auto src_iter_undi2_memory = - mkldnn::memory({ src_iter_undi_md, cpu_engine }, src_iter_undi2); - op.concat_iter_memory.push_back(src_iter_undi2_memory); - - mkldnn::memory::dims src_iter_tz = {1, D, nstates, N, H}; // ldsnc - auto src_iter_md = mkldnn::memory::desc( - { src_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - DType* src_iter = src_iter_undi2 + nstates * N * H; // D * nstates * N * H - auto src_iter_memory = - mkldnn::memory({ src_iter_md, cpu_engine }, src_iter); - op.concat_iter_memory.push_back(src_iter_memory); - op.hcx_memory.push_back(src_iter_memory); - - DType* dst_layer_n = src_iter + D * nstates * N * H; // T * N * D * H - auto dst_layer_memory_n - = mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_n); - op.y_memory.push_back(dst_layer_memory_n); - - mkldnn::memory::dims dst_iter_tz_n = {1, D, nstates, N, H}; // ldsnc - auto dst_iter_md_n = mkldnn::memory::desc( - { dst_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - DType* dst_iter_n = dst_layer_n + T * N * D * H; // D * nstates * N * H - auto dst_iter_memory_n = - mkldnn::memory({ dst_iter_md_n, cpu_engine }, dst_iter_n); - op.hcy_memory.push_back(dst_iter_memory_n); - } - } - } - op.Forward(ctx, in_blobs, req, out_blobs); }); }); + return state; } -static void RNNStatefulComputeExCPU(const OpStatePtr& state_ptr, const OpContext& ctx, +#if MXNET_USE_MKLDNN == 100 +static void RNNStatefulComputeExCPU(const OpStatePtr& state_ptr, + const OpContext& ctx, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - if (SupportMKLDNNRNN(inputs[0])) { - RNNStatefulComputeCPU(state_ptr, ctx, inputs, req, outputs); - return; + if ((inputs[0].dtype() == mshadow::kFloat32 || inputs[0].dtype() == mshadow::kFloat16) && + inputs[0].shape().ndim() == 3) { + MKLDNNRnnOp& op = state_ptr.get_state(); + op.Forward(ctx, inputs, req, outputs); + } else { + FallBackCompute(RNNStatefulCompute, state_ptr, ctx, inputs, req, outputs); } - int use_mkldnn_rnn = dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1); - dmlc::SetEnv("MXNET_USE_MKLDNN_RNN", 0); - FallBackCompute(RNNStatefulCompute, state_ptr, ctx, inputs, req, outputs); - dmlc::SetEnv("MXNET_USE_MKLDNN_RNN", use_mkldnn_rnn); } -#endif + +static void RNNStatefulGradComputeExCPU(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + if ((inputs[0].dtype() == mshadow::kFloat32 || inputs[0].dtype() == mshadow::kFloat16) && + inputs[0].shape().ndim() == 3) { + MKLDNNRnnOp& op = state_ptr.get_state(); + op.Backward(ctx, inputs, req, outputs); + } else { + FallBackCompute(RNNStatefulGradCompute, state_ptr, ctx, inputs, req, outputs); + } +} +#endif // MXNET_USE_MKLDNN == 100 NNVM_REGISTER_OP(RNN) .add_alias("_npx_rnn") @@ -726,12 +370,22 @@ The definition of GRU here is slightly different from paper but compatible with const RNNParam& params = nnvm::get(attrs.parsed); return ListArguments(params); }) +.set_attr("FListOutputNames", [](const NodeAttrs& attrs) { + const RNNParam& params = nnvm::get(attrs.parsed); + std::vector names{"output"}; + if (params.state_outputs) { + names.emplace_back("state_output"); + if (params.mode == rnn_enum::kLstm) + names.emplace_back("statecell_output"); + } + return names; +}) .set_attr("FInferShape", RNNShape) .set_attr("FInferType", RNNType) .set_attr("FInferStorageType", RNNStorageType) .set_attr("FCreateOpState", CreateRNNState) .set_attr("FStatefulCompute", RNNStatefulCompute) -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 .set_attr("TIsMKLDNN", true) .set_attr("FStatefulComputeEx", RNNStatefulComputeExCPU) #endif @@ -756,7 +410,12 @@ NNVM_REGISTER_OP(_backward_RNN) .set_attr_parser(ParamParser) .set_attr("TIsLayerOpBackward", true) .set_attr("TIsBackward", true) +.set_attr("FInferStorageType", RNNStorageType) .set_attr("FStatefulCompute", RNNStatefulGradCompute) +#if MXNET_USE_MKLDNN == 100 +.set_attr("TIsMKLDNN", true) +.set_attr("FStatefulComputeEx", RNNStatefulGradComputeExCPU) +#endif .set_attr("FResourceRequestEx", RNNResourceEx); } // namespace op } // namespace mxnet diff --git a/src/operator/rnn_impl.h b/src/operator/rnn_impl.h index 425ea4a3c6ab..e1b4a2b79c0a 100644 --- a/src/operator/rnn_impl.h +++ b/src/operator/rnn_impl.h @@ -44,13 +44,6 @@ namespace mxnet { namespace op { -namespace rnn_enum { - enum RNNOpInputs {kData, kParams, kState, kStateCell, kSequenceLength}; - enum RNNOpOutputs {kOut, kStateOut, kStateCellOut}; - enum RNNModeType {kRnnRelu, kRnnTanh, kLstm, kGru}; - enum RNNOpResource {kTempSpace, kCuDNNDropoutDescSpace}; -} - template inline DType sigmoid(DType x) { return 1.0f / (1.0f + exp(-x)); diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 35460676da28..713b11ead48b 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -36,6 +36,14 @@ import os def check_rnn_consistency(cell1, cell2, T, N, I, H, grad_req, rtol=1e-2, atol=1e-4): + if default_context().device_type == 'cpu': + # NOTE(zixuanweeei): Currently, we don't add `add` requests support on fused mkl-dnn rnn operator. + if isinstance(grad_req, dict) and 'add' in grad_req.values(): + print("Skip the test when requiring `add` operation against gradients on CPU context.") + return + if isinstance(grad_req, str) and grad_req == 'add': + print("Skip the test when requiring `add` operation against gradients on CPU context.") + return dshape = (N, T, I) data = mx.sym.Variable('data') @@ -86,7 +94,7 @@ def test_rnn_with_new_param(): for mode, ngates in zip(rnn_modes, ngates_): first_layer_size = (input_size * state_size + state_size * state_size + state_size * 2) * ngates rest_layer_size = (state_size * directions * state_size + state_size * state_size + state_size * 2) \ - * ngates * (num_layers - 1) + * ngates * (num_layers - 1) param_size = (first_layer_size + rest_layer_size) * directions sym = mx.sym.RNN(mode=mode, num_layers=num_layers, bidirectional=bidirectional, state_outputs=False, state_size=state_size, name='rnn') @@ -118,149 +126,176 @@ def test_rnn_with_new_param(): @with_seed() @assert_raises_cudnn_not_satisfied(min_version='5.1.10') def test_lstm_sym(): - T, N, I, H = 5, 32, 800, 800 - fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='lstm', get_next_state=True, prefix='') - stack = mx.rnn.SequentialRNNCell() - stack.add(mx.rnn.LSTMCell(H, prefix='l0_')) - stack.add(mx.rnn.LSTMCell(H, prefix='l1_')) - stack.add(mx.rnn.LSTMCell(H, prefix='l2_')) - - check_rnn_consistency(fused, stack, T, N, I, H, 'write') - check_rnn_consistency(fused, stack, T, N, I, H, 'add') - check_rnn_consistency(fused, stack, T, N, I, H, 'null') + Ts = [1, 5] + Ns = [1, 32] + Is = [32, 128, 512] + Hs = [32, 128, 512] + for T, N, I, H in itertools.product(Ts, Ns, Is, Hs): + fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='lstm', get_next_state=True, prefix='') + stack = mx.rnn.SequentialRNNCell() + stack.add(mx.rnn.LSTMCell(H, prefix='l0_')) + stack.add(mx.rnn.LSTMCell(H, prefix='l1_')) + stack.add(mx.rnn.LSTMCell(H, prefix='l2_')) + + check_rnn_consistency(fused, stack, T, N, I, H, 'write') + check_rnn_consistency(fused, stack, T, N, I, H, 'add') + check_rnn_consistency(fused, stack, T, N, I, H, 'null') @with_seed() @assert_raises_cudnn_not_satisfied(min_version='5.1.10') def test_lstm_bidirectional(): - T, N, I, H = 5, 20, 800, 800 - fused = mx.rnn.FusedRNNCell(H, num_layers=2, mode='lstm', - bidirectional=True, get_next_state=True, prefix='') - - stack = mx.rnn.SequentialRNNCell() - stack.add(mx.rnn.BidirectionalCell( - mx.rnn.LSTMCell(H, prefix='l0_'), - mx.rnn.LSTMCell(H, prefix='r0_'), - output_prefix='bi_lstm_0_')) - stack.add(mx.rnn.BidirectionalCell( - mx.rnn.LSTMCell(H, prefix='l1_'), - mx.rnn.LSTMCell(H, prefix='r1_'), - output_prefix='bi_lstm_1_')) - - check_rnn_consistency(fused, stack, T, N, I, H, 'write') - check_rnn_consistency(fused, stack, T, N, I, H, 'add') - check_rnn_consistency(fused, stack, T, N, I, H, 'null') - check_rnn_consistency(fused, stack, T, N, I, H, {'data': 'add', 'parameters': 'null'}) + Ts = [1, 5] + Ns = [1, 32] + Is = [32, 128, 512] + Hs = [32, 128, 512] + for T, N, I, H in itertools.product(Ts, Ns, Is, Hs): + fused = mx.rnn.FusedRNNCell(H, num_layers=2, mode='lstm', + bidirectional=True, get_next_state=True, prefix='') + + stack = mx.rnn.SequentialRNNCell() + stack.add(mx.rnn.BidirectionalCell( + mx.rnn.LSTMCell(H, prefix='l0_'), + mx.rnn.LSTMCell(H, prefix='r0_'), + output_prefix='bi_lstm_0_')) + stack.add(mx.rnn.BidirectionalCell( + mx.rnn.LSTMCell(H, prefix='l1_'), + mx.rnn.LSTMCell(H, prefix='r1_'), + output_prefix='bi_lstm_1_')) + + check_rnn_consistency(fused, stack, T, N, I, H, 'write') + check_rnn_consistency(fused, stack, T, N, I, H, 'add') + check_rnn_consistency(fused, stack, T, N, I, H, 'null') + check_rnn_consistency(fused, stack, T, N, I, H, {'data': 'add', 'parameters': 'null'}) @with_seed() @assert_raises_cudnn_not_satisfied(min_version='5.1.10') def test_gru_sym(): - T, N, I, H = 5, 32, 800, 800 - fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='gru', get_next_state=True, prefix='') - stack = mx.rnn.SequentialRNNCell() - stack.add(mx.rnn.GRUCell(H, prefix='l0_')) - stack.add(mx.rnn.GRUCell(H, prefix='l1_')) - stack.add(mx.rnn.GRUCell(H, prefix='l2_')) - - check_rnn_consistency(fused, stack, T, N, I, H, 'write') - check_rnn_consistency(fused, stack, T, N, I, H, 'add') - check_rnn_consistency(fused, stack, T, N, I, H, 'null') + Ts = [1, 5] + Ns = [1, 32] + Is = [32, 128, 512] + Hs = [32, 128, 512] + for T, N, I, H in itertools.product(Ts, Ns, Is, Hs): + fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='gru', get_next_state=True, prefix='') + stack = mx.rnn.SequentialRNNCell() + stack.add(mx.rnn.GRUCell(H, prefix='l0_')) + stack.add(mx.rnn.GRUCell(H, prefix='l1_')) + stack.add(mx.rnn.GRUCell(H, prefix='l2_')) + + check_rnn_consistency(fused, stack, T, N, I, H, 'write', atol=2e-4) + check_rnn_consistency(fused, stack, T, N, I, H, 'add', atol=2e-4) + check_rnn_consistency(fused, stack, T, N, I, H, 'null', atol=2e-4) @with_seed() @assert_raises_cudnn_not_satisfied(min_version='5.1.10') def test_gru_bidirectional(): - T, N, I, H = 5, 20, 800, 800 - - fused = mx.rnn.FusedRNNCell(H, num_layers=2, mode='gru', - bidirectional=True, get_next_state=True, prefix='') - - stack = mx.rnn.SequentialRNNCell() - stack.add(mx.rnn.BidirectionalCell( - mx.rnn.GRUCell(H, prefix='l0_'), - mx.rnn.GRUCell(H, prefix='r0_'), - output_prefix='bi_gru_0_')) - - stack.add(mx.rnn.BidirectionalCell( - mx.rnn.GRUCell(H, prefix='l1_'), - mx.rnn.GRUCell(H, prefix='r1_'), - output_prefix='bi_gru_1_')) - - check_rnn_consistency(fused, stack, T, N, I, H, 'write') - check_rnn_consistency(fused, stack, T, N, I, H, 'add') - check_rnn_consistency(fused, stack, T, N, I, H, 'null') + Ts = [1, 5] + Ns = [1, 32] + Is = [32, 128, 512] + Hs = [32, 128, 512] + for T, N, I, H in itertools.product(Ts, Ns, Is, Hs): + fused = mx.rnn.FusedRNNCell(H, num_layers=2, mode='gru', + bidirectional=True, get_next_state=True, prefix='') + + stack = mx.rnn.SequentialRNNCell() + stack.add(mx.rnn.BidirectionalCell( + mx.rnn.GRUCell(H, prefix='l0_'), + mx.rnn.GRUCell(H, prefix='r0_'), + output_prefix='bi_gru_0_')) + + stack.add(mx.rnn.BidirectionalCell( + mx.rnn.GRUCell(H, prefix='l1_'), + mx.rnn.GRUCell(H, prefix='r1_'), + output_prefix='bi_gru_1_')) + + check_rnn_consistency(fused, stack, T, N, I, H, 'write', atol=2e-4) + check_rnn_consistency(fused, stack, T, N, I, H, 'add', atol=2e-4) + check_rnn_consistency(fused, stack, T, N, I, H, 'null', atol=2e-4) @with_seed() @assert_raises_cudnn_not_satisfied(min_version='5.1.10') def test_rnntanh_sym(): - T, N, I, H = 5, 32, 800, 800 - - fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='rnn_tanh', get_next_state=True, prefix='') - stack = mx.rnn.SequentialRNNCell() - stack.add(mx.rnn.RNNCell(H, activation='tanh', prefix='l0_')) - stack.add(mx.rnn.RNNCell(H, activation='tanh', prefix='l1_')) - stack.add(mx.rnn.RNNCell(H, activation='tanh', prefix='l2_')) - - check_rnn_consistency(fused, stack, T, N, I, H, 'write') - check_rnn_consistency(fused, stack, T, N, I, H, 'add') - check_rnn_consistency(fused, stack, T, N, I, H, 'null') + Ts = [1, 5] + Ns = [1, 32] + Is = [32, 128, 512] + Hs = [32, 128, 512] + for T, N, I, H in itertools.product(Ts, Ns, Is, Hs): + fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='rnn_tanh', get_next_state=True, prefix='') + stack = mx.rnn.SequentialRNNCell() + stack.add(mx.rnn.RNNCell(H, activation='tanh', prefix='l0_')) + stack.add(mx.rnn.RNNCell(H, activation='tanh', prefix='l1_')) + stack.add(mx.rnn.RNNCell(H, activation='tanh', prefix='l2_')) + + check_rnn_consistency(fused, stack, T, N, I, H, 'write') + check_rnn_consistency(fused, stack, T, N, I, H, 'add') + check_rnn_consistency(fused, stack, T, N, I, H, 'null') @with_seed() @assert_raises_cudnn_not_satisfied(min_version='5.1.10') def test_rnntanh_bidirectional(): - T, N, I, H = 5, 20, 800, 800 - - fused = mx.rnn.FusedRNNCell(H, num_layers=2, mode='rnn_tanh', - bidirectional=True, get_next_state=True, prefix='') - - stack = mx.rnn.SequentialRNNCell() - stack.add(mx.rnn.BidirectionalCell( - mx.rnn.RNNCell(H, activation='tanh', prefix='l0_'), - mx.rnn.RNNCell(H, activation='tanh', prefix='r0_'), - output_prefix='bi_rnntanh_0_')) - stack.add(mx.rnn.BidirectionalCell( - mx.rnn.RNNCell(H, activation='tanh', prefix='l1_'), - mx.rnn.RNNCell(H, activation='tanh', prefix='r1_'), - output_prefix='bi_rnntanh_1_')) - - check_rnn_consistency(fused, stack, T, N, I, H, 'write') - check_rnn_consistency(fused, stack, T, N, I, H, 'add') - check_rnn_consistency(fused, stack, T, N, I, H, 'null') + Ts = [1, 5] + Ns = [1, 32] + Is = [32, 128, 512] + Hs = [32, 128, 512] + for T, N, I, H in itertools.product(Ts, Ns, Is, Hs): + fused = mx.rnn.FusedRNNCell(H, num_layers=2, mode='rnn_tanh', + bidirectional=True, get_next_state=True, prefix='') + + stack = mx.rnn.SequentialRNNCell() + stack.add(mx.rnn.BidirectionalCell( + mx.rnn.RNNCell(H, activation='tanh', prefix='l0_'), + mx.rnn.RNNCell(H, activation='tanh', prefix='r0_'), + output_prefix='bi_rnntanh_0_')) + stack.add(mx.rnn.BidirectionalCell( + mx.rnn.RNNCell(H, activation='tanh', prefix='l1_'), + mx.rnn.RNNCell(H, activation='tanh', prefix='r1_'), + output_prefix='bi_rnntanh_1_')) + + check_rnn_consistency(fused, stack, T, N, I, H, 'write') + check_rnn_consistency(fused, stack, T, N, I, H, 'add') + check_rnn_consistency(fused, stack, T, N, I, H, 'null') @with_seed() @assert_raises_cudnn_not_satisfied(min_version='5.1.10') def test_rnnrelu_sym(): - T, N, I, H = 5, 32, 200, 200 - - fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='rnn_relu', get_next_state=True, prefix='') - stack = mx.rnn.SequentialRNNCell() - stack.add(mx.rnn.RNNCell(H, activation='relu', prefix='l0_')) - stack.add(mx.rnn.RNNCell(H, activation='relu', prefix='l1_')) - stack.add(mx.rnn.RNNCell(H, activation='relu', prefix='l2_')) - - check_rnn_consistency(fused, stack, T, N, I, H, 'write') - check_rnn_consistency(fused, stack, T, N, I, H, 'add') - check_rnn_consistency(fused, stack, T, N, I, H, 'null') + Ts = [1, 5] + Ns = [1, 32] + Is = [32, 128, 512] + Hs = [32, 128, 512] + for T, N, I, H in itertools.product(Ts, Ns, Is, Hs): + fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='rnn_relu', get_next_state=True, prefix='') + stack = mx.rnn.SequentialRNNCell() + stack.add(mx.rnn.RNNCell(H, activation='relu', prefix='l0_')) + stack.add(mx.rnn.RNNCell(H, activation='relu', prefix='l1_')) + stack.add(mx.rnn.RNNCell(H, activation='relu', prefix='l2_')) + + check_rnn_consistency(fused, stack, T, N, I, H, 'write') + check_rnn_consistency(fused, stack, T, N, I, H, 'add') + check_rnn_consistency(fused, stack, T, N, I, H, 'null') @with_seed() @assert_raises_cudnn_not_satisfied(min_version='5.1.10') def test_rnnrelu_bidirectional(): - T, N, I, H = 5, 20, 200, 200 - - fused = mx.rnn.FusedRNNCell(H, num_layers=2, mode='rnn_relu', - bidirectional=True, get_next_state=True, prefix='') - - stack = mx.rnn.SequentialRNNCell() - stack.add(mx.rnn.BidirectionalCell( - mx.rnn.RNNCell(H, activation='relu', prefix='l0_'), - mx.rnn.RNNCell(H, activation='relu', prefix='r0_'), - output_prefix='bi_rnnrelu_0_')) - stack.add(mx.rnn.BidirectionalCell( - mx.rnn.RNNCell(H, activation='relu', prefix='l1_'), - mx.rnn.RNNCell(H, activation='relu', prefix='r1_'), - output_prefix='bi_rnnrelu_1_')) - - check_rnn_consistency(fused, stack, T, N, I, H, 'write', rtol=1e-2, atol=1e-2) - check_rnn_consistency(fused, stack, T, N, I, H, 'add', rtol=1e-2, atol=1e-2) - check_rnn_consistency(fused, stack, T, N, I, H, 'null', rtol=1e-2, atol=1e-2) + Ts = [1, 5] + Ns = [1, 32] + Is = [32, 128, 512] + Hs = [32, 128, 512] + for T, N, I, H in itertools.product(Ts, Ns, Is, Hs): + fused = mx.rnn.FusedRNNCell(H, num_layers=2, mode='rnn_relu', + bidirectional=True, get_next_state=True, prefix='') + + stack = mx.rnn.SequentialRNNCell() + stack.add(mx.rnn.BidirectionalCell( + mx.rnn.RNNCell(H, activation='relu', prefix='l0_'), + mx.rnn.RNNCell(H, activation='relu', prefix='r0_'), + output_prefix='bi_rnnrelu_0_')) + stack.add(mx.rnn.BidirectionalCell( + mx.rnn.RNNCell(H, activation='relu', prefix='l1_'), + mx.rnn.RNNCell(H, activation='relu', prefix='r1_'), + output_prefix='bi_rnnrelu_1_')) + + check_rnn_consistency(fused, stack, T, N, I, H, 'write', rtol=1e-2, atol=1e-2) + check_rnn_consistency(fused, stack, T, N, I, H, 'add', rtol=1e-2, atol=1e-2) + check_rnn_consistency(fused, stack, T, N, I, H, 'null', rtol=1e-2, atol=1e-2) @with_seed() def test_lstm_dropout():