diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py index 020bfd7afb7f..fa5d0831561a 100644 --- a/python/mxnet/gluon/rnn/rnn_layer.py +++ b/python/mxnet/gluon/rnn/rnn_layer.py @@ -60,6 +60,7 @@ def __init__(self, hidden_size, num_layers, layout, self._lstm_state_clip_nan = lstm_state_clip_nan self._dtype = dtype self._use_sequence_length = use_sequence_length + self.skip_states = None self._gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode] diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index aaeda76bd459..fa036237c97c 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -158,14 +158,6 @@ static inline bool SupportMKLDNN(int dtype, const mxnet::TShape &shape) { (ndim == 1 || ndim == 2 || ndim == 4); } -static inline bool SupportMKLDNNRnn(const NDArray &input) { - if (input.dtype() == mshadow::kFloat32 && input.shape().ndim() == 3 - && dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1)) { - return true; - } - return false; -} - static inline bool SupportMKLDNNQuantize(int dtype) { return dtype == mshadow::kFloat32 || dtype == mshadow::kInt8 || dtype == mshadow::kUint8 || dtype == mshadow::kBfloat16; diff --git a/src/operator/nn/mkldnn/mkldnn_rnn-inl.h b/src/operator/nn/mkldnn/mkldnn_rnn-inl.h index a4104bf1a437..58e335f1c63a 100644 --- a/src/operator/nn/mkldnn/mkldnn_rnn-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_rnn-inl.h @@ -60,8 +60,8 @@ struct MKLDNNRnnLayerParam { size_t reserve_size; // used for the reserved cached memory in Backward size_t single_w_size; // weights size of a single cell size_t single_b_size; // bias size of a single cell from mkl-dnn - size_t naive_single_b_size; // bias size of a single cell from framework - size_t single_state_size; // state size of a single cell, hy, cy + size_t native_single_b_size; // bias size of a single cell from framework + size_t single_state_size; // state size of a single cell, hy, cy MKLDNNRnnLayerParam(int num_layer, int batch_size, int seq_len, int input_size, int state_size, @@ -441,6 +441,18 @@ class MKLDNNRnnOp { const std::vector &outputs); }; +inline bool SupportMKLDNNRnn(const int input_dtype) { + if (input_dtype == mshadow::kFloat32 && dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1)) { + return true; + } + return false; +} + +inline bool SupportMKLDNNRnn(const RNNParam ¶m, const int input_dtype) { + if (param.projection_size.has_value()) return false; + return SupportMKLDNNRnn(input_dtype); +} + } // namespace op } // namespace mxnet diff --git a/src/operator/nn/mkldnn/mkldnn_rnn.cc b/src/operator/nn/mkldnn/mkldnn_rnn.cc index 8af0e997483e..5d3857e1c578 100644 --- a/src/operator/nn/mkldnn/mkldnn_rnn.cc +++ b/src/operator/nn/mkldnn/mkldnn_rnn.cc @@ -63,7 +63,7 @@ void MKLDNNRnnLayerParam::SetDims() { // unidirectional size of a single cell single_w_size = (input_size + state_size) * ngates * state_size; single_b_size = nbias * state_size; - naive_single_b_size = ngates * state_size * 2; // naive RNN variants have double bias + native_single_b_size = ngates * state_size * 2; // native RNN variants have double bias single_state_size = batch_size * state_size; // Get workspace size for cached weights memory @@ -265,7 +265,7 @@ RnnBwdPrimitive GetRnnBwdPrim(const MKLDNNRnnForwardTraining &fwd, } /* - * Naive weights layout is: + * Native weights layout is: * | l0_l2r_wx | l0_l2r_wh | l0_r2l_wx | l0_r2l_wh | * | l1_l2r_wx | l1_l2r_wh | l1_r2l_wx | l1_r2l_wh | * ... @@ -339,7 +339,6 @@ FUNC(MKLDNN_ARG_DIFF_##NAME, ARGS.at(MKLDNN_ARG_##NAME).get_desc(), HANDLE) void MKLDNNRnnForward::SetNewDataMem(void* x, void* hx, void* cx, void* y, void* hy, void* cy, const int dtype) { - using dims = mkldnn::memory::dims; using desc = mkldnn::memory::desc; using format_tag = mkldnn::memory::format_tag; auto& cpu_engine = CpuEngine::Get()->get_engine(); @@ -462,12 +461,12 @@ inline void EmplaceNetArgs(mkldnn_args_map_t* net_args, const int arg_name, } /* - * Copy naive memory to mkldnn-format memory. It will initialize the memory - * when first invoked. Then, the naive weight_layer and weight_iter are + * Copy native memory to mkldnn-format memory. It will initialize the memory + * when first invoked. Then, the native weight_layer and weight_iter are * concatenated to xxx_xx_r memory. Per the different gates order of GRU, * it will swap the memory blocks of gates among concatenated memory * inplace. From then on, the xxx_xx_r memory is reordered to target - * memory with preferred format_tag. Finally, naive bias is fused to MKLDNN + * memory with preferred format_tag. Finally, native bias is fused to MKLDNN * bias memory. */ void MKLDNNRnnForward::SetWeightsMem(MKLDNNRnnMemMgr* mgr, void *w_ptr, void *b_ptr, @@ -551,13 +550,13 @@ void MKLDNNRnnForward::SetWeightsMem(MKLDNNRnnMemMgr* mgr, void *w_ptr, void *b_ // Process bias MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - DType* naive_b_ptr = static_cast(b_ptr); + DType* native_b_ptr = static_cast(b_ptr); DType* fused_bias = static_cast(bias_->get_data_handle()); for (int lyr = 0; lyr < param_.num_layer; ++lyr) { for (int d = 0; d < param_.bidirectional + 1; ++d) { - FuseBias(fused_bias, naive_b_ptr, param_.mode, param_.state_size); + FuseBias(fused_bias, native_b_ptr, param_.mode, param_.state_size); fused_bias += param_.single_b_size; - naive_b_ptr += param_.naive_single_b_size; + native_b_ptr += param_.native_single_b_size; } } }); @@ -632,7 +631,6 @@ void MKLDNNRnnOp::Init(const OpContext &ctx, const std::vector &inputs, const std::vector &req, const std::vector &outputs) { - using memory = mkldnn::memory; using format_tag = mkldnn::memory::format_tag; // In the `autograd.record()` context, RNNOp is required to run into @@ -674,10 +672,10 @@ void MKLDNNRnnOp::Init(const OpContext &ctx, default_param.bidirectional + 1, default_param.mode)) * dtype_bytes; for (auto& fwd_layer : fwd_inf_vec_) { size_t single_w_bytes = fwd_layer.GetParam().single_w_size * dtype_bytes; - size_t single_b_bytes = fwd_layer.GetParam().naive_single_b_size * dtype_bytes; + size_t single_b_bytes = fwd_layer.GetParam().native_single_b_size * dtype_bytes; size_t directions = fwd_layer.GetParam().bidirectional ? 2 : 1; size_t layer_weights_bytes = single_w_bytes * directions; - size_t layer_bias_bytes = single_b_bytes * directions; // Naive MXNet has double bias + size_t layer_bias_bytes = single_b_bytes * directions; // Native MXNet has double bias if (!fwd_layer.IsInitialized() || is_training) fwd_layer.SetWeightsMem(&(this->mgr_), weights_ptr, bias_ptr, is_training, dtype); @@ -857,7 +855,7 @@ void MKLDNNRnnBackward::CommitWeightsGrads(void* diff_weights, void* diff_bias, const size_t wx_size = param.input_size * param.state_size * ngates; const size_t wh_size = param.state_size * param.state_size * ngates; - /* naive weights layout is: + /* native weights layout is: 1st-layer: | wx_lr | wh_lr | wx_rl | wh_rl | 2st-layer: | wx_lr | wh_lr | wx_rl | wh_rl | size: | wxh_bytes | @@ -903,33 +901,33 @@ void MKLDNNRnnBackward::CommitWeightsGrads(void* diff_weights, void* diff_bias, }); const size_t bias_size = param.single_b_size; - const size_t naive_bias_size = param.naive_single_b_size; + const size_t native_bias_size = param.native_single_b_size; MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { DType* native_bias = static_cast(diff_bias); DType* diff_bias_ptr = static_cast(this->diff_bias_->get_data_handle()); OPREQTYPE_SWITCH(req, DType, FAccGrad, { if (param.mode != rnn_enum::kGru) { for (int shift = 0; shift < num_layer * direction; ++shift) { - FAccGrad(native_bias + shift * naive_bias_size, + FAccGrad(native_bias + shift * native_bias_size, diff_bias_ptr + shift * bias_size, bias_size); - FAccGrad(native_bias + shift * naive_bias_size + bias_size, + FAccGrad(native_bias + shift * native_bias_size + bias_size, diff_bias_ptr + shift * bias_size, bias_size); } } else { const size_t bias_size_per_gate = param.state_size; for (int shift = 0; shift < num_layer * direction; ++shift) { - DType* native_reset = native_bias + shift * naive_bias_size; + DType* native_reset = native_bias + shift * native_bias_size; DType* native_update = native_reset + bias_size_per_gate; DType* update = diff_bias_ptr + shift * bias_size; DType* reset = update + bias_size_per_gate; FAccGrad(native_update, update, bias_size_per_gate); FAccGrad(native_reset, reset, bias_size_per_gate); - FAccGrad(native_update + naive_bias_size / 2, update, bias_size_per_gate); - FAccGrad(native_reset + naive_bias_size / 2, reset, bias_size_per_gate); + FAccGrad(native_update + native_bias_size / 2, update, bias_size_per_gate); + FAccGrad(native_reset + native_bias_size / 2, reset, bias_size_per_gate); DType* native_new_bx = native_update + bias_size_per_gate; - DType* native_new_bh = native_new_bx + naive_bias_size / 2; + DType* native_new_bh = native_new_bx + native_bias_size / 2; DType* new_bx = reset + bias_size_per_gate; DType* new_bh = new_bx + bias_size_per_gate; FAccGrad(native_new_bx, new_bx, bias_size_per_gate); @@ -1186,10 +1184,11 @@ void MKLDNNRnnOp::Backward(const OpContext& ctx, // Commit weights diff if (req[rnn_enum::kParams] != kNullOp) { + const int directions = default_param.bidirectional ? 2 : 1; for (size_t lyr = 0; lyr < bwd_vec_.size(); ++lyr) { bwd_vec_.at(lyr).CommitWeightsGrads(dw, db, req[rnn_enum::kParams], w_dtype); - dw += full_param_.layer_params.at(lyr).single_w_size * w_bytes; - db += full_param_.layer_params.at(lyr).single_b_size * w_bytes; + dw += full_param_.layer_params.at(lyr).single_w_size * directions * w_bytes; + db += full_param_.layer_params.at(lyr).native_single_b_size * directions * w_bytes; } } } diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index 3d47f8c0d361..557c1117739a 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -185,6 +185,7 @@ inline int GetRnnBiasSize(int num_layer, inline size_t GetRNNWorkspaceSize(int seq_length, int batch_size, int hidden_size, + int projection_size, int direction, int mode) { size_t size = 0; @@ -324,6 +325,7 @@ void RNNForwardInference(DType* ws, const int batch_size, const int input_size, const int state_size, + const int projection_size, DType* x_ptr, DType* hx_ptr, DType* cx_ptr, @@ -336,8 +338,8 @@ void RNNForwardInference(DType* ws, switch (mode) { case rnn_enum::kLstm: LstmForwardInference(ws, state_outputs, num_layers, direction, seq_length, - batch_size, input_size, state_size, x_ptr, hx_ptr, cx_ptr, - w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr); + batch_size, input_size, state_size, projection_size, + x_ptr, hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr); break; case rnn_enum::kGru: GruForwardInference(ws, state_outputs, num_layers, direction, seq_length, @@ -511,10 +513,7 @@ class RNNOp { this->temp_init_space_ = false; this->reserve_cpu_space_size_ = 0; this->temp_cpu_space_size_ = 0; - if (param_.projection_size.has_value()) { - LOG(FATAL) << - "hidden layer projection is only supported for GPU with CuDNN later than 7.1.1"; - } + if (param_.lstm_state_clip_min.has_value() || param_.lstm_state_clip_max.has_value()) { LOG(FATAL) << "LSTM state clipping is only supported for GPU with CuDNN later than 7.2.1"; @@ -843,9 +842,14 @@ class RNNOp { #endif // MXNET_USE_CUDNN == 1 && defined(__CUDACC__) if (ctx_.dev_type == kCPU) { + int projection_size = 0; + if (param_.projection_size.has_value()) { + projection_size = param_.projection_size.value(); + } + // allocate temp space const size_t work_cpu_space_size = GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, - param_.state_size, direction, param_.mode); + param_.state_size, projection_size, direction, param_.mode); if (!temp_init_space_ || temp_cpu_space_size_ < work_cpu_space_size) { temp_cpu_space_size_ = work_cpu_space_size; temp_cpu_space_ = NDArray(TShape({static_cast(temp_cpu_space_size_)}), ctx_, @@ -856,6 +860,9 @@ class RNNOp { if (ctx.is_train || ctx.need_grad) { // allocate reserve space + if (param_.projection_size.has_value()) { + LOG(FATAL) << "No training support for LSTM with projection on CPU currently."; + } const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction, param_.seq_length_, param_.batch_size_, @@ -896,6 +903,7 @@ class RNNOp { param_.batch_size_, param_.input_size_, param_.state_size, + projection_size, x.dptr_, hx.dptr_, cx_ptr, @@ -1096,10 +1104,17 @@ class RNNOp { #endif // MXNET_USE_CUDNN == 1 && defined(__CUDACC__) if (ctx_.dev_type == kCPU) { + int projection_size = 0; + if (param_.projection_size.has_value()) { + // TODO(zixuanweeei): Add training support for LSTM with projection on CPU. + // projection_size = param_.projection_size.value(); + LOG(FATAL) << "No training support for LSTM with projection on CPU currently."; + } + // allocate temp space const size_t work_cpu_space_size = - GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, - param_.state_size, direction, param_.mode); + GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, param_.state_size, + projection_size, direction, param_.mode); if (!temp_init_space_ || temp_cpu_space_size_ != work_cpu_space_size) { LOG(FATAL) << "Check temp init error"; } diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc index f468b60de744..ac5e17d49133 100644 --- a/src/operator/rnn.cc +++ b/src/operator/rnn.cc @@ -190,20 +190,19 @@ static std::vector RNNResourceEx(const NodeAttrs& attrs, const return request; } +#if MXNET_USE_MKLDNN == 1 inline static bool RNNStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask, DispatchMode* dispatch_mode, std::vector *in_attrs, std::vector *out_attrs) { - DispatchMode wanted_mode = DispatchMode::kFCompute; - -#if MXNET_USE_MKLDNN == 1 - wanted_mode = DispatchMode::kFComputeEx; -#endif // MXNET_USE_MKLDNN == 1 - - return storage_type_assign(out_attrs, mxnet::kDefaultStorage, - dispatch_mode, wanted_mode); + const RNNParam& param = nnvm::get(attrs.parsed); + const bool support_mkldnn_rnn = + !param.projection_size.has_value() && dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1); + return MKLDNNStorageType(attrs, dev_mask, support_mkldnn_rnn, + dispatch_mode, in_attrs, out_attrs); } +#endif // MXNET_USE_MKLDNN == 1 struct RNNGrad { const char *op_name; @@ -246,9 +245,7 @@ static OpStatePtr CreateRNNState(const nnvm::NodeAttrs &attrs, } #if MXNET_USE_MKLDNN == 1 - if ((in_types[0] == mshadow::kFloat32 || in_types[0] == mshadow::kFloat16) - && in_shapes[0].ndim() == 3 && ctx.dev_type == kCPU - && dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1)) { + if (ctx.dev_type == kCPU && SupportMKLDNNRnn(param, in_types[rnn_enum::kData])) { const mxnet::TShape& data_shape = in_shapes[rnn_enum::kData]; state = OpStatePtr::Create(param, data_shape[0], data_shape[1], data_shape[2]); @@ -274,7 +271,7 @@ static void RNNStatefulComputeExCPU(const OpStatePtr& state_ptr, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - if (SupportMKLDNNRnn(inputs[0])) { + if (SupportMKLDNNRnn(inputs[rnn_enum::kData].dtype())) { MKLDNNRnnOp& op = state_ptr.get_state(); op.Forward(ctx, inputs, req, outputs); } else { @@ -287,7 +284,7 @@ static void RNNStatefulGradComputeExCPU(const OpStatePtr& state_ptr, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - if (SupportMKLDNNRnn(inputs[0])) { + if (SupportMKLDNNRnn(inputs[rnn_enum::kData].dtype())) { MKLDNNRnnOp& op = state_ptr.get_state(); op.Backward(ctx, inputs, req, outputs); } else { @@ -338,6 +335,23 @@ Long Short-Term Memory - Hochreiter, 1997. http://www.bioinf.jku.at/publications h_t = o_t * \tanh(c_t) \end{array} +With the projection size being set, LSTM could use the projection feature to reduce the parameters +size and give some speedups without significant damage to the accuracy. + +Long Short-Term Memory Based Recurrent Neural Network Architectures for Large Vocabulary Speech +Recognition - Sak et al. 2014. https://arxiv.org/abs/1402.1128 + +.. math:: + \begin{array}{ll} + i_t = \mathrm{sigmoid}(W_{ii} x_t + b_{ii} + W_{ri} r_{(t-1)} + b_{ri}) \\ + f_t = \mathrm{sigmoid}(W_{if} x_t + b_{if} + W_{rf} r_{(t-1)} + b_{rf}) \\ + g_t = \tanh(W_{ig} x_t + b_{ig} + W_{rc} r_{(t-1)} + b_{rg}) \\ + o_t = \mathrm{sigmoid}(W_{io} x_t + b_{o} + W_{ro} r_{(t-1)} + b_{ro}) \\ + c_t = f_t * c_{(t-1)} + i_t * g_t \\ + h_t = o_t * \tanh(c_t) + r_t = W_{hr} h_t + \end{array} + **GRU** Gated Recurrent Unit - Cho et al. 2014. http://arxiv.org/abs/1406.1078 @@ -385,10 +399,10 @@ The definition of GRU here is slightly different from paper but compatible with }) .set_attr("FInferShape", RNNShape) .set_attr("FInferType", RNNType) -.set_attr("FInferStorageType", RNNStorageType) .set_attr("FCreateOpState", CreateRNNState) .set_attr("FStatefulCompute", RNNStatefulCompute) #if MXNET_USE_MKLDNN == 1 +.set_attr("FInferStorageType", RNNStorageType) .set_attr("TIsMKLDNN", true) .set_attr("FStatefulComputeEx", RNNStatefulComputeExCPU) #endif @@ -427,9 +441,9 @@ NNVM_REGISTER_OP(_backward_RNN) .set_attr_parser(ParamParser) .set_attr("TIsLayerOpBackward", true) .set_attr("TIsBackward", true) -.set_attr("FInferStorageType", RNNStorageType) .set_attr("FStatefulCompute", RNNStatefulGradCompute) #if MXNET_USE_MKLDNN == 1 +.set_attr("FInferStorageType", RNNStorageType) .set_attr("TIsMKLDNN", true) .set_attr("FStatefulComputeEx", RNNStatefulGradComputeExCPU) #endif diff --git a/src/operator/rnn_impl.h b/src/operator/rnn_impl.h index 3aa643421857..008ba7d315c6 100644 --- a/src/operator/rnn_impl.h +++ b/src/operator/rnn_impl.h @@ -209,6 +209,7 @@ void LstmForwardInferenceSingleLayer(DType* ws, const int N, const int I, const int H, + const int P, const Tensor &x, const Tensor &hx, const Tensor &cx, @@ -219,7 +220,9 @@ void LstmForwardInferenceSingleLayer(DType* ws, DType* cy_ptr) { using namespace mshadow; const Tensor wx(w_ptr, Shape2(H * 4, I)); - const Tensor wh(w_ptr + I * H * 4, Shape2(H * 4, H)); + const Tensor wh(w_ptr + I * H * 4, Shape2(H * 4, (P ? P : H))); + Tensor whr(w_ptr, Shape2(1, 1)); + if (P > 0) whr = Tensor(wh.dptr_ + P * 4 * H, Shape2(P, H)); const Tensor bx(b_ptr, Shape2(4, H)); const Tensor bh(b_ptr + H * 4, Shape2(4, H)); Tensor yx_flat(ws, Shape2(T * N, H * 4)); @@ -228,7 +231,10 @@ void LstmForwardInferenceSingleLayer(DType* ws, const Tensor yh(yh_flat.dptr_, Shape3(N, 4, H)); Tensor h(yh_flat.dptr_ + N * H * 4, Shape2(N, H)); Tensor c(h.dptr_ + N * H, Shape2(N, H)); + Tensor r(hy_ptr, Shape2(1, 1)); + if (P > 0) r = Tensor(hy_ptr, Shape2(N, P)); const int offset = bid ? H : 0; + const int proj_offset = bid ? P : 0; const DType alpha = 1.0; const DType beta = 0.0; const int cell_size = N * H; @@ -237,7 +243,11 @@ void LstmForwardInferenceSingleLayer(DType* ws, const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); for (int i = 0; i < T; ++i) { int t = bid ? T - 1 - i : i; - linalg_gemm(i ? h : hx, wh, yh_flat, alpha, beta, false, true); + if (P > 0) { + linalg_gemm(i ? r : hx, wh, yh_flat, alpha, beta, false, true); + } else { + linalg_gemm(i ? h : hx, wh, yh_flat, alpha, beta, false, true); + } #pragma omp parallel for num_threads(omp_threads) for (int jk = 0; jk < cell_size; ++jk) { int j = jk / H; @@ -248,14 +258,21 @@ void LstmForwardInferenceSingleLayer(DType* ws, DType ot = sigmoid(yx[t][j][3][k] + yh[j][3][k] + bx[3][k] + bh[3][k]); DType ct = (i ? c[j][k] : cx[j][k]) * ft + it * gt; DType ht = ot * tanh(ct); - y[t][j][k + offset] = ht; + if (P == 0) y[t][j][k + offset] = ht; if (i == T - 1 && state_outputs) { - hy_ptr[jk] = ht; + if (P == 0) hy_ptr[jk] = ht; cy_ptr[jk] = ct; } else { - h[j][k] = ht; c[j][k] = ct; } + h[j][k] = ht; + } + if (P > 0) { + linalg_gemm(h, whr, r, alpha, beta, false, true); + #pragma omp parallel for num_threads(omp_threads) + for (int j = 0; j < N; ++j) { + std::memcpy(y[t][j].dptr_ + proj_offset, r[j].dptr_, P * sizeof(DType)); + } } } } @@ -269,6 +286,7 @@ void LstmForwardInference(DType* ws, const int N, const int I, const int H, + const int P, DType* x_ptr, DType* hx_ptr, DType* cx_ptr, @@ -278,25 +296,29 @@ void LstmForwardInference(DType* ws, DType* hy_ptr, DType* cy_ptr) { const int total_layers = D * L; - Tensor hx(hx_ptr, Shape3(total_layers, N, H)); + Tensor hx(hx_ptr, Shape3(total_layers, N, P ? P : H)); Tensor cx(cx_ptr, Shape3(total_layers, N, H)); const int b_size = 2 * H * 4; const int cell_size = N * H; + const int projection_size = (P ? P : H) * N; DType* y_tmp_ptr = ws + (T + 1) * cell_size * 4 + cell_size * 2; DType* y_cur_ptr = y_ptr; int idx = 0; // state & cell state's idx; bool flag = L % 2 ? false : true; for (int i = 0; i < L; ++i) { - const int input_size = i ? H * D : I; - const int w_size = (input_size + H) * H * 4; + const int input_size = i ? (P ? P : H) * D : I; + int w_size = (input_size + (P ? P : H)) * H * 4; + if (P > 0) { + w_size += P * H; + } // If bidirectional, need space to save current layer output y. if (D == 2) { y_cur_ptr = flag ? y_tmp_ptr : y_ptr; flag = !flag; } Tensor x(x_ptr, Shape2(T * N, input_size)); - Tensor y(y_cur_ptr, Shape3(T, N, H * D)); - LstmForwardInferenceSingleLayer(ws, state_outputs, false, T, N, input_size, H, + Tensor y(y_cur_ptr, Shape3(T, N, (P ? P : H) * D)); + LstmForwardInferenceSingleLayer(ws, state_outputs, false, T, N, input_size, H, P, x, hx[idx], cx[idx], y, w_ptr, b_ptr, hy_ptr, cy_ptr); // If bidirectional, then calculate the reverse direction's forward result. if (D == 2) { @@ -304,10 +326,10 @@ void LstmForwardInference(DType* ws, b_ptr += b_size; ++idx; if (state_outputs) { - hy_ptr += cell_size; + hy_ptr += projection_size; cy_ptr += cell_size; } - LstmForwardInferenceSingleLayer(ws, state_outputs, true, T, N, input_size, H, + LstmForwardInferenceSingleLayer(ws, state_outputs, true, T, N, input_size, H, P, x, hx[idx], cx[idx], y, w_ptr, b_ptr, hy_ptr, cy_ptr); } // Don't need to move pointer in the last layer. @@ -317,7 +339,7 @@ void LstmForwardInference(DType* ws, x_ptr = y_cur_ptr; ++idx; if (state_outputs) { - hy_ptr += cell_size; + hy_ptr += projection_size; cy_ptr += cell_size; } } diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py index 0f27f53f83a8..f2a220bbe719 100644 --- a/tests/python/unittest/test_gluon_rnn.py +++ b/tests/python/unittest/test_gluon_rnn.py @@ -26,6 +26,25 @@ from mxnet.test_utils import almost_equal, assert_almost_equal from common import assert_raises_cudnn_not_satisfied, with_seed + +def check_rnn_states(fused_states, stack_states, num_layers, bidirectional=False, is_lstm=True): + directions = 2 if bidirectional else 1 + assert len(stack_states) / len(fused_states) == num_layers * directions + + fused_states = [state.asnumpy() for state in fused_states] + stack_states = [np.expand_dims(state.asnumpy(), axis=0) for state in stack_states] + if is_lstm: + stack_states_h = stack_states[0::2] + stack_states_c = stack_states[1::2] + stack_states = [np.concatenate(stack_states_h, axis=0), np.concatenate(stack_states_c, axis=0)] + else: + stack_states = [np.concatenate(stack_states, axis=0)] + + for f, s in zip(fused_states, stack_states): + assert f.shape == s.shape + assert_almost_equal(f, s, atol=1e-4, rtol=1e-4) + + def test_rnn(): cell = gluon.rnn.RNNCell(100, prefix='rnn_') inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)] @@ -51,6 +70,88 @@ def test_lstm(): assert outs == [(10, 100), (10, 100), (10, 100)] +@with_seed() +@assert_raises_cudnn_not_satisfied(min_version='7.2.1') +def test_lstmp(): + hidden_size, projection_size = 512, 256 + rtol, atol = 1e-4, 1e-4 + batch_size, seq_len = 5, 3 + input_size = 128 + lstm_input = mx.nd.uniform(shape=(seq_len, batch_size, input_size)) + + # ==== Unidirectional Layer ==== + for num_layers in [1, 3]: + fused_layer = gluon.rnn.LSTM(hidden_size, projection_size=projection_size, + num_layers=num_layers, layout='TNC', bidirectional=False, + prefix='lstm0_') + + stack_layer = mx.gluon.rnn.HybridSequentialRNNCell(prefix='lstm0_') + with stack_layer.name_scope(): + for i in range(num_layers): + stack_layer.add(gluon.contrib.rnn.LSTMPCell(hidden_size, + projection_size=projection_size, + prefix='l%d_' % i)) + fused_layer.initialize() + stack_layer.initialize() + + fused_begin_state = fused_layer.begin_state(batch_size) + stack_begin_state = stack_layer.begin_state(batch_size=batch_size) + fused_layer.infer_shape(lstm_input, fused_begin_state) + fused_layer_params = fused_layer.collect_params() + stack_layer_params = stack_layer.collect_params() + + for name, value in fused_layer_params.items(): + w = mx.nd.random.uniform(shape=value.shape) + value.set_data(w.copy()) + stack_layer_params[name].set_data(w.copy()) + + fused_output, fused_states = fused_layer(lstm_input.copy(), fused_begin_state) + stack_output, stack_states = stack_layer.unroll(seq_len, lstm_input.copy(), begin_state=stack_begin_state, + layout='TNC', + merge_outputs=True) + + assert_almost_equal(fused_output.asnumpy(), stack_output.asnumpy(), rtol=rtol, atol=atol) + check_rnn_states(fused_states, stack_states, num_layers, False) + + # ==== Bidirectional Layer ==== + for num_layers in [1, 3]: + fused_layer = gluon.rnn.LSTM(hidden_size, projection_size=projection_size, + num_layers=num_layers, layout='TNC', bidirectional=True, + prefix='lstm0_') + + stack_layer = mx.gluon.rnn.HybridSequentialRNNCell(prefix='lstm0_') + with stack_layer.name_scope(): + for i in range(num_layers): + stack_layer.add( + gluon.rnn.BidirectionalCell(gluon.contrib.rnn.LSTMPCell(hidden_size, + projection_size=projection_size, + prefix='l%d_' % i), + gluon.contrib.rnn.LSTMPCell(hidden_size, + projection_size=projection_size, + prefix='r%d_' % i))) + fused_layer.initialize() + stack_layer.initialize() + + fused_begin_state = fused_layer.begin_state(batch_size) + stack_begin_state = stack_layer.begin_state(batch_size=batch_size) + fused_layer.infer_shape(lstm_input, fused_begin_state) + fused_layer_params = fused_layer.collect_params() + stack_layer_params = stack_layer.collect_params() + + for name, value in fused_layer_params.items(): + w = mx.nd.random.uniform(shape=value.shape) + value.set_data(w.copy()) + stack_layer_params[name].set_data(w.copy()) + + fused_output, fused_states = fused_layer(lstm_input.copy(), fused_begin_state) + stack_output, stack_states = stack_layer.unroll(seq_len, lstm_input.copy(), begin_state=stack_begin_state, + layout='TNC', + merge_outputs=True) + + assert_almost_equal(fused_output.asnumpy(), stack_output.asnumpy(), rtol=rtol, atol=atol) + check_rnn_states(fused_states, stack_states, num_layers, True) + + def test_lstm_forget_bias(): forget_bias = 2.0 stack = gluon.rnn.SequentialRNNCell() @@ -548,30 +649,53 @@ def test_rnn_layers_fp16(): def check_rnn_consistency(fused_layer, stack_layer, loss, input_size, hidden_size, bidirectional=False, rtol=1e-2, atol=1e-4): - fused_begin_state = fused_layer.begin_state(1) - stack_state = stack_layer.begin_state(batch_size=1) x = nd.random.normal(shape=(1, 5, input_size)) - x.attach_grad() - y = nd.random.normal(shape=(1, 5, hidden_size * 2 if bidirectional else hidden_size)) + fused_begin_state = fused_layer.begin_state(1) + stack_states = stack_layer.begin_state(batch_size=1) + fused_layer.infer_shape(x, fused_begin_state) + fused_layer_params = fused_layer.collect_params() + stack_layer_params = stack_layer.collect_params() + + for name, value in fused_layer_params.items(): + if 'rnn' in fused_layer.prefix and 'weight' in name: + w = mx.nd.zeros(shape=value.shape) + else: + w = mx.nd.random.normal(shape=value.shape) + value.set_data(w.copy()) + stack_layer_params[name].set_data(w.copy()) + fx = x.copy() + sx = x.copy() + y = nd.random.uniform(shape=(1, 5, hidden_size * 2 if bidirectional else hidden_size)) + + fx.attach_grad() with mx.autograd.record(): - fused_out, fused_state = fused_layer(x, fused_begin_state) + fused_out, fused_states = fused_layer(fx, fused_begin_state) l = loss(fused_out, y).mean() l.backward() fused_grads = dict([(name, p.grad()) for name, p in fused_layer.collect_params().items()]) - fused_input_grad = x.grad.asnumpy() + fused_input_grad = fx.grad.asnumpy() + sx.attach_grad() with mx.autograd.record(): - stack_out, stack_state = stack_layer.unroll(5, x, stack_state, merge_outputs=True) + stack_out, stack_states = stack_layer.unroll(5, sx, begin_state=stack_states, merge_outputs=True) l = loss(stack_out, y).mean() l.backward() stack_grads = dict([(name, p.grad()) for name, p in stack_layer.collect_params().items()]) - stack_input_grad = x.grad.asnumpy() + stack_input_grad = sx.grad.asnumpy() assert_allclose(fused_out.asnumpy(), stack_out.asnumpy(), rtol=rtol, atol=atol) - assert_allclose(fused_input_grad, stack_input_grad, rtol=rtol, atol=atol) - for key, value in fused_grads.items(): - assert_allclose(value.asnumpy(), stack_grads[key].asnumpy(), rtol=rtol, atol=atol) + if mx.context.current_context().device_type == 'cpu' and \ + not mx.runtime.Features().is_enabled('MKLDNN') and \ + 'rnn' not in fused_layer.prefix: + print("LSTM and GRU on native CPU give wrong gradients. " + "Tracking issue: https://github.com/apache/incubator-mxnet/issues/17898.") + else: + assert_allclose(fused_input_grad, stack_input_grad, rtol=rtol, atol=atol) + for key, value in fused_grads.items(): + assert_allclose(value.asnumpy(), stack_grads[key].asnumpy(), rtol=rtol, atol=atol) + num_layers = fused_begin_state[0].shape[0] // (2 if bidirectional else 1) + check_rnn_states(fused_states, stack_states, num_layers, bidirectional, len(fused_begin_state) == 2) def create_op_by_mode(mode): @@ -598,11 +722,10 @@ def create_op_by_mode(mode): def check_rnn_unidir_layer_gradients(mode, input_size, hidden_size, loss): fused_op, stack_op, recurrent_block_prefix = create_op_by_mode(mode) # ==== Single layer ==== - fused_layer = fused_op(hidden_size, num_layers=1, layout='NTC', bidirectional=False) - fused_layer.collect_params().initialize() + fused_layer = fused_op(hidden_size, num_layers=1, layout='NTC', bidirectional=False, prefix=recurrent_block_prefix) + fused_layer.initialize() - params = fused_layer.collect_params() - stack_layer = mx.gluon.rnn.HybridSequentialRNNCell(prefix=recurrent_block_prefix, params=params) + stack_layer = mx.gluon.rnn.HybridSequentialRNNCell(prefix=recurrent_block_prefix) with stack_layer.name_scope(): stack_layer.add(stack_op(hidden_size, prefix='l0_')) stack_layer.initialize() @@ -610,11 +733,10 @@ def check_rnn_unidir_layer_gradients(mode, input_size, hidden_size, loss): check_rnn_consistency(fused_layer, stack_layer, loss, input_size, hidden_size) # ==== Multiple layer ==== - fused_layer = fused_op(hidden_size, num_layers=3, layout='NTC', bidirectional=False) - fused_layer.collect_params().initialize() + fused_layer = fused_op(hidden_size, num_layers=3, layout='NTC', bidirectional=False, prefix=recurrent_block_prefix) + fused_layer.initialize() - params = fused_layer.collect_params() - stack_layer = mx.gluon.rnn.HybridSequentialRNNCell(prefix=recurrent_block_prefix, params=params) + stack_layer = mx.gluon.rnn.HybridSequentialRNNCell(prefix=recurrent_block_prefix) with stack_layer.name_scope(): stack_layer.add(stack_op(hidden_size, prefix='l0_')) stack_layer.add(stack_op(hidden_size, prefix='l1_')) @@ -627,11 +749,10 @@ def check_rnn_unidir_layer_gradients(mode, input_size, hidden_size, loss): def check_rnn_bidir_layer_gradients(mode, input_size, hidden_size, loss): fused_op, stack_op, recurrent_block_prefix = create_op_by_mode(mode) # ==== Single layer ==== - fused_layer = fused_op(hidden_size, num_layers=1, layout='NTC', bidirectional=True) - fused_layer.collect_params().initialize() + fused_layer = fused_op(hidden_size, num_layers=1, layout='NTC', bidirectional=True, prefix=recurrent_block_prefix) + fused_layer.initialize() - params = fused_layer.collect_params() - stack_layer = mx.gluon.rnn.HybridSequentialRNNCell(prefix=recurrent_block_prefix, params=params) + stack_layer = mx.gluon.rnn.HybridSequentialRNNCell(prefix=recurrent_block_prefix) with stack_layer.name_scope(): stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size, prefix='l0_'), stack_op(hidden_size, prefix='r0_'))) @@ -640,11 +761,10 @@ def check_rnn_bidir_layer_gradients(mode, input_size, hidden_size, loss): check_rnn_consistency(fused_layer, stack_layer, loss, input_size, hidden_size, bidirectional=True) # ==== Multiple layer ==== - fused_layer = fused_op(hidden_size, num_layers=3, layout='NTC', bidirectional=True) - fused_layer.collect_params().initialize() + fused_layer = fused_op(hidden_size, num_layers=3, layout='NTC', bidirectional=True, prefix=recurrent_block_prefix) + fused_layer.initialize() - params = fused_layer.collect_params() - stack_layer = mx.gluon.rnn.HybridSequentialRNNCell(prefix=recurrent_block_prefix, params=params) + stack_layer = mx.gluon.rnn.HybridSequentialRNNCell(prefix=recurrent_block_prefix) with stack_layer.name_scope(): stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size, prefix='l0_'), stack_op(hidden_size, prefix='r0_'))) @@ -657,16 +777,48 @@ def check_rnn_bidir_layer_gradients(mode, input_size, hidden_size, loss): check_rnn_consistency(fused_layer, stack_layer, loss, input_size, hidden_size, bidirectional=True) +@with_seed() +@assert_raises_cudnn_not_satisfied(min_version='5.1.10') +def test_fused_lstm_layer(): + input_sizes = [8] + hidden_sizes = [8, 16] + for input_size, hidden_size in product(input_sizes, hidden_sizes): + loss = mx.gluon.loss.L2Loss() + check_rnn_unidir_layer_gradients('lstm', input_size, hidden_size, loss) + check_rnn_bidir_layer_gradients('lstm', input_size, hidden_size, loss) + + +@with_seed() +@assert_raises_cudnn_not_satisfied(min_version='5.1.10') +def test_fused_gru_layer(): + input_sizes = [8] + hidden_sizes = [8, 16] + for input_size, hidden_size in product(input_sizes, hidden_sizes): + loss = mx.gluon.loss.L2Loss() + check_rnn_unidir_layer_gradients('gru', input_size, hidden_size, loss) + check_rnn_bidir_layer_gradients('gru', input_size, hidden_size, loss) + + +@with_seed() +@assert_raises_cudnn_not_satisfied(min_version='5.1.10') +def test_fused_rnnrelu_layer(): + input_sizes = [8] + hidden_sizes = [8, 16] + for input_size, hidden_size in product(input_sizes, hidden_sizes): + loss = mx.gluon.loss.L2Loss() + check_rnn_unidir_layer_gradients('rnn_relu', input_size, hidden_size, loss) + check_rnn_bidir_layer_gradients('rnn_relu', input_size, hidden_size, loss) + + +@with_seed() @assert_raises_cudnn_not_satisfied(min_version='5.1.10') -def test_fused_rnn_layer(): - input_sizes = [128] - hidden_sizes = [128, 256] - modes = ['lstm', 'gru', 'rnn_relu', 'rnn_tanh'] - # single layer - for mode, input_size, hidden_size in product(modes, input_sizes, hidden_sizes): +def test_fused_rnntanh_layer(): + input_sizes = [8] + hidden_sizes = [8, 16] + for input_size, hidden_size in product(input_sizes, hidden_sizes): loss = mx.gluon.loss.L2Loss() - check_rnn_unidir_layer_gradients(mode, input_size, hidden_size, loss) - check_rnn_bidir_layer_gradients(mode, input_size, hidden_size, loss) + check_rnn_unidir_layer_gradients('rnn_tanh', input_size, hidden_size, loss) + check_rnn_bidir_layer_gradients('rnn_tanh', input_size, hidden_size, loss) def test_rnn_unroll_variant_length():