Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[v1.x] Backport #17702 and #17872 to v1.x branch #18038

Merged
merged 2 commits into from
Apr 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/mxnet/gluon/rnn/rnn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(self, hidden_size, num_layers, layout,
self._lstm_state_clip_nan = lstm_state_clip_nan
self._dtype = dtype
self._use_sequence_length = use_sequence_length
self.skip_states = None

self._gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode]

Expand Down
8 changes: 0 additions & 8 deletions src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,6 @@ static inline bool SupportMKLDNN(int dtype, const mxnet::TShape &shape) {
(ndim == 1 || ndim == 2 || ndim == 4);
}

static inline bool SupportMKLDNNRnn(const NDArray &input) {
if (input.dtype() == mshadow::kFloat32 && input.shape().ndim() == 3
&& dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1)) {
return true;
}
return false;
}

static inline bool SupportMKLDNNQuantize(int dtype) {
return dtype == mshadow::kFloat32 || dtype == mshadow::kInt8 ||
dtype == mshadow::kUint8 || dtype == mshadow::kBfloat16;
Expand Down
16 changes: 14 additions & 2 deletions src/operator/nn/mkldnn/mkldnn_rnn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ struct MKLDNNRnnLayerParam {
size_t reserve_size; // used for the reserved cached memory in Backward
size_t single_w_size; // weights size of a single cell
size_t single_b_size; // bias size of a single cell from mkl-dnn
size_t naive_single_b_size; // bias size of a single cell from framework
size_t single_state_size; // state size of a single cell, hy, cy
size_t native_single_b_size; // bias size of a single cell from framework
size_t single_state_size; // state size of a single cell, hy, cy

MKLDNNRnnLayerParam(int num_layer, int batch_size, int seq_len,
int input_size, int state_size,
Expand Down Expand Up @@ -441,6 +441,18 @@ class MKLDNNRnnOp {
const std::vector<NDArray> &outputs);
};

inline bool SupportMKLDNNRnn(const int input_dtype) {
if (input_dtype == mshadow::kFloat32 && dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1)) {
return true;
}
return false;
}

inline bool SupportMKLDNNRnn(const RNNParam &param, const int input_dtype) {
if (param.projection_size.has_value()) return false;
return SupportMKLDNNRnn(input_dtype);
}

} // namespace op
} // namespace mxnet

