diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index aaeda76bd459..fa036237c97c 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -158,14 +158,6 @@ static inline bool SupportMKLDNN(int dtype, const mxnet::TShape &shape) { (ndim == 1 || ndim == 2 || ndim == 4); } -static inline bool SupportMKLDNNRnn(const NDArray &input) { - if (input.dtype() == mshadow::kFloat32 && input.shape().ndim() == 3 - && dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1)) { - return true; - } - return false; -} - static inline bool SupportMKLDNNQuantize(int dtype) { return dtype == mshadow::kFloat32 || dtype == mshadow::kInt8 || dtype == mshadow::kUint8 || dtype == mshadow::kBfloat16; diff --git a/src/operator/nn/mkldnn/mkldnn_rnn-inl.h b/src/operator/nn/mkldnn/mkldnn_rnn-inl.h index a4104bf1a437..1d914876506f 100644 --- a/src/operator/nn/mkldnn/mkldnn_rnn-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_rnn-inl.h @@ -441,6 +441,18 @@ class MKLDNNRnnOp { const std::vector &outputs); }; +inline bool SupportMKLDNNRnn(const int input_dtype) { + if (input_dtype == mshadow::kFloat32 && dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1)) { + return true; + } + return false; +} + +inline bool SupportMKLDNNRnn(const RNNParam ¶m, const int input_dtype) { + if (param.projection_size.has_value()) return false; + return SupportMKLDNNRnn(input_dtype); +} + } // namespace op } // namespace mxnet diff --git a/src/operator/nn/mkldnn/mkldnn_rnn.cc b/src/operator/nn/mkldnn/mkldnn_rnn.cc index 8af0e997483e..6180e3d8ceb2 100644 --- a/src/operator/nn/mkldnn/mkldnn_rnn.cc +++ b/src/operator/nn/mkldnn/mkldnn_rnn.cc @@ -339,7 +339,6 @@ FUNC(MKLDNN_ARG_DIFF_##NAME, ARGS.at(MKLDNN_ARG_##NAME).get_desc(), HANDLE) void MKLDNNRnnForward::SetNewDataMem(void* x, void* hx, void* cx, void* y, void* hy, void* cy, const int dtype) { - using dims = mkldnn::memory::dims; using desc = mkldnn::memory::desc; using format_tag = mkldnn::memory::format_tag; auto& cpu_engine = CpuEngine::Get()->get_engine(); @@ -632,7 +631,6 @@ void MKLDNNRnnOp::Init(const OpContext &ctx, const std::vector &inputs, const std::vector &req, const std::vector &outputs) { - using memory = mkldnn::memory; using format_tag = mkldnn::memory::format_tag; // In the `autograd.record()` context, RNNOp is required to run into diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index 3d47f8c0d361..557c1117739a 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -185,6 +185,7 @@ inline int GetRnnBiasSize(int num_layer, inline size_t GetRNNWorkspaceSize(int seq_length, int batch_size, int hidden_size, + int projection_size, int direction, int mode) { size_t size = 0; @@ -324,6 +325,7 @@ void RNNForwardInference(DType* ws, const int batch_size, const int input_size, const int state_size, + const int projection_size, DType* x_ptr, DType* hx_ptr, DType* cx_ptr, @@ -336,8 +338,8 @@ void RNNForwardInference(DType* ws, switch (mode) { case rnn_enum::kLstm: LstmForwardInference(ws, state_outputs, num_layers, direction, seq_length, - batch_size, input_size, state_size, x_ptr, hx_ptr, cx_ptr, - w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr); + batch_size, input_size, state_size, projection_size, + x_ptr, hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr); break; case rnn_enum::kGru: GruForwardInference(ws, state_outputs, num_layers, direction, seq_length, @@ -511,10 +513,7 @@ class RNNOp { this->temp_init_space_ = false; this->reserve_cpu_space_size_ = 0; this->temp_cpu_space_size_ = 0; - if (param_.projection_size.has_value()) { - LOG(FATAL) << - "hidden layer projection is only supported for GPU with CuDNN later than 7.1.1"; - } + if (param_.lstm_state_clip_min.has_value() || param_.lstm_state_clip_max.has_value()) { LOG(FATAL) << "LSTM state clipping is only supported for GPU with CuDNN later than 7.2.1"; @@ -843,9 +842,14 @@ class RNNOp { #endif // MXNET_USE_CUDNN == 1 && defined(__CUDACC__) if (ctx_.dev_type == kCPU) { + int projection_size = 0; + if (param_.projection_size.has_value()) { + projection_size = param_.projection_size.value(); + } + // allocate temp space const size_t work_cpu_space_size = GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, - param_.state_size, direction, param_.mode); + param_.state_size, projection_size, direction, param_.mode); if (!temp_init_space_ || temp_cpu_space_size_ < work_cpu_space_size) { temp_cpu_space_size_ = work_cpu_space_size; temp_cpu_space_ = NDArray(TShape({static_cast(temp_cpu_space_size_)}), ctx_, @@ -856,6 +860,9 @@ class RNNOp { if (ctx.is_train || ctx.need_grad) { // allocate reserve space + if (param_.projection_size.has_value()) { + LOG(FATAL) << "No training support for LSTM with projection on CPU currently."; + } const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction, param_.seq_length_, param_.batch_size_, @@ -896,6 +903,7 @@ class RNNOp { param_.batch_size_, param_.input_size_, param_.state_size, + projection_size, x.dptr_, hx.dptr_, cx_ptr, @@ -1096,10 +1104,17 @@ class RNNOp { #endif // MXNET_USE_CUDNN == 1 && defined(__CUDACC__) if (ctx_.dev_type == kCPU) { + int projection_size = 0; + if (param_.projection_size.has_value()) { + // TODO(zixuanweeei): Add training support for LSTM with projection on CPU. + // projection_size = param_.projection_size.value(); + LOG(FATAL) << "No training support for LSTM with projection on CPU currently."; + } + // allocate temp space const size_t work_cpu_space_size = - GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, - param_.state_size, direction, param_.mode); + GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, param_.state_size, + projection_size, direction, param_.mode); if (!temp_init_space_ || temp_cpu_space_size_ != work_cpu_space_size) { LOG(FATAL) << "Check temp init error"; } diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc index f468b60de744..ac5e17d49133 100644 --- a/src/operator/rnn.cc +++ b/src/operator/rnn.cc @@ -190,20 +190,19 @@ static std::vector RNNResourceEx(const NodeAttrs& attrs, const return request; } +#if MXNET_USE_MKLDNN == 1 inline static bool RNNStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask, DispatchMode* dispatch_mode, std::vector *in_attrs, std::vector *out_attrs) { - DispatchMode wanted_mode = DispatchMode::kFCompute; - -#if MXNET_USE_MKLDNN == 1 - wanted_mode = DispatchMode::kFComputeEx; -#endif // MXNET_USE_MKLDNN == 1 - - return storage_type_assign(out_attrs, mxnet::kDefaultStorage, - dispatch_mode, wanted_mode); + const RNNParam& param = nnvm::get(attrs.parsed); + const bool support_mkldnn_rnn = + !param.projection_size.has_value() && dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1); + return MKLDNNStorageType(attrs, dev_mask, support_mkldnn_rnn, + dispatch_mode, in_attrs, out_attrs); } +#endif // MXNET_USE_MKLDNN == 1 struct RNNGrad { const char *op_name; @@ -246,9 +245,7 @@ static OpStatePtr CreateRNNState(const nnvm::NodeAttrs &attrs, } #if MXNET_USE_MKLDNN == 1 - if ((in_types[0] == mshadow::kFloat32 || in_types[0] == mshadow::kFloat16) - && in_shapes[0].ndim() == 3 && ctx.dev_type == kCPU - && dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1)) { + if (ctx.dev_type == kCPU && SupportMKLDNNRnn(param, in_types[rnn_enum::kData])) { const mxnet::TShape& data_shape = in_shapes[rnn_enum::kData]; state = OpStatePtr::Create(param, data_shape[0], data_shape[1], data_shape[2]); @@ -274,7 +271,7 @@ static void RNNStatefulComputeExCPU(const OpStatePtr& state_ptr, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - if (SupportMKLDNNRnn(inputs[0])) { + if (SupportMKLDNNRnn(inputs[rnn_enum::kData].dtype())) { MKLDNNRnnOp& op = state_ptr.get_state(); op.Forward(ctx, inputs, req, outputs); } else { @@ -287,7 +284,7 @@ static void RNNStatefulGradComputeExCPU(const OpStatePtr& state_ptr, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - if (SupportMKLDNNRnn(inputs[0])) { + if (SupportMKLDNNRnn(inputs[rnn_enum::kData].dtype())) { MKLDNNRnnOp& op = state_ptr.get_state(); op.Backward(ctx, inputs, req, outputs); } else { @@ -338,6 +335,23 @@ Long Short-Term Memory - Hochreiter, 1997. http://www.bioinf.jku.at/publications h_t = o_t * \tanh(c_t) \end{array} +With the projection size being set, LSTM could use the projection feature to reduce the parameters +size and give some speedups without significant damage to the accuracy. + +Long Short-Term Memory Based Recurrent Neural Network Architectures for Large Vocabulary Speech +Recognition - Sak et al. 2014. https://arxiv.org/abs/1402.1128 + +.. math:: + \begin{array}{ll} + i_t = \mathrm{sigmoid}(W_{ii} x_t + b_{ii} + W_{ri} r_{(t-1)} + b_{ri}) \\ + f_t = \mathrm{sigmoid}(W_{if} x_t + b_{if} + W_{rf} r_{(t-1)} + b_{rf}) \\ + g_t = \tanh(W_{ig} x_t + b_{ig} + W_{rc} r_{(t-1)} + b_{rg}) \\ + o_t = \mathrm{sigmoid}(W_{io} x_t + b_{o} + W_{ro} r_{(t-1)} + b_{ro}) \\ + c_t = f_t * c_{(t-1)} + i_t * g_t \\ + h_t = o_t * \tanh(c_t) + r_t = W_{hr} h_t + \end{array} + **GRU** Gated Recurrent Unit - Cho et al. 2014. http://arxiv.org/abs/1406.1078 @@ -385,10 +399,10 @@ The definition of GRU here is slightly different from paper but compatible with }) .set_attr("FInferShape", RNNShape) .set_attr("FInferType", RNNType) -.set_attr("FInferStorageType", RNNStorageType) .set_attr("FCreateOpState", CreateRNNState) .set_attr("FStatefulCompute", RNNStatefulCompute) #if MXNET_USE_MKLDNN == 1 +.set_attr("FInferStorageType", RNNStorageType) .set_attr("TIsMKLDNN", true) .set_attr("FStatefulComputeEx", RNNStatefulComputeExCPU) #endif @@ -427,9 +441,9 @@ NNVM_REGISTER_OP(_backward_RNN) .set_attr_parser(ParamParser) .set_attr("TIsLayerOpBackward", true) .set_attr("TIsBackward", true) -.set_attr("FInferStorageType", RNNStorageType) .set_attr("FStatefulCompute", RNNStatefulGradCompute) #if MXNET_USE_MKLDNN == 1 +.set_attr("FInferStorageType", RNNStorageType) .set_attr("TIsMKLDNN", true) .set_attr("FStatefulComputeEx", RNNStatefulGradComputeExCPU) #endif diff --git a/src/operator/rnn_impl.h b/src/operator/rnn_impl.h index 3aa643421857..008ba7d315c6 100644 --- a/src/operator/rnn_impl.h +++ b/src/operator/rnn_impl.h @@ -209,6 +209,7 @@ void LstmForwardInferenceSingleLayer(DType* ws, const int N, const int I, const int H, + const int P, const Tensor &x, const Tensor &hx, const Tensor &cx, @@ -219,7 +220,9 @@ void LstmForwardInferenceSingleLayer(DType* ws, DType* cy_ptr) { using namespace mshadow; const Tensor wx(w_ptr, Shape2(H * 4, I)); - const Tensor wh(w_ptr + I * H * 4, Shape2(H * 4, H)); + const Tensor wh(w_ptr + I * H * 4, Shape2(H * 4, (P ? P : H))); + Tensor whr(w_ptr, Shape2(1, 1)); + if (P > 0) whr = Tensor(wh.dptr_ + P * 4 * H, Shape2(P, H)); const Tensor bx(b_ptr, Shape2(4, H)); const Tensor bh(b_ptr + H * 4, Shape2(4, H)); Tensor yx_flat(ws, Shape2(T * N, H * 4)); @@ -228,7 +231,10 @@ void LstmForwardInferenceSingleLayer(DType* ws, const Tensor yh(yh_flat.dptr_, Shape3(N, 4, H)); Tensor h(yh_flat.dptr_ + N * H * 4, Shape2(N, H)); Tensor c(h.dptr_ + N * H, Shape2(N, H)); + Tensor r(hy_ptr, Shape2(1, 1)); + if (P > 0) r = Tensor(hy_ptr, Shape2(N, P)); const int offset = bid ? H : 0; + const int proj_offset = bid ? P : 0; const DType alpha = 1.0; const DType beta = 0.0; const int cell_size = N * H; @@ -237,7 +243,11 @@ void LstmForwardInferenceSingleLayer(DType* ws, const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); for (int i = 0; i < T; ++i) { int t = bid ? T - 1 - i : i; - linalg_gemm(i ? h : hx, wh, yh_flat, alpha, beta, false, true); + if (P > 0) { + linalg_gemm(i ? r : hx, wh, yh_flat, alpha, beta, false, true); + } else { + linalg_gemm(i ? h : hx, wh, yh_flat, alpha, beta, false, true); + } #pragma omp parallel for num_threads(omp_threads) for (int jk = 0; jk < cell_size; ++jk) { int j = jk / H; @@ -248,14 +258,21 @@ void LstmForwardInferenceSingleLayer(DType* ws, DType ot = sigmoid(yx[t][j][3][k] + yh[j][3][k] + bx[3][k] + bh[3][k]); DType ct = (i ? c[j][k] : cx[j][k]) * ft + it * gt; DType ht = ot * tanh(ct); - y[t][j][k + offset] = ht; + if (P == 0) y[t][j][k + offset] = ht; if (i == T - 1 && state_outputs) { - hy_ptr[jk] = ht; + if (P == 0) hy_ptr[jk] = ht; cy_ptr[jk] = ct; } else { - h[j][k] = ht; c[j][k] = ct; } + h[j][k] = ht; + } + if (P > 0) { + linalg_gemm(h, whr, r, alpha, beta, false, true); + #pragma omp parallel for num_threads(omp_threads) + for (int j = 0; j < N; ++j) { + std::memcpy(y[t][j].dptr_ + proj_offset, r[j].dptr_, P * sizeof(DType)); + } } } } @@ -269,6 +286,7 @@ void LstmForwardInference(DType* ws, const int N, const int I, const int H, + const int P, DType* x_ptr, DType* hx_ptr, DType* cx_ptr, @@ -278,25 +296,29 @@ void LstmForwardInference(DType* ws, DType* hy_ptr, DType* cy_ptr) { const int total_layers = D * L; - Tensor hx(hx_ptr, Shape3(total_layers, N, H)); + Tensor hx(hx_ptr, Shape3(total_layers, N, P ? P : H)); Tensor cx(cx_ptr, Shape3(total_layers, N, H)); const int b_size = 2 * H * 4; const int cell_size = N * H; + const int projection_size = (P ? P : H) * N; DType* y_tmp_ptr = ws + (T + 1) * cell_size * 4 + cell_size * 2; DType* y_cur_ptr = y_ptr; int idx = 0; // state & cell state's idx; bool flag = L % 2 ? false : true; for (int i = 0; i < L; ++i) { - const int input_size = i ? H * D : I; - const int w_size = (input_size + H) * H * 4; + const int input_size = i ? (P ? P : H) * D : I; + int w_size = (input_size + (P ? P : H)) * H * 4; + if (P > 0) { + w_size += P * H; + } // If bidirectional, need space to save current layer output y. if (D == 2) { y_cur_ptr = flag ? y_tmp_ptr : y_ptr; flag = !flag; } Tensor x(x_ptr, Shape2(T * N, input_size)); - Tensor y(y_cur_ptr, Shape3(T, N, H * D)); - LstmForwardInferenceSingleLayer(ws, state_outputs, false, T, N, input_size, H, + Tensor y(y_cur_ptr, Shape3(T, N, (P ? P : H) * D)); + LstmForwardInferenceSingleLayer(ws, state_outputs, false, T, N, input_size, H, P, x, hx[idx], cx[idx], y, w_ptr, b_ptr, hy_ptr, cy_ptr); // If bidirectional, then calculate the reverse direction's forward result. if (D == 2) { @@ -304,10 +326,10 @@ void LstmForwardInference(DType* ws, b_ptr += b_size; ++idx; if (state_outputs) { - hy_ptr += cell_size; + hy_ptr += projection_size; cy_ptr += cell_size; } - LstmForwardInferenceSingleLayer(ws, state_outputs, true, T, N, input_size, H, + LstmForwardInferenceSingleLayer(ws, state_outputs, true, T, N, input_size, H, P, x, hx[idx], cx[idx], y, w_ptr, b_ptr, hy_ptr, cy_ptr); } // Don't need to move pointer in the last layer. @@ -317,7 +339,7 @@ void LstmForwardInference(DType* ws, x_ptr = y_cur_ptr; ++idx; if (state_outputs) { - hy_ptr += cell_size; + hy_ptr += projection_size; cy_ptr += cell_size; } } diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py index 0f27f53f83a8..c18f6fa5f4e7 100644 --- a/tests/python/unittest/test_gluon_rnn.py +++ b/tests/python/unittest/test_gluon_rnn.py @@ -51,6 +51,62 @@ def test_lstm(): assert outs == [(10, 100), (10, 100), (10, 100)] +@with_seed() +@assert_raises_cudnn_not_satisfied(min_version='7.2.1') +def test_lstmp(): + hidden_size, projection_size = 512, 256 + rtol, atol = 1e-4, 1e-4 + batch_size, seq_len = 5, 3 + input_size = 128 + lstm_input = mx.nd.uniform(shape=(seq_len, batch_size, input_size)) + + # ==== Unidirectional Layer ==== + for num_layers in [1, 3]: + fused_layer = gluon.rnn.LSTM(hidden_size, projection_size=projection_size, + num_layers=num_layers, layout='TNC', bidirectional=False) + fused_layer.collect_params().initialize() + + params = fused_layer.collect_params() + stack_layer = mx.gluon.rnn.HybridSequentialRNNCell(prefix='lstm0_', params=params) + with stack_layer.name_scope(): + for i in range(num_layers): + stack_layer.add(gluon.contrib.rnn.LSTMPCell(hidden_size, + projection_size=projection_size, + prefix='l%d_' % i)) + stack_layer.initialize() + + fused_output = fused_layer(lstm_input.copy()) + stack_output = stack_layer.unroll(seq_len, lstm_input.copy(), layout='TNC', + merge_outputs=True)[0] + + assert_almost_equal(fused_output.asnumpy(), stack_output.asnumpy(), rtol=rtol, atol=atol) + + # ==== Bidirectional Layer ==== + for num_layers in [1, 3]: + fused_layer = gluon.rnn.LSTM(hidden_size, projection_size=projection_size, + num_layers=num_layers, layout='TNC', bidirectional=True) + fused_layer.collect_params().initialize() + + params = fused_layer.collect_params() + stack_layer = mx.gluon.rnn.HybridSequentialRNNCell(prefix='lstm0_', params=params) + with stack_layer.name_scope(): + for i in range(num_layers): + stack_layer.add( + gluon.rnn.BidirectionalCell(gluon.contrib.rnn.LSTMPCell(hidden_size, + projection_size=projection_size, + prefix='l%d_' % i), + gluon.contrib.rnn.LSTMPCell(hidden_size, + projection_size=projection_size, + prefix='r%d_' % i))) + stack_layer.initialize() + + fused_output = fused_layer(lstm_input.copy()) + stack_output = stack_layer.unroll(seq_len, lstm_input.copy(), layout='TNC', + merge_outputs=True)[0] + + assert_almost_equal(fused_output.asnumpy(), stack_output.asnumpy(), rtol=rtol, atol=atol) + + def test_lstm_forget_bias(): forget_bias = 2.0 stack = gluon.rnn.SequentialRNNCell()