diff --git a/cpp-package/example/charRNN.cpp b/cpp-package/example/charRNN.cpp index ac5faa47b58c..94e9455c5941 100644 --- a/cpp-package/example/charRNN.cpp +++ b/cpp-package/example/charRNN.cpp @@ -164,8 +164,9 @@ Symbol LSTMWithBuiltInRNNOp(int num_lstm_layer, int sequence_length, int input_d auto rnn_h_init = Symbol::Variable("LSTM_init_h"); auto rnn_c_init = Symbol::Variable("LSTM_init_c"); auto rnn_params = Symbol::Variable("LSTM_parameters"); // See explanations near RNNXavier class - auto rnn = RNN(embed, rnn_params, rnn_h_init, rnn_c_init, num_hidden, num_lstm_layer, - RNNMode::kLstm, false, dropout, !isTrain); + auto variable_sequence_length = Symbol::Variable("sequence_length"); + auto rnn = RNN(embed, rnn_params, rnn_h_init, rnn_c_init, variable_sequence_length, num_hidden, + num_lstm_layer, RNNMode::kLstm, false, dropout, !isTrain); auto hidden = Reshape(rnn[0], Shape(), false, Shape(0, num_hidden), false); auto cls_weight = Symbol::Variable("cls_weight"); diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py index 6dfec43a8b5f..b3cc596282a7 100644 --- a/python/mxnet/gluon/rnn/rnn_layer.py +++ b/python/mxnet/gluon/rnn/rnn_layer.py @@ -37,7 +37,7 @@ def __init__(self, hidden_size, num_layers, layout, i2h_bias_initializer, h2h_bias_initializer, mode, projection_size, h2r_weight_initializer, lstm_state_clip_min, lstm_state_clip_max, lstm_state_clip_nan, - dtype, **kwargs): + dtype, use_sequence_length=False, **kwargs): super(_RNNLayer, self).__init__(**kwargs) assert layout in ('TNC', 'NTC'), \ "Invalid layout %s; must be one of ['TNC' or 'NTC']"%layout @@ -58,6 +58,7 @@ def __init__(self, hidden_size, num_layers, layout, self._lstm_state_clip_max = lstm_state_clip_max self._lstm_state_clip_nan = lstm_state_clip_nan self._dtype = dtype + self._use_sequence_length = use_sequence_length self._gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode] @@ -219,29 +220,39 @@ def begin_state(self, batch_size=0, func=ndarray.zeros, **kwargs): states.append(func(name='%sh0_%d'%(self.prefix, i), **info)) return states - def hybrid_forward(self, F, inputs, states=None, **kwargs): - if F is ndarray: - batch_size = inputs.shape[self._layout.find('N')] - skip_states = states is None - if skip_states: - if F is ndarray: + def __call__(self, inputs, states=None, sequence_length=None, **kwargs): + self.skip_states = states is None + if states is None: + if isinstance(inputs, ndarray.NDArray): + batch_size = inputs.shape[self._layout.find('N')] states = self.begin_state(batch_size, ctx=inputs.context, dtype=inputs.dtype) else: states = self.begin_state(0, func=symbol.zeros) if isinstance(states, tensor_types): states = [states] + + if self._use_sequence_length: + return super(_RNNLayer, self).__call__(inputs, states, sequence_length, **kwargs) + else: + return super(_RNNLayer, self).__call__(inputs, states, **kwargs) + + + def hybrid_forward(self, F, inputs, states, sequence_length=None, **kwargs): + if F is ndarray: + batch_size = inputs.shape[self._layout.find('N')] + if F is ndarray: for state, info in zip(states, self.state_info(batch_size)): if state.shape != info['shape']: raise ValueError( "Invalid recurrent state shape. Expecting %s, got %s."%( str(info['shape']), str(state.shape))) - out = self._forward_kernel(F, inputs, states, **kwargs) + out = self._forward_kernel(F, inputs, states, sequence_length, **kwargs) # out is (output, state) - return out[0] if skip_states else out + return out[0] if self.skip_states else out - def _forward_kernel(self, F, inputs, states, **kwargs): + def _forward_kernel(self, F, inputs, states, sequence_length, **kwargs): """ forward using CUDNN or CPU kenrel""" if self._layout == 'NTC': inputs = F.swapaxes(inputs, dim1=0, dim2=1) @@ -261,14 +272,20 @@ def _forward_kernel(self, F, inputs, states, **kwargs): params = F._internal._rnn_param_concat(*params, dim=0) - rnn = F.RNN(inputs, params, *states, state_size=self._hidden_size, - projection_size=self._projection_size, + if self._use_sequence_length: + rnn_args = states + [sequence_length] + else: + rnn_args = states + + rnn = F.RNN(inputs, params, *rnn_args, use_sequence_length=self._use_sequence_length, + state_size=self._hidden_size, projection_size=self._projection_size, num_layers=self._num_layers, bidirectional=self._dir == 2, p=self._dropout, state_outputs=True, mode=self._mode, lstm_state_clip_min=self._lstm_state_clip_min, lstm_state_clip_max=self._lstm_state_clip_max, lstm_state_clip_nan=self._lstm_state_clip_nan) + if self._mode == 'lstm': outputs, states = rnn[0], [rnn[1], rnn[2]] else: diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index 37f21ce6d126..d164333953f2 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -27,7 +27,7 @@ #define MXNET_OPERATOR_RNN_INL_H_ #define MXNET_USE_CUDNN_RNN MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5 -#define USE_CUDNN_LSTM_PROJ MXNET_USE_CUDNN == 1 && CUDNN_VERSION >= 7200 +#define MXNET_USE_CUDNN_GE_7200 MXNET_USE_CUDNN == 1 && CUDNN_VERSION >= 7200 #include #include @@ -39,6 +39,7 @@ #include #include #include + #include "./math.h" #include "./math_functions-inl.h" #include "./operator_common.h" @@ -48,10 +49,10 @@ namespace mxnet { namespace op { namespace rnn_enum { - enum RNNOpInputs {kData, kParams, kState, kStateCell}; + enum RNNOpInputs {kData, kParams, kState, kStateCell, kSequenceLength}; enum RNNOpOutputs {kOut, kStateOut, kStateCellOut}; enum RNNModeType {kRnnRelu, kRnnTanh, kLstm, kGru}; - enum RNNOpResource {kCuDNNDropoutDescSpace}; + enum RNNOpResource {kTempSpace, kCuDNNDropoutDescSpace}; } inline int GetRnnParamSize(int num_layer, @@ -166,6 +167,8 @@ struct RNNParam : public dmlc::Parameter { int mode; float p; int seq_length_, batch_size_, input_size_; + + bool use_sequence_length; dmlc::optional projection_size; dmlc::optional lstm_state_clip_min, lstm_state_clip_max; bool lstm_state_clip_nan; @@ -212,9 +215,22 @@ struct RNNParam : public dmlc::Parameter { .set_default(false) .describe("Whether to stop NaN from propagating in state by clipping it to min/max. " "If clipping range is not specified, this option is ignored."); + + DMLC_DECLARE_FIELD(use_sequence_length) + .set_default(false) + .describe( + "If set to true, this layer takes in an extra input parameter " + "`sequence_length` " + "to specify variable length sequence"); } }; +inline size_t GetNumInputArguments(RNNParam param_) { + size_t num_inputs = (param_.mode == rnn_enum::kLstm) ? 4U : 3U; + if (param_.use_sequence_length) num_inputs += 1U; + return num_inputs; +} + /** * @params: ws: Temp workspace for gemm's output storage. * rs: Reserve space of forward intermediate data used for training. @@ -379,7 +395,7 @@ void RNNBackward(DType* ws, } } -template +template class RNNOp { public: RNNParam param_; @@ -415,7 +431,7 @@ class RNNOp { default: LOG(FATAL) << "Not implmented"; } -#if USE_CUDNN_LSTM_PROJ +#if MXNET_USE_CUDNN_GE_7200 if (param_.projection_size.has_value()) { CHECK_EQ(param_.mode, rnn_enum::kLstm) << "Projection is only supported for LSTM."; @@ -426,7 +442,7 @@ class RNNOp { CHECK(!param_.projection_size.has_value()) << "Projection is only supported for LSTM with CuDNN version later than 7.1.1."; #endif -#if USE_CUDNN_LSTM_PROJ +#if MXNET_USE_CUDNN_GE_7200 if (param_.lstm_state_clip_min.has_value() || param_.lstm_state_clip_max.has_value()) { CHECK_EQ(param_.mode, rnn_enum::kLstm) @@ -459,7 +475,7 @@ class RNNOp { CUDNN_CALL(cudnnCreateRNNDescriptor(&rnn_desc_)); CUDNN_CALL(cudnnCreateDropoutDescriptor(&dropout_desc_)); - #if USE_CUDNN_LSTM_PROJ + #if MXNET_USE_CUDNN_GE_7200 CUDNN_CALL(cudnnCreateRNNDataDescriptor(&x_data_desc_)); CUDNN_CALL(cudnnCreateRNNDataDescriptor(&y_data_desc_)); CUDNN_CALL(cudnnCreateRNNDataDescriptor(&dx_data_desc_)); @@ -515,7 +531,7 @@ class RNNOp { Storage::Get()->Free(temp_space_); Storage::Get()->Free(reserve_space_); } - #if USE_CUDNN_LSTM_PROJ + #if MXNET_USE_CUDNN_GE_7200 CUDNN_CALL(cudnnDestroyRNNDataDescriptor(x_data_desc_)); CUDNN_CALL(cudnnDestroyRNNDataDescriptor(y_data_desc_)); CUDNN_CALL(cudnnDestroyRNNDataDescriptor(dx_data_desc_)); @@ -541,8 +557,9 @@ class RNNOp { using namespace mshadow; using namespace mshadow::expr; CHECK(param_.p >= 0.0f && param_.p < 1.0f) - << "unsupported dropout value, should be 0 <= dropout < 1"; - size_t num_inputs = (param_.mode == rnn_enum::kLstm) ? 4 : 3; + << "unsupported dropout value, should be 0 <= dropout < 1"; + size_t num_inputs = GetNumInputArguments(param_); + // kOut size_t num_outputs = 1; if (param_.state_outputs) { @@ -553,6 +570,7 @@ class RNNOp { CHECK_EQ(in_data.size(), num_inputs); CHECK_EQ(out_data.size(), num_outputs); Stream *s = ctx.get_stream(); + // get input + output tensors Tensor x = in_data[rnn_enum::kData].get(s); Tensor w = in_data[rnn_enum::kParams].get(s); @@ -562,6 +580,7 @@ class RNNOp { param_.seq_length_ = x.shape_[0]; param_.batch_size_ = x.shape_[1]; param_.input_size_ = x.shape_[2]; + const int direction = param_.bidirectional ? 2 : 1; const int bsize = GetRnnBiasSize(param_.num_layers, param_.state_size, direction, param_.mode); DType* b_ptr = w.dptr_ + w.shape_[0] - bsize; @@ -570,65 +589,130 @@ class RNNOp { if (param_.state_outputs) { hy_ptr = out_data[rnn_enum::kStateOut].dptr(); } + + +#if MXNET_USE_CUDNN_GE_7200 + Tensor host_workspace; + int *sequence_length_cpu_int = NULL; + IType *sequence_length_cpu_itype = NULL; + + if (ctx_.dev_type == kGPU) { + int host_workspace_bytes = + param_.batch_size_ * sizeof(IType) + param_.batch_size_ * sizeof(int); + + host_workspace = + ctx.requested[rnn_enum::kTempSpace].get_host_space_typed<1, char>( + Shape1(host_workspace_bytes)); + + sequence_length_cpu_int = reinterpret_cast(host_workspace.dptr_); + sequence_length_cpu_itype = + reinterpret_cast(host_workspace.dptr_ + sizeof(int) * param_.batch_size_); + + (void)sequence_length_cpu_int; + (void)sequence_length_cpu_itype; + } +#endif + + + if (param_.use_sequence_length) { +#if MXNET_USE_CUDNN_GE_7200 + if (ctx_.dev_type == kCPU) { + LOG(FATAL) << "RNN use_sequence_length option is only available for cuDNN at the moment." + << " Not supported on CPU"; + } + + // We can assume we are on GPU for now + size_t seq_len_input_idx = rnn_enum::kSequenceLength; + if (param_.mode != rnn_enum::kLstm) { + seq_len_input_idx -= 1; + } + IType *sequence_length_ptr_gpu = (in_data[seq_len_input_idx].get(s)).dptr_; + + // Need to copy from GPU -> CPU, becuase cuDNN API requires this array on CPU memory. + // TODO(stephenrawls): In future, allow users to pass this array on the CPU so we don't have + // to do this copy For now however it is required as several places in backend assume that + // all data arrays share the same context. + CUDA_CALL(cudaMemcpy(sequence_length_cpu_itype, sequence_length_ptr_gpu, + sizeof(IType) * param_.batch_size_, cudaMemcpyDeviceToHost)); +#else + LOG(FATAL) << "RNN use_sequence_length option is only available for cuDNN version >= 7.2"; +#endif + } DType* cx_ptr = NULL; DType* cy_ptr = NULL; - if (param_.mode == rnn_enum::kLstm) + if (param_.mode == rnn_enum::kLstm) { cx_ptr = (in_data[rnn_enum::kStateCell].get(s)).dptr_; - if (param_.mode == rnn_enum::kLstm && param_.state_outputs) + } + if (param_.mode == rnn_enum::kLstm && param_.state_outputs) { cy_ptr = (out_data[rnn_enum::kStateCellOut].get(s)).dptr_; - + } CHECK_EQ(x.CheckContiguous(), true); CHECK_EQ(w.CheckContiguous(), true); CHECK_EQ(hx.CheckContiguous(), true); CHECK_EQ(y.CheckContiguous(), true); - #if MXNET_USE_CUDNN_RNN && defined(__CUDACC__) +#if MXNET_USE_CUDNN_RNN && defined(__CUDACC__) if (!init_cudnn_) { Init(ctx, s, in_data, out_data); } - #if USE_CUDNN_LSTM_PROJ - std::vector seqLengthArray(param_.batch_size_, param_.seq_length_); +#if MXNET_USE_CUDNN_GE_7200 + + cudnnRNNDataLayout_t layout_t; + + if (param_.use_sequence_length) { + // Note: Can't mempcy, sequence_length_ptr_cpu is of type Itype, not nescesarily int + for (int i = 0; i < param_.batch_size_; ++i) { + sequence_length_cpu_int[i] = sequence_length_cpu_itype[i]; + } + layout_t = CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED; + } else { + for (int i = 0; i < param_.batch_size_; ++i) { + sequence_length_cpu_int[i] = param_.seq_length_; + } + layout_t = CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED; + } + CUDNN_CALL(cudnnSetRNNDataDescriptor(x_data_desc_, dtype_, - CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED, + layout_t, param_.seq_length_, param_.batch_size_, param_.input_size_, - seqLengthArray.data(), - nullptr)); + sequence_length_cpu_int, + reinterpret_cast(&padding_fill_))); int out_size = (param_.projection_size.has_value()) ? param_.projection_size.value() : param_.state_size; out_size = (param_.bidirectional) ? (out_size * 2) : out_size; CUDNN_CALL(cudnnSetRNNDataDescriptor(y_data_desc_, dtype_, - CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED, + layout_t, param_.seq_length_, param_.batch_size_, out_size, - seqLengthArray.data(), - nullptr)); + sequence_length_cpu_int, + reinterpret_cast(&padding_fill_))); if (ctx.is_train) { CUDNN_CALL(cudnnSetRNNDataDescriptor(dx_data_desc_, dtype_, - CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED, + layout_t, param_.seq_length_, param_.batch_size_, param_.input_size_, - seqLengthArray.data(), - nullptr)); + sequence_length_cpu_int, + reinterpret_cast(&padding_fill_))); CUDNN_CALL(cudnnSetRNNDataDescriptor(dy_data_desc_, dtype_, - CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED, + layout_t, param_.seq_length_, param_.batch_size_, out_size, - seqLengthArray.data(), - nullptr)); + sequence_length_cpu_int, + reinterpret_cast(&padding_fill_))); } - #endif +#endif - #if USE_CUDNN_LSTM_PROJ +#if MXNET_USE_CUDNN_GE_7200 bool clip_state = param_.lstm_state_clip_min.has_value(); bool clip_nan = param_.lstm_state_clip_nan; CUDNN_CALL(cudnnRNNSetClip(s->dnn_handle_, @@ -637,10 +721,10 @@ class RNNOp { clip_nan ? CUDNN_NOT_PROPAGATE_NAN : CUDNN_PROPAGATE_NAN, clip_state ? param_.lstm_state_clip_min.value() : 0.0, clip_state ? param_.lstm_state_clip_max.value() : 0.0)); - #endif +#endif if (ctx.is_train) { - #if USE_CUDNN_LSTM_PROJ +#if MXNET_USE_CUDNN_GE_7200 CUDNN_CALL(cudnnRNNForwardTrainingEx(s->dnn_handle_, rnn_desc_, x_data_desc_, @@ -669,7 +753,7 @@ class RNNOp { workspace_byte_, reserve_space_.dptr, reserve_space_byte_)); - #else +#else CUDNN_CALL(cudnnRNNForwardTraining(s->dnn_handle_, rnn_desc_, param_.seq_length_, @@ -691,9 +775,9 @@ class RNNOp { workspace_byte_, reserve_space_.dptr, reserve_space_byte_)); - #endif +#endif } else { - #if USE_CUDNN_LSTM_PROJ +#if MXNET_USE_CUDNN_GE_7200 CUDNN_CALL(cudnnRNNForwardInferenceEx(s->dnn_handle_, rnn_desc_, x_data_desc_, @@ -720,7 +804,7 @@ class RNNOp { nullptr, temp_space_.dptr, workspace_byte_)); - #else +#else CUDNN_CALL(cudnnRNNForwardInference(s->dnn_handle_, rnn_desc_, param_.seq_length_, @@ -740,22 +824,22 @@ class RNNOp { cy_ptr, temp_space_.dptr, workspace_byte_)); - #endif +#endif } - #endif +#endif if (ctx_.dev_type == kCPU) { // allocate temp space const size_t work_cpu_space_size = - GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, - param_.state_size, direction, param_.mode); + GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, + param_.state_size, direction, param_.mode); if (temp_init_space_ && temp_cpu_space_size_ < work_cpu_space_size) { - Storage::Get()->Free(temp_cpu_space_); - temp_init_space_ = false; + Storage::Get()->Free(temp_cpu_space_); + temp_init_space_ = false; } if (!temp_init_space_) { temp_cpu_space_ = Storage::Get()->Alloc - (work_cpu_space_size * sizeof(DType), Context::CPU()); + (work_cpu_space_size * sizeof(DType), Context::CPU()); temp_cpu_space_size_ = work_cpu_space_size; temp_init_space_ = true; } @@ -828,7 +912,8 @@ class RNNOp { CHECK(param_.p >= 0.0f && param_.p < 1.0f) << "unsupported dropout value, should be 0 <= dropout < 1"; - size_t num_inputs = (param_.mode == rnn_enum::kLstm) ? 4 : 3; + size_t num_inputs = GetNumInputArguments(param_); + // kOut size_t num_outputs = 1; if (param_.state_outputs) { @@ -890,15 +975,16 @@ class RNNOp { cx_ptr = (in_data[rnn_enum::kStateCell].get(s)).dptr_; dcx_ptr = (in_grad[rnn_enum::kStateCell].get(s)).dptr_; } - if ((param_.mode == rnn_enum::kLstm) && param_.state_outputs) + if ((param_.mode == rnn_enum::kLstm) && param_.state_outputs) { dcy_ptr = (out_grad[rnn_enum::kStateCellOut].get(s)).dptr_; + } #if MXNET_USE_CUDNN_RNN && defined(__CUDACC__) if (!init_cudnn_) { Init(ctx, s, in_data, out_data); } - #if USE_CUDNN_LSTM_PROJ + #if MXNET_USE_CUDNN_GE_7200 CUDNN_CALL(cudnnRNNBackwardDataEx(s->dnn_handle_, rnn_desc_, y_data_desc_, @@ -1038,19 +1124,19 @@ class RNNOp { } } - private: inline void Init(const OpContext &ctx, mshadow::Stream *s, const std::vector &in_data, const std::vector &out_data) { using namespace mshadow; - size_t num_inputs = (param_.mode == rnn_enum::kLstm) ? 4 : 3; + + size_t num_inputs = GetNumInputArguments(param_); // kOut size_t num_outputs = 1; if (param_.state_outputs) { // kOut, kStateOut, kStateCellOut - num_outputs = (param_.mode == rnn_enum::kLstm) ? 3 : 2; + num_outputs = (param_.mode == rnn_enum::kLstm) ? 3U : 2U; } CHECK_EQ(in_data.size(), num_inputs); @@ -1130,7 +1216,7 @@ class RNNOp { strideA[0] = dimA[2] * dimA[1]; strideA[1] = dimA[2]; strideA[2] = 1; - #if USE_CUDNN_LSTM_PROJ + #if MXNET_USE_CUDNN_GE_7200 int dimB[3]; int strideB[3]; dimB[0] = param_.num_layers * (param_.bidirectional ? 2 : 1); @@ -1141,7 +1227,7 @@ class RNNOp { strideB[1] = dimB[2]; strideB[2] = 1; #endif - #if USE_CUDNN_LSTM_PROJ + #if MXNET_USE_CUDNN_GE_7200 CUDNN_CALL(cudnnSetTensorNdDescriptor(hx_desc_, dtype_, 3, @@ -1159,7 +1245,7 @@ class RNNOp { 3, dimA, strideA)); - #if USE_CUDNN_LSTM_PROJ + #if MXNET_USE_CUDNN_GE_7200 CUDNN_CALL(cudnnSetTensorNdDescriptor(hy_desc_, dtype_, 3, @@ -1177,7 +1263,7 @@ class RNNOp { 3, dimA, strideA)); - #if USE_CUDNN_LSTM_PROJ + #if MXNET_USE_CUDNN_GE_7200 CUDNN_CALL(cudnnSetTensorNdDescriptor(dhx_desc_, dtype_, 3, @@ -1195,7 +1281,7 @@ class RNNOp { 3, dimA, strideA)); - #if USE_CUDNN_LSTM_PROJ + #if MXNET_USE_CUDNN_GE_7200 CUDNN_CALL(cudnnSetTensorNdDescriptor(dhy_desc_, dtype_, 3, @@ -1258,12 +1344,13 @@ class RNNOp { } #if CUDNN_VERSION >= 7200 if (GetEnvAllowTensorCore() && GetEnvAllowTensorCoreConversion() && - (DataType::kFlag != kFloat16)) + (DataType::kFlag != kFloat16)) { math_type = CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION; + } #endif CUDNN_CALL(cudnnSetRNNMatrixMathType(rnn_desc_, math_type)); #endif - #if USE_CUDNN_LSTM_PROJ + #if MXNET_USE_CUDNN_GE_7200 if (param_.projection_size.has_value()) { CUDNN_CALL(cudnnSetRNNProjectionLayers(s->dnn_handle_, rnn_desc_, @@ -1272,6 +1359,13 @@ class RNNOp { } #endif // Get temp space sizes + + #if MXNET_USE_CUDNN_GE_7200 + if (param_.use_sequence_length) { + CUDNN_CALL(cudnnSetRNNPaddingMode(rnn_desc_, CUDNN_RNN_PADDED_IO_ENABLED)); + } + #endif + CUDNN_CALL(cudnnGetRNNWorkspaceSize(s->dnn_handle_, rnn_desc_, param_.seq_length_, @@ -1360,8 +1454,9 @@ class RNNOp { size_t workspace_byte_, reserve_space_byte_, dropout_byte_; int workspace_size_; std::vector x_desc_vec_, y_desc_vec_, dx_desc_vec_, dy_desc_vec_; - #if USE_CUDNN_LSTM_PROJ + #if MXNET_USE_CUDNN_GE_7200 cudnnRNNDataDescriptor_t x_data_desc_, y_data_desc_, dx_data_desc_, dy_data_desc_; + DType padding_fill_ = 0; #endif cudnnTensorDescriptor_t hx_desc_, cx_desc_; cudnnTensorDescriptor_t hy_desc_, cy_desc_; @@ -1387,13 +1482,22 @@ static OpStatePtr CreateRNNState(const nnvm::NodeAttrs &attrs, const std::vector &in_types) { const RNNParam& param = nnvm::get(attrs.parsed); OpStatePtr state = OpStatePtr(); - MSHADOW_REAL_TYPE_SWITCH(in_types[rnn_enum::kData], DType, { + int dtype = in_types[rnn_enum::kData]; + int itype = dtype; + if (param.use_sequence_length) { + itype = in_types[rnn_enum::kSequenceLength]; + if (param.mode == rnn_enum::kLstm) itype -= 1; + } + + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { + MSHADOW_TYPE_SWITCH(itype, IType, { if (ctx.dev_type == kGPU) { - state = OpStatePtr::Create>(param, ctx); + state = OpStatePtr::Create>(param, ctx); } else { - state = OpStatePtr::Create>(param, ctx); + state = OpStatePtr::Create>(param, ctx); } }); + }); return state; } @@ -1404,10 +1508,18 @@ void RNNStatefulCompute(const OpStatePtr& state, const std::vector& req, const std::vector& outputs) { int dtype = inputs[rnn_enum::kData].type_flag_; + + // Hacky. This relies on fact that seq-len type is either the last input, + // or we aren't using seq-len input and this type should be same as dtype. + // Would prefer direct access to RNNParam object here but not sure how to get. + int itype = inputs[inputs.size()-1].type_flag_; + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - RNNOp& op = state.get_state>(); - op.Forward(ctx, inputs, req, outputs); - }); + MSHADOW_TYPE_SWITCH(itype, IType, { + RNNOp& op = state.get_state>(); + op.Forward(ctx, inputs, req, outputs); + }); + }); } /* @@ -1435,25 +1547,33 @@ void RNNStatefulGradCompute(const OpStatePtr& state, const std::vector &in_grad = outputs; int dtype = inputs[rnn_enum::kData].type_flag_; - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - RNNOp& op = state.get_state>(); - const RNNParam& param = op.param_; - int index = 5; - if (param.state_outputs) { - out_data.push_back(inputs[index++]); - out_grad.push_back(inputs[index++]); - } - if (param.mode == rnn_enum::kLstm) { - in_data.push_back(inputs[index++]); - if (param.state_outputs) { - out_data.push_back(inputs[index++]); - out_grad.push_back(inputs[index]); - } - } + // Hacky. This relies on fact that seq-len type is either the last input, + // or we aren't using seq-len input and this type should be same as dtype. + // Would prefer direct access to RNNParam object here but not sure how to get. + int itype = inputs[inputs.size()-1].type_flag_; - op.Backward(ctx, out_grad, in_data, out_data, req, in_grad); - }); + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { + MSHADOW_TYPE_SWITCH(itype, IType, { + RNNOp& op = state.get_state>(); + const RNNParam& param = op.param_; + int index = 5; + if (param.state_outputs) { + out_data.push_back(inputs[index++]); + out_grad.push_back(inputs[index++]); + } + + if (param.mode == rnn_enum::kLstm) { + in_data.push_back(inputs[index++]); + if (param.state_outputs) { + out_data.push_back(inputs[index++]); + out_grad.push_back(inputs[index]); + } + } + + op.Backward(ctx, out_grad, in_data, out_data, req, in_grad); + }); + }); } } // namespace op diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc index 7012a3c22f50..296d57eb4713 100644 --- a/src/operator/rnn.cc +++ b/src/operator/rnn.cc @@ -23,6 +23,9 @@ * \brief * \author Sebastian Bodenstein */ + +#include + #include "./rnn-inl.h" namespace mxnet { @@ -30,11 +33,20 @@ namespace op { DMLC_REGISTER_PARAMETER(RNNParam); static inline std::vector ListArguments(const RNNParam& param_) { + // All RNNs start off with same 3 input arguments + std::vector arguments{"data", "parameters", "state"}; + + // LSTMs also have an additional state_cell argument if (param_.mode == rnn_enum::kLstm) { - return {"data", "parameters", "state", "state_cell"}; - } else { - return {"data", "parameters", "state"}; + arguments.emplace_back("state_cell"); } + + // All RNNs have option of additional sequence_length argument + if (param_.use_sequence_length) { + arguments.emplace_back("sequence_length"); + } + + return arguments; } static bool RNNShape(const nnvm::NodeAttrs& attrs, @@ -42,13 +54,13 @@ static bool RNNShape(const nnvm::NodeAttrs& attrs, std::vector *out_shape) { const RNNParam& param_ = nnvm::get(attrs.parsed); using namespace mshadow; - if (param_.mode == rnn_enum::kLstm) { - CHECK_EQ(in_shape->size(), 4U) << "Needed input:[data, parameters, state, cell_state]," - << " got in_shape->size(): " << in_shape->size(); - } else { - CHECK_EQ(in_shape->size(), 3U) << - "Needed input:[data, parameters, state], got in_shape->size(): " << in_shape->size(); - } + + // Query param_ object to figure out what the expectd input arguments are + std::vector expected_arguments = ListArguments(param_); + + CHECK_EQ(in_shape->size(), expected_arguments.size()) << "Input shape mismatch. Expected " << + expected_arguments.size() << " input parameters but got " << in_shape->size() << "."; + const TShape &dshape = (*in_shape)[rnn_enum::kData]; if (!mxnet::ndim_is_known(dshape)) return false; CHECK_EQ(dshape.ndim(), 3U) \ @@ -77,6 +89,15 @@ static bool RNNShape(const nnvm::NodeAttrs& attrs, param_.mode, param_.projection_size); SHAPE_ASSIGN_CHECK(*in_shape, rnn_enum::kParams, Shape1(param_size)); + + // Check on sequence_length shape if using + if (param_.use_sequence_length) { + size_t seq_len_input_idx = rnn_enum::kSequenceLength; + if (param_.mode != rnn_enum::kLstm) --seq_len_input_idx; + + SHAPE_ASSIGN_CHECK(*in_shape, seq_len_input_idx, Shape1(batch_size)); + } + out_shape->clear(); // output: [sequence len, batch, output size] TShape oshape = dshape; @@ -106,6 +127,7 @@ static bool RNNShape(const nnvm::NodeAttrs& attrs, out_shape->push_back(cellStateShape); } } + return true; } @@ -113,18 +135,24 @@ static bool RNNType(const nnvm::NodeAttrs& attrs, std::vector *in_type, std::vector *out_type) { const RNNParam& param_ = nnvm::get(attrs.parsed); - if (param_.mode == rnn_enum::kLstm) { - CHECK_EQ(in_type->size(), 4U); - } else { - CHECK_EQ(in_type->size(), 3U); - } + + CHECK_EQ(in_type->size(), GetNumInputArguments(param_)); + + size_t seq_len_input_idx = rnn_enum::kSequenceLength; + if (param_.mode != rnn_enum::kLstm) --seq_len_input_idx; + int dtype = (*in_type)[0]; CHECK_NE(dtype, -1) << "First input must have specified type"; + std::vector arguments = ListArguments(param_); for (size_t i = 0; i < in_type->size(); ++i) { if ((*in_type)[i] == -1) { TYPE_ASSIGN_CHECK(*in_type, i, dtype); } else { - UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments(param_)[i]); + // If using sequence length argument, it has its own indexing type + // All other input arguments must match the main data type + if (!(param_.use_sequence_length && i == seq_len_input_idx)) { + UNIFORM_TYPE_CHECK((*in_type)[i], dtype, arguments[i]); + } } } out_type->clear(); @@ -132,8 +160,9 @@ static bool RNNType(const nnvm::NodeAttrs& attrs, if (param_.state_outputs) { out_type->push_back(dtype); // Deal with lstm cell state - if (param_.mode == rnn_enum::kLstm) + if (param_.mode == rnn_enum::kLstm) { out_type->push_back(dtype); + } } return true; } @@ -220,7 +249,7 @@ The definition of GRU here is slightly different from paper but compatible with .set_attr_parser(ParamParser) .set_num_inputs([](const NodeAttrs& attrs) { const RNNParam& params = nnvm::get(attrs.parsed); - return params.mode == rnn_enum::kLstm ? 4 : 3; + return GetNumInputArguments(params); }) .set_num_outputs([](const NodeAttrs& attrs) { const RNNParam& params = nnvm::get(attrs.parsed); @@ -246,13 +275,13 @@ The definition of GRU here is slightly different from paper but compatible with .set_attr("FResourceRequestEx", [](const NodeAttrs& attrs, const int dev_mask, const DispatchMode dispatch_mode) { std::vector request; - const RNNParam& param = nnvm::get(attrs.parsed); - if (param.p == 0) return request; if (dev_mask == kGPU) { #if MXNET_USE_CUDNN_RNN - if (1.0f - param.p > 0) { + request.emplace_back(ResourceRequest::kTempSpace); + + const RNNParam& param = nnvm::get(attrs.parsed); + if (param.p != 0 && 1.0f - param.p > 0) { request.emplace_back(ResourceRequest::kCuDNNDropoutDesc); - return request; } #endif } @@ -264,12 +293,15 @@ The definition of GRU here is slightly different from paper but compatible with .add_argument("state", "NDArray-or-Symbol", "initial hidden state of the RNN") .add_argument("state_cell", "NDArray-or-Symbol", "initial cell state for LSTM networks (only for LSTM)") +.add_argument("sequence_length", "NDArray-or-Symbol", + "Vector of valid sequence lengths for each element in batch. (Only used if" + " use_sequence_length kwarg is True)") .add_arguments(RNNParam::__FIELDS__()); NNVM_REGISTER_OP(_backward_RNN) .set_num_outputs([](const NodeAttrs& attrs) { const RNNParam& params = nnvm::get(attrs.parsed); - return params.mode == rnn_enum::kLstm ? 4 : 3; + return GetNumInputArguments(params); }) .set_attr_parser(ParamParser) .set_attr("TIsLayerOpBackward", true) diff --git a/src/operator/rnn.cu b/src/operator/rnn.cu index 77bb95522711..093a64a0623a 100644 --- a/src/operator/rnn.cu +++ b/src/operator/rnn.cu @@ -30,6 +30,7 @@ namespace mxnet { namespace op { + NNVM_REGISTER_OP(RNN) .set_attr("FStatefulCompute", RNNStatefulCompute); diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index 1c5a5835e6f9..95835fd77e9e 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -24,6 +24,7 @@ import unittest import random import mxnet as mx +import mxnet.ndarray as nd import numpy as np import unittest import math @@ -225,6 +226,55 @@ def forward(self, inpt): assert_allclose(net(data).asnumpy(), ref_net(data).asnumpy()) +def check_layer_bidirectional_varseqlen(size, in_size): + class RefBiLSTMVarSeqLen(gluon.Block): + def __init__(self, size, **kwargs): + super(RefBiLSTMVarSeqLen, self).__init__(**kwargs) + with self.name_scope(): + self._lstm_fwd = gluon.rnn.LSTM(size, bidirectional=False, prefix='l0') + self._lstm_bwd = gluon.rnn.LSTM(size, bidirectional=False, prefix='r0') + + def forward(self, inpt, sequence_length): + fwd = self._lstm_fwd(inpt) + bwd_inpt = nd.SequenceReverse(inpt, sequence_length=sequence_length, use_sequence_length=True) + bwd = self._lstm_bwd(bwd_inpt) + bwd = nd.SequenceReverse(bwd, sequence_length=sequence_length, use_sequence_length=True) + return nd.concat(fwd, bwd, dim=2) + weights = {} + for d in ['l', 'r']: + weights['lstm_{}0_i2h_weight'.format(d)] = mx.random.uniform(shape=(size*4, in_size)) + weights['lstm_{}0_h2h_weight'.format(d)] = mx.random.uniform(shape=(size*4, size)) + weights['lstm_{}0_i2h_bias'.format(d)] = mx.random.uniform(shape=(size*4,)) + weights['lstm_{}0_h2h_bias'.format(d)] = mx.random.uniform(shape=(size*4,)) + + net = gluon.rnn.LSTM(size, bidirectional=True, use_sequence_length=True, prefix='lstm_') + ref_net = RefBiLSTMVarSeqLen(size, prefix='lstm_') + net.initialize() + ref_net.initialize() + net_params = net.collect_params() + ref_net_params = ref_net.collect_params() + for k in weights: + net_params[k].set_data(weights[k]) + ref_net_params[k.replace('l0', 'l0l0').replace('r0', 'r0l0')].set_data(weights[k]) + + + batch_size = 10 + num_timesteps = 11 + data = mx.random.uniform(shape=(num_timesteps, batch_size, in_size)) + + # TODO: figure out why int32 doesn't work here + sequence_length = nd.random.randint(1, num_timesteps+1, shape=(batch_size)).astype("float") + + net_output = net(data, sequence_length=sequence_length).asnumpy() + ref_net_output = ref_net(data, sequence_length).asnumpy() + sequence_length_np = sequence_length.asnumpy().astype("int32") + + # TODO: test state return value as well output + # Only compare the valid sections for each batch entry + for b in range(batch_size): + assert_allclose(net_output[:sequence_length_np[b], b], ref_net_output[:sequence_length_np[b], b]) + + @with_seed() @assert_raises_cudnn_not_satisfied(min_version='5.1.10') def test_layer_bidirectional(): @@ -236,6 +286,11 @@ def test_layer_bidirectional(): def test_layer_bidirectional_proj(): check_layer_bidirectional(7, 5, 3) +@with_seed() +@assert_raises_cudnn_not_satisfied(min_version='7.2.1') +def test_layer_bidirectional_varseqlength(): + check_layer_bidirectional_varseqlen(7, 5) + @with_seed() @assert_raises_cudnn_not_satisfied(min_version='5.1.10')