Expand Down
43 changes: 21 additions & 22 deletions src/operator/nn/mkldnn/mkldnn_rnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ void MKLDNNRnnLayerParam::SetDims() {
// unidirectional size of a single cell
single_w_size = (input_size + state_size) * ngates * state_size;
single_b_size = nbias * state_size;
naive_single_b_size = ngates * state_size * 2; // naive RNN variants have double bias
native_single_b_size = ngates * state_size * 2; // native RNN variants have double bias
single_state_size = batch_size * state_size;

// Get workspace size for cached weights memory
Expand Down Expand Up @@ -265,7 +265,7 @@ RnnBwdPrimitive GetRnnBwdPrim(const MKLDNNRnnForwardTraining &fwd,
}

/*
* Naive weights layout is:
* Native weights layout is:
* | l0_l2r_wx | l0_l2r_wh | l0_r2l_wx | l0_r2l_wh |
* | l1_l2r_wx | l1_l2r_wh | l1_r2l_wx | l1_r2l_wh |
* ...
Expand Down Expand Up @@ -339,7 +339,6 @@ FUNC(MKLDNN_ARG_DIFF_##NAME, ARGS.at(MKLDNN_ARG_##NAME).get_desc(), HANDLE)
void MKLDNNRnnForward::SetNewDataMem(void* x, void* hx, void* cx,
void* y, void* hy, void* cy,
const int dtype) {
using dims = mkldnn::memory::dims;
using desc = mkldnn::memory::desc;
using format_tag = mkldnn::memory::format_tag;
auto& cpu_engine = CpuEngine::Get()->get_engine();
Expand Down Expand Up @@ -462,12 +461,12 @@ inline void EmplaceNetArgs(mkldnn_args_map_t* net_args, const int arg_name,
}

/*
* Copy naive memory to mkldnn-format memory. It will initialize the memory
* when first invoked. Then, the naive weight_layer and weight_iter are
* Copy native memory to mkldnn-format memory. It will initialize the memory
* when first invoked. Then, the native weight_layer and weight_iter are
* concatenated to xxx_xx_r memory. Per the different gates order of GRU,
* it will swap the memory blocks of gates among concatenated memory
* inplace. From then on, the xxx_xx_r memory is reordered to target
* memory with preferred format_tag. Finally, naive bias is fused to MKLDNN
* memory with preferred format_tag. Finally, native bias is fused to MKLDNN
* bias memory.
*/
void MKLDNNRnnForward::SetWeightsMem(MKLDNNRnnMemMgr* mgr, void *w_ptr, void *b_ptr,
Expand Down Expand Up @@ -551,13 +550,13 @@ void MKLDNNRnnForward::SetWeightsMem(MKLDNNRnnMemMgr* mgr, void *w_ptr, void *b_

// Process bias
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
DType* naive_b_ptr = static_cast<DType*>(b_ptr);
DType* native_b_ptr = static_cast<DType*>(b_ptr);
DType* fused_bias = static_cast<DType*>(bias_->get_data_handle());
for (int lyr = 0; lyr < param_.num_layer; ++lyr) {
for (int d = 0; d < param_.bidirectional + 1; ++d) {
FuseBias<DType>(fused_bias, naive_b_ptr, param_.mode, param_.state_size);
FuseBias<DType>(fused_bias, native_b_ptr, param_.mode, param_.state_size);
fused_bias += param_.single_b_size;
naive_b_ptr += param_.naive_single_b_size;
native_b_ptr += param_.native_single_b_size;
}
}
});
Expand Down Expand Up @@ -632,7 +631,6 @@ void MKLDNNRnnOp::Init(const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
using memory = mkldnn::memory;
using format_tag = mkldnn::memory::format_tag;

// In the `autograd.record()` context, RNNOp is required to run into
Expand Down Expand Up @@ -674,10 +672,10 @@ void MKLDNNRnnOp::Init(const OpContext &ctx,
default_param.bidirectional + 1, default_param.mode)) * dtype_bytes;
for (auto& fwd_layer : fwd_inf_vec_) {
size_t single_w_bytes = fwd_layer.GetParam().single_w_size * dtype_bytes;
size_t single_b_bytes = fwd_layer.GetParam().naive_single_b_size * dtype_bytes;
size_t single_b_bytes = fwd_layer.GetParam().native_single_b_size * dtype_bytes;
size_t directions = fwd_layer.GetParam().bidirectional ? 2 : 1;
size_t layer_weights_bytes = single_w_bytes * directions;
size_t layer_bias_bytes = single_b_bytes * directions; // Naive MXNet has double bias
size_t layer_bias_bytes = single_b_bytes * directions; // Native MXNet has double bias

if (!fwd_layer.IsInitialized() || is_training)
fwd_layer.SetWeightsMem(&(this->mgr_), weights_ptr, bias_ptr, is_training, dtype);
Expand Down Expand Up @@ -857,7 +855,7 @@ void MKLDNNRnnBackward::CommitWeightsGrads(void* diff_weights, void* diff_bias,
const size_t wx_size = param.input_size * param.state_size * ngates;
const size_t wh_size = param.state_size * param.state_size * ngates;

/* naive weights layout is:
/* native weights layout is:
1st-layer: | wx_lr | wh_lr | wx_rl | wh_rl |
2st-layer: | wx_lr | wh_lr | wx_rl | wh_rl |
size: | wxh_bytes |
Expand Down Expand Up @@ -903,33 +901,33 @@ void MKLDNNRnnBackward::CommitWeightsGrads(void* diff_weights, void* diff_bias,
});

const size_t bias_size = param.single_b_size;
const size_t naive_bias_size = param.naive_single_b_size;
const size_t native_bias_size = param.native_single_b_size;
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
DType* native_bias = static_cast<DType *>(diff_bias);
DType* diff_bias_ptr = static_cast<DType *>(this->diff_bias_->get_data_handle());
OPREQTYPE_SWITCH(req, DType, FAccGrad, {
if (param.mode != rnn_enum::kGru) {
for (int shift = 0; shift < num_layer * direction; ++shift) {
FAccGrad(native_bias + shift * naive_bias_size,
FAccGrad(native_bias + shift * native_bias_size,
diff_bias_ptr + shift * bias_size, bias_size);
FAccGrad(native_bias + shift * naive_bias_size + bias_size,
FAccGrad(native_bias + shift * native_bias_size + bias_size,
diff_bias_ptr + shift * bias_size, bias_size);
}
} else {
const size_t bias_size_per_gate = param.state_size;
for (int shift = 0; shift < num_layer * direction; ++shift) {
DType* native_reset = native_bias + shift * naive_bias_size;
DType* native_reset = native_bias + shift * native_bias_size;
DType* native_update = native_reset + bias_size_per_gate;
DType* update = diff_bias_ptr + shift * bias_size;
DType* reset = update + bias_size_per_gate;

FAccGrad(native_update, update, bias_size_per_gate);
FAccGrad(native_reset, reset, bias_size_per_gate);
FAccGrad(native_update + naive_bias_size / 2, update, bias_size_per_gate);
FAccGrad(native_reset + naive_bias_size / 2, reset, bias_size_per_gate);
FAccGrad(native_update + native_bias_size / 2, update, bias_size_per_gate);
FAccGrad(native_reset + native_bias_size / 2, reset, bias_size_per_gate);

DType* native_new_bx = native_update + bias_size_per_gate;
DType* native_new_bh = native_new_bx + naive_bias_size / 2;
DType* native_new_bh = native_new_bx + native_bias_size / 2;
DType* new_bx = reset + bias_size_per_gate;
DType* new_bh = new_bx + bias_size_per_gate;
FAccGrad(native_new_bx, new_bx, bias_size_per_gate);
Expand Down Expand Up @@ -1186,10 +1184,11 @@ void MKLDNNRnnOp::Backward(const OpContext& ctx,

// Commit weights diff
if (req[rnn_enum::kParams] != kNullOp) {
const int directions = default_param.bidirectional ? 2 : 1;
for (size_t lyr = 0; lyr < bwd_vec_.size(); ++lyr) {
bwd_vec_.at(lyr).CommitWeightsGrads(dw, db, req[rnn_enum::kParams], w_dtype);
dw += full_param_.layer_params.at(lyr).single_w_size * w_bytes;
db += full_param_.layer_params.at(lyr).single_b_size * w_bytes;
dw += full_param_.layer_params.at(lyr).single_w_size * directions * w_bytes;
db += full_param_.layer_params.at(lyr).native_single_b_size * directions * w_bytes;
}
}
}
Expand Down
33 changes: 24 additions & 9 deletions src/operator/rnn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ inline int GetRnnBiasSize(int num_layer,
inline size_t GetRNNWorkspaceSize(int seq_length,
int batch_size,
int hidden_size,
int projection_size,
int direction,
int mode) {
size_t size = 0;
Expand Down Expand Up @@ -324,6 +325,7 @@ void RNNForwardInference(DType* ws,
const int batch_size,
const int input_size,
const int state_size,
const int projection_size,
DType* x_ptr,
DType* hx_ptr,
DType* cx_ptr,
Expand All @@ -336,8 +338,8 @@ void RNNForwardInference(DType* ws,
switch (mode) {
case rnn_enum::kLstm:
LstmForwardInference<DType>(ws, state_outputs, num_layers, direction, seq_length,
batch_size, input_size, state_size, x_ptr, hx_ptr, cx_ptr,
w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr);
batch_size, input_size, state_size, projection_size,
x_ptr, hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr);
break;
case rnn_enum::kGru:
GruForwardInference<DType>(ws, state_outputs, num_layers, direction, seq_length,
Expand Down Expand Up @@ -511,10 +513,7 @@ class RNNOp {
this->temp_init_space_ = false;
this->reserve_cpu_space_size_ = 0;
this->temp_cpu_space_size_ = 0;
if (param_.projection_size.has_value()) {
LOG(FATAL) <<
"hidden layer projection is only supported for GPU with CuDNN later than 7.1.1";
}

if (param_.lstm_state_clip_min.has_value()
|| param_.lstm_state_clip_max.has_value()) {
LOG(FATAL) << "LSTM state clipping is only supported for GPU with CuDNN later than 7.2.1";
Expand Down Expand Up @@ -843,9 +842,14 @@ class RNNOp {
#endif // MXNET_USE_CUDNN == 1 && defined(__CUDACC__)

if (ctx_.dev_type == kCPU) {
int projection_size = 0;
if (param_.projection_size.has_value()) {
projection_size = param_.projection_size.value();
}

// allocate temp space
const size_t work_cpu_space_size = GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_,
param_.state_size, direction, param_.mode);
param_.state_size, projection_size, direction, param_.mode);
if (!temp_init_space_ || temp_cpu_space_size_ < work_cpu_space_size) {
temp_cpu_space_size_ = work_cpu_space_size;
temp_cpu_space_ = NDArray(TShape({static_cast<dim_t>(temp_cpu_space_size_)}), ctx_,
Expand All @@ -856,6 +860,9 @@ class RNNOp {

if (ctx.is_train || ctx.need_grad) {
// allocate reserve space
if (param_.projection_size.has_value()) {
LOG(FATAL) << "No training support for LSTM with projection on CPU currently.";
}

const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction,
param_.seq_length_, param_.batch_size_,
Expand Down Expand Up @@ -896,6 +903,7 @@ class RNNOp {
param_.batch_size_,
param_.input_size_,
param_.state_size,
projection_size,
x.dptr_,
hx.dptr_,
cx_ptr,
Expand Down Expand Up @@ -1096,10 +1104,17 @@ class RNNOp {
#endif // MXNET_USE_CUDNN == 1 && defined(__CUDACC__)

if (ctx_.dev_type == kCPU) {
int projection_size = 0;
if (param_.projection_size.has_value()) {
// TODO(zixuanweeei): Add training support for LSTM with projection on CPU.
// projection_size = param_.projection_size.value();
LOG(FATAL) << "No training support for LSTM with projection on CPU currently.";
}

// allocate temp space
const size_t work_cpu_space_size =
GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_,
param_.state_size, direction, param_.mode);
GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, param_.state_size,
projection_size, direction, param_.mode);
if (!temp_init_space_ || temp_cpu_space_size_ != work_cpu_space_size) {
LOG(FATAL) << "Check temp init error";
}
Expand Down
44 changes: 29 additions & 15 deletions src/operator/rnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -190,20 +190,19 @@ static std::vector<ResourceRequest> RNNResourceEx(const NodeAttrs& attrs, const
return request;
}

#if MXNET_USE_MKLDNN == 1
inline static bool RNNStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
DispatchMode wanted_mode = DispatchMode::kFCompute;

#if MXNET_USE_MKLDNN == 1
wanted_mode = DispatchMode::kFComputeEx;
#endif // MXNET_USE_MKLDNN == 1

return storage_type_assign(out_attrs, mxnet::kDefaultStorage,
dispatch_mode, wanted_mode);
const RNNParam& param = nnvm::get<RNNParam>(attrs.parsed);
const bool support_mkldnn_rnn =
!param.projection_size.has_value() && dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1);
return MKLDNNStorageType(attrs, dev_mask, support_mkldnn_rnn,
dispatch_mode, in_attrs, out_attrs);
}
#endif // MXNET_USE_MKLDNN == 1

struct RNNGrad {
const char *op_name;
Expand Down Expand Up @@ -246,9 +245,7 @@ static OpStatePtr CreateRNNState(const nnvm::NodeAttrs &attrs,
}

#if MXNET_USE_MKLDNN == 1
if ((in_types[0] == mshadow::kFloat32 || in_types[0] == mshadow::kFloat16)
&& in_shapes[0].ndim() == 3 && ctx.dev_type == kCPU
&& dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1)) {
if (ctx.dev_type == kCPU && SupportMKLDNNRnn(param, in_types[rnn_enum::kData])) {
const mxnet::TShape& data_shape = in_shapes[rnn_enum::kData];
state = OpStatePtr::Create<MKLDNNRnnOp>(param, data_shape[0],
data_shape[1], data_shape[2]);
Expand All @@ -274,7 +271,7 @@ static void RNNStatefulComputeExCPU(const OpStatePtr& state_ptr,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
if (SupportMKLDNNRnn(inputs[0])) {
if (SupportMKLDNNRnn(inputs[rnn_enum::kData].dtype())) {
MKLDNNRnnOp& op = state_ptr.get_state<MKLDNNRnnOp>();
op.Forward(ctx, inputs, req, outputs);
} else {
Expand All @@ -287,7 +284,7 @@ static void RNNStatefulGradComputeExCPU(const OpStatePtr& state_ptr,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
if (SupportMKLDNNRnn(inputs[0])) {
if (SupportMKLDNNRnn(inputs[rnn_enum::kData].dtype())) {
MKLDNNRnnOp& op = state_ptr.get_state<MKLDNNRnnOp>();
op.Backward(ctx, inputs, req, outputs);
} else {
Expand Down Expand Up @@ -338,6 +335,23 @@ Long Short-Term Memory - Hochreiter, 1997. http://www.bioinf.jku.at/publications
h_t = o_t * \tanh(c_t)
\end{array}

With the projection size being set, LSTM could use the projection feature to reduce the parameters
size and give some speedups without significant damage to the accuracy.

Long Short-Term Memory Based Recurrent Neural Network Architectures for Large Vocabulary Speech
Recognition - Sak et al. 2014. https://arxiv.org/abs/1402.1128

.. math::
\begin{array}{ll}
i_t = \mathrm{sigmoid}(W_{ii} x_t + b_{ii} + W_{ri} r_{(t-1)} + b_{ri}) \\
f_t = \mathrm{sigmoid}(W_{if} x_t + b_{if} + W_{rf} r_{(t-1)} + b_{rf}) \\
g_t = \tanh(W_{ig} x_t + b_{ig} + W_{rc} r_{(t-1)} + b_{rg}) \\
o_t = \mathrm{sigmoid}(W_{io} x_t + b_{o} + W_{ro} r_{(t-1)} + b_{ro}) \\
c_t = f_t * c_{(t-1)} + i_t * g_t \\
h_t = o_t * \tanh(c_t)
r_t = W_{hr} h_t
\end{array}

**GRU**

Gated Recurrent Unit - Cho et al. 2014. http://arxiv.org/abs/1406.1078
Expand Down Expand Up @@ -385,10 +399,10 @@ The definition of GRU here is slightly different from paper but compatible with
})
.set_attr<mxnet::FInferShape>("FInferShape", RNNShape)
.set_attr<nnvm::FInferType>("FInferType", RNNType)
.set_attr<FInferStorageType>("FInferStorageType", RNNStorageType)
.set_attr<FCreateOpState>("FCreateOpState", CreateRNNState)
.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", RNNStatefulCompute<cpu>)
#if MXNET_USE_MKLDNN == 1
.set_attr<FInferStorageType>("FInferStorageType", RNNStorageType)
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", RNNStatefulComputeExCPU)
#endif
Expand Down Expand Up @@ -427,9 +441,9 @@ NNVM_REGISTER_OP(_backward_RNN)
.set_attr_parser(ParamParser<RNNParam>)
.set_attr<bool>("TIsLayerOpBackward", true)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FInferStorageType>("FInferStorageType", RNNStorageType)
.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", RNNStatefulGradCompute<cpu>)
#if MXNET_USE_MKLDNN == 1
.set_attr<FInferStorageType>("FInferStorageType", RNNStorageType)
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", RNNStatefulGradComputeExCPU)
#endif
Expand Down
Loading