diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index e85f6b4af3fd..8812c3d76166 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -167,7 +167,7 @@ struct RNNParam : public dmlc::Parameter { int mode; float p; int seq_length_, batch_size_, input_size_; - bool lstm_q_; // whether type is lstm + bool use_sequence_length; dmlc::optional projection_size; dmlc::optional lstm_state_clip_min, lstm_state_clip_max; @@ -225,6 +225,13 @@ struct RNNParam : public dmlc::Parameter { } }; +inline size_t GetNumInputArguments(RNNParam param_) { + size_t num_inputs = (param_.mode == rnn_enum::kLstm) ? 4U : 3U; + if (param_.use_sequence_length) + num_inputs += 1U; + return num_inputs; +} + /** * @params: ws: Temp workspace for gemm's output storage. * rs: Reserve space of forward intermediate data used for training. @@ -389,7 +396,7 @@ void RNNBackward(DType* ws, } } -template +template class RNNOp { public: RNNParam param_; @@ -552,8 +559,9 @@ class RNNOp { using namespace mshadow; using namespace mshadow::expr; CHECK(param_.p >= 0.0f && param_.p < 1.0f) - << "unsupported dropout value, should be 0 <= dropout < 1"; - size_t num_inputs = (param_.mode == rnn_enum::kLstm) ? 4 : 3; + << "unsupported dropout value, should be 0 <= dropout < 1"; + size_t num_inputs = GetNumInputArguments(param_); + // kOut size_t num_outputs = 1; if (param_.state_outputs) { @@ -561,12 +569,6 @@ class RNNOp { num_outputs = (param_.mode == rnn_enum::kLstm) ? 3 : 2; } - size_t in_expected = param_.lstm_q_ ? 4 : 3; - size_t out_expected = param_.lstm_q_ ? 3 : 2; - - if (param_.use_sequence_length) - num_inputs += 1; - CHECK_EQ(in_data.size(), num_inputs); CHECK_EQ(out_data.size(), num_outputs); Stream *s = ctx.get_stream(); @@ -592,276 +594,286 @@ class RNNOp { if (param_.state_outputs) { hy_ptr = out_data[rnn_enum::kStateOut].dptr(); - sequence_length_array_cpu.resize(param_.batch_size_); - if (param_.use_sequence_length) { - size_t seq_len_input_idx = rnn_enum::kSequenceLength; - if (!param_.lstm_q_) - --seq_len_input_idx; - IType *sequence_length_ptr_gpu = (in_data[seq_len_input_idx].get(s)).dptr_; - - // Need to copy from GPU -> CPU, becuase cuDNN API requires this array on CPU memory. - // TODO: In future, allow users to pass this array on the CPU so we don't have to do this copy - // For now it is required as several places in backend assume that all data arrays share - // the same context. - CUDA_CALL(cudaMemcpy(sequence_length_array_cpu.data(), sequence_length_ptr_gpu, sizeof(IType) * param_.batch_size_, - cudaMemcpyDeviceToHost)); - } - DType* cx_ptr = NULL; - DType* cy_ptr = NULL; - if (param_.mode == rnn_enum::kLstm) - cx_ptr = (in_data[rnn_enum::kStateCell].get(s)).dptr_; - if (param_.mode == rnn_enum::kLstm && param_.state_outputs) - cy_ptr = (out_data[rnn_enum::kStateCellOut].get(s)).dptr_; - - CHECK_EQ(x.CheckContiguous(), true); - CHECK_EQ(w.CheckContiguous(), true); - CHECK_EQ(hx.CheckContiguous(), true); - CHECK_EQ(y.CheckContiguous(), true); - - #if MXNET_USE_CUDNN_RNN && defined(__CUDACC__) - if (!init_cudnn_) { - Init(ctx, s, in_data, out_data); - } + if (param_.use_sequence_length) { +#if MXNET_USE_CUDNN_RNN && CUDNN_VERSION >= 7200 + if (ctx_.dev_type == kCPU) { + LOG(FATAL) << "RNN use_sequence_length option is only available for cuDNN at the moment. Not supported on CPU"; + } + + // We can assume we are on GPU for now + size_t seq_len_input_idx = rnn_enum::kSequenceLength; + if (param_.mode != rnn_enum::kLstm) + seq_len_input_idx -= 1; + IType *sequence_length_ptr_gpu = (in_data[seq_len_input_idx].get(s)).dptr_; + + // Need to copy from GPU -> CPU, becuase cuDNN API requires this array on CPU memory. + // TODO: In future, allow users to pass this array on the CPU so we don't have to do this copy + // For now it is required as several places in backend assume that all data arrays share + // the same context. + sequence_length_array_cpu.resize(param_.batch_size_); + CUDA_CALL(cudaMemcpy(sequence_length_array_cpu.data(), sequence_length_ptr_gpu, sizeof(IType) * param_.batch_size_, + cudaMemcpyDeviceToHost)); +#else + LOG(FATAL) << "RNN use_sequence_length option is only available for cuDNN version >= 7.2"; +#endif + } + DType* cx_ptr = NULL; + DType* cy_ptr = NULL; + if (param_.mode == rnn_enum::kLstm) + cx_ptr = (in_data[rnn_enum::kStateCell].get(s)).dptr_; + if (param_.mode == rnn_enum::kLstm && param_.state_outputs) + cy_ptr = (out_data[rnn_enum::kStateCellOut].get(s)).dptr_; + + CHECK_EQ(x.CheckContiguous(), true); + CHECK_EQ(w.CheckContiguous(), true); + CHECK_EQ(hx.CheckContiguous(), true); + CHECK_EQ(y.CheckContiguous(), true); + +#if MXNET_USE_CUDNN_RNN && defined(__CUDACC__) + if (!init_cudnn_) { + Init(ctx, s, in_data, out_data); + } - #if USE_CUDNN_LSTM_PROJ - std::vector seqLengthArray; +#if USE_CUDNN_LSTM_PROJ + std::vector seqLengthArray; - cudnnRNNDataLayout_t layout_t; + cudnnRNNDataLayout_t layout_t; - if (param_.use_sequence_length) { - // sequence_length_ptr_gpu is fo type Itype, need to convert to vector - seqLengthArray = std::vector(sequence_length_array_cpu.begin(), sequence_length_array_cpu.end()); - layout_t = CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED; - } - else { - seqLengthArray = std::vector(param_.batch_size_, param_.seq_length_); - layout_t = CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED; - } + if (param_.use_sequence_length) { + // sequence_length_ptr_gpu is fo type Itype, need to convert to vector + seqLengthArray = std::vector(sequence_length_array_cpu.begin(), sequence_length_array_cpu.end()); + layout_t = CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED; + } + else { + seqLengthArray = std::vector(param_.batch_size_, param_.seq_length_); + layout_t = CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED; + } - CUDNN_CALL(cudnnSetRNNDataDescriptor(x_data_desc_, - dtype_, - layout_t, - param_.seq_length_, - param_.batch_size_, - param_.input_size_, - seqLengthArray.data(), - (void*)&padding_fill_)); - int out_size = - (param_.projection_size.has_value()) ? param_.projection_size.value() : param_.state_size; - out_size = (param_.bidirectional) ? (out_size * 2) : out_size; - CUDNN_CALL(cudnnSetRNNDataDescriptor(y_data_desc_, - dtype_, - layout_t, - param_.seq_length_, - param_.batch_size_, - out_size, - seqLengthArray.data(), - (void*)&padding_fill_)); - if (ctx.is_train) { - CUDNN_CALL(cudnnSetRNNDataDescriptor(dx_data_desc_, - dtype_, - layout_t, - param_.seq_length_, - param_.batch_size_, - param_.input_size_, - seqLengthArray.data(), - (void*)&padding_fill_)); - CUDNN_CALL(cudnnSetRNNDataDescriptor(dy_data_desc_, - dtype_, - layout_t, - param_.seq_length_, - param_.batch_size_, - out_size, - seqLengthArray.data(), - (void*)&padding_fill_)); - } - #endif + CUDNN_CALL(cudnnSetRNNDataDescriptor(x_data_desc_, + dtype_, + layout_t, + param_.seq_length_, + param_.batch_size_, + param_.input_size_, + seqLengthArray.data(), + (void*)&padding_fill_)); + int out_size = + (param_.projection_size.has_value()) ? param_.projection_size.value() : param_.state_size; + out_size = (param_.bidirectional) ? (out_size * 2) : out_size; + CUDNN_CALL(cudnnSetRNNDataDescriptor(y_data_desc_, + dtype_, + layout_t, + param_.seq_length_, + param_.batch_size_, + out_size, + seqLengthArray.data(), + (void*)&padding_fill_)); + if (ctx.is_train) { + CUDNN_CALL(cudnnSetRNNDataDescriptor(dx_data_desc_, + dtype_, + layout_t, + param_.seq_length_, + param_.batch_size_, + param_.input_size_, + seqLengthArray.data(), + (void*)&padding_fill_)); + CUDNN_CALL(cudnnSetRNNDataDescriptor(dy_data_desc_, + dtype_, + layout_t, + param_.seq_length_, + param_.batch_size_, + out_size, + seqLengthArray.data(), + (void*)&padding_fill_)); + } +#endif - #if USE_CUDNN_LSTM_PROJ - bool clip_state = param_.lstm_state_clip_min.has_value(); - bool clip_nan = param_.lstm_state_clip_nan; - CUDNN_CALL(cudnnRNNSetClip(s->dnn_handle_, - rnn_desc_, - clip_state ? CUDNN_RNN_CLIP_MINMAX : CUDNN_RNN_CLIP_NONE, - clip_nan ? CUDNN_NOT_PROPAGATE_NAN : CUDNN_PROPAGATE_NAN, - clip_state ? param_.lstm_state_clip_min.value() : 0.0, - clip_state ? param_.lstm_state_clip_max.value() : 0.0)); - #endif +#if USE_CUDNN_LSTM_PROJ + bool clip_state = param_.lstm_state_clip_min.has_value(); + bool clip_nan = param_.lstm_state_clip_nan; + CUDNN_CALL(cudnnRNNSetClip(s->dnn_handle_, + rnn_desc_, + clip_state ? CUDNN_RNN_CLIP_MINMAX : CUDNN_RNN_CLIP_NONE, + clip_nan ? CUDNN_NOT_PROPAGATE_NAN : CUDNN_PROPAGATE_NAN, + clip_state ? param_.lstm_state_clip_min.value() : 0.0, + clip_state ? param_.lstm_state_clip_max.value() : 0.0)); +#endif - if (ctx.is_train) { - #if USE_CUDNN_LSTM_PROJ - CUDNN_CALL(cudnnRNNForwardTrainingEx(s->dnn_handle_, - rnn_desc_, - x_data_desc_, - x.dptr_, - hx_desc_, - hx.dptr_, - cx_desc_, - cx_ptr, - w_desc_, - w.dptr_, - y_data_desc_, - y.dptr_, - hy_desc_, - hy_ptr, - cy_desc_, - cy_ptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - temp_space_.dptr, - workspace_byte_, - reserve_space_.dptr, - reserve_space_byte_)); - #else - CUDNN_CALL(cudnnRNNForwardTraining(s->dnn_handle_, - rnn_desc_, - param_.seq_length_, - x_desc_vec_.data(), - x.dptr_, - hx_desc_, - hx.dptr_, - cx_desc_, - cx_ptr, - w_desc_, - w.dptr_, - y_desc_vec_.data(), - y.dptr_, - hy_desc_, - hy_ptr, - cy_desc_, - cy_ptr, - temp_space_.dptr, - workspace_byte_, - reserve_space_.dptr, - reserve_space_byte_)); - #endif - } else { - #if USE_CUDNN_LSTM_PROJ - CUDNN_CALL(cudnnRNNForwardInferenceEx(s->dnn_handle_, - rnn_desc_, - x_data_desc_, - x.dptr_, - hx_desc_, - hx.dptr_, - cx_desc_, - cx_ptr, - w_desc_, - w.dptr_, - y_data_desc_, - y.dptr_, - hy_desc_, - hy_ptr, - cy_desc_, - cy_ptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - temp_space_.dptr, - workspace_byte_)); - #else - CUDNN_CALL(cudnnRNNForwardInference(s->dnn_handle_, - rnn_desc_, - param_.seq_length_, - x_desc_vec_.data(), - x.dptr_, - hx_desc_, - hx.dptr_, - cx_desc_, - cx_ptr, - w_desc_, - w.dptr_, - y_desc_vec_.data(), - y.dptr_, - hy_desc_, - hy_ptr, - cy_desc_, - cy_ptr, - temp_space_.dptr, - workspace_byte_)); - #endif - } - #endif + if (ctx.is_train) { +#if USE_CUDNN_LSTM_PROJ + CUDNN_CALL(cudnnRNNForwardTrainingEx(s->dnn_handle_, + rnn_desc_, + x_data_desc_, + x.dptr_, + hx_desc_, + hx.dptr_, + cx_desc_, + cx_ptr, + w_desc_, + w.dptr_, + y_data_desc_, + y.dptr_, + hy_desc_, + hy_ptr, + cy_desc_, + cy_ptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + temp_space_.dptr, + workspace_byte_, + reserve_space_.dptr, + reserve_space_byte_)); +#else + CUDNN_CALL(cudnnRNNForwardTraining(s->dnn_handle_, + rnn_desc_, + param_.seq_length_, + x_desc_vec_.data(), + x.dptr_, + hx_desc_, + hx.dptr_, + cx_desc_, + cx_ptr, + w_desc_, + w.dptr_, + y_desc_vec_.data(), + y.dptr_, + hy_desc_, + hy_ptr, + cy_desc_, + cy_ptr, + temp_space_.dptr, + workspace_byte_, + reserve_space_.dptr, + reserve_space_byte_)); +#endif + } else { +#if USE_CUDNN_LSTM_PROJ + CUDNN_CALL(cudnnRNNForwardInferenceEx(s->dnn_handle_, + rnn_desc_, + x_data_desc_, + x.dptr_, + hx_desc_, + hx.dptr_, + cx_desc_, + cx_ptr, + w_desc_, + w.dptr_, + y_data_desc_, + y.dptr_, + hy_desc_, + hy_ptr, + cy_desc_, + cy_ptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + temp_space_.dptr, + workspace_byte_)); +#else + CUDNN_CALL(cudnnRNNForwardInference(s->dnn_handle_, + rnn_desc_, + param_.seq_length_, + x_desc_vec_.data(), + x.dptr_, + hx_desc_, + hx.dptr_, + cx_desc_, + cx_ptr, + w_desc_, + w.dptr_, + y_desc_vec_.data(), + y.dptr_, + hy_desc_, + hy_ptr, + cy_desc_, + cy_ptr, + temp_space_.dptr, + workspace_byte_)); +#endif + } +#endif - if (ctx_.dev_type == kCPU) { - // allocate temp space - const size_t work_cpu_space_size = + if (ctx_.dev_type == kCPU) { + // allocate temp space + const size_t work_cpu_space_size = GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, param_.state_size, direction, param_.mode); - if (temp_init_space_ && temp_cpu_space_size_ < work_cpu_space_size) { + if (temp_init_space_ && temp_cpu_space_size_ < work_cpu_space_size) { Storage::Get()->Free(temp_cpu_space_); temp_init_space_ = false; - } - if (!temp_init_space_) { - temp_cpu_space_ = Storage::Get()->Alloc + } + if (!temp_init_space_) { + temp_cpu_space_ = Storage::Get()->Alloc (work_cpu_space_size * sizeof(DType), Context::CPU()); - temp_cpu_space_size_ = work_cpu_space_size; - temp_init_space_ = true; - } - DType* work_cpu_space = static_cast(temp_cpu_space_.dptr); - if (ctx.is_train) { - const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction, - param_.seq_length_, param_.batch_size_, - param_.state_size, param_.mode); - if (init_space_ && reserve_cpu_space_size_ < r_size) { - Storage::Get()->Free(reserve_cpu_space_); - init_space_ = false; - } - if (!init_space_) { - reserve_cpu_space_ = Storage::Get()->Alloc(r_size * sizeof(DType), Context::CPU()); - reserve_cpu_space_size_ = r_size; - init_space_ = true; - } - - DType* reserve_space_ptr = static_cast(reserve_cpu_space_.dptr); - - RNNForwardTraining(work_cpu_space, - reserve_space_ptr, - param_.state_outputs, - param_.num_layers, - direction, - param_.seq_length_, - param_.batch_size_, - param_.input_size_, - param_.state_size, - x.dptr_, - hx.dptr_, - cx_ptr, - w.dptr_, - b_ptr, - y.dptr_, - hy_ptr, - cy_ptr, - param_.p, - param_.mode); - } else { - RNNForwardInference(work_cpu_space, - param_.state_outputs, - param_.num_layers, - direction, - param_.seq_length_, - param_.batch_size_, - param_.input_size_, - param_.state_size, - x.dptr_, - hx.dptr_, - cx_ptr, - w.dptr_, - b_ptr, - y.dptr_, - hy_ptr, - cy_ptr, - param_.mode); + temp_cpu_space_size_ = work_cpu_space_size; + temp_init_space_ = true; + } + DType* work_cpu_space = static_cast(temp_cpu_space_.dptr); + if (ctx.is_train) { + const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction, + param_.seq_length_, param_.batch_size_, + param_.state_size, param_.mode); + if (init_space_ && reserve_cpu_space_size_ < r_size) { + Storage::Get()->Free(reserve_cpu_space_); + init_space_ = false; + } + if (!init_space_) { + reserve_cpu_space_ = Storage::Get()->Alloc(r_size * sizeof(DType), Context::CPU()); + reserve_cpu_space_size_ = r_size; + init_space_ = true; + } + + DType* reserve_space_ptr = static_cast(reserve_cpu_space_.dptr); + + RNNForwardTraining(work_cpu_space, + reserve_space_ptr, + param_.state_outputs, + param_.num_layers, + direction, + param_.seq_length_, + param_.batch_size_, + param_.input_size_, + param_.state_size, + x.dptr_, + hx.dptr_, + cx_ptr, + w.dptr_, + b_ptr, + y.dptr_, + hy_ptr, + cy_ptr, + param_.p, + param_.mode); + } else { + RNNForwardInference(work_cpu_space, + param_.state_outputs, + param_.num_layers, + direction, + param_.seq_length_, + param_.batch_size_, + param_.input_size_, + param_.state_size, + x.dptr_, + hx.dptr_, + cx_ptr, + w.dptr_, + b_ptr, + y.dptr_, + hy_ptr, + cy_ptr, + param_.mode); + } } } } @@ -877,7 +889,8 @@ class RNNOp { CHECK(param_.p >= 0.0f && param_.p < 1.0f) << "unsupported dropout value, should be 0 <= dropout < 1"; - size_t num_inputs = (param_.mode == rnn_enum::kLstm) ? 4 : 3; + size_t num_inputs = GetNumInputArguments(param_); + // kOut size_t num_outputs = 1; if (param_.state_outputs) { @@ -1094,24 +1107,12 @@ class RNNOp { const std::vector &out_data) { using namespace mshadow; - size_t num_inputs; + size_t num_inputs = GetNumInputArguments(param_); // kOut size_t num_outputs = 1; if (param_.state_outputs) { // kOut, kStateOut, kStateCellOut - num_outputs = (param_.mode == rnn_enum::kLstm) ? 3 : 2; - } - if (param_.mode == rnn_enum::kLstm) { - if (param_.use_sequence_length) { - num_inputs = 5; // data, parameters, state, cell_state, sequence_length - } else { - num_inputs = 4; // data, parameters, state, cell_state - } - } else { - if (param_.use_sequence_length) - num_inputs = 4; // data, parameters, state, sequence_length - else - num_inputs = 3; // data, parameters, state + num_outputs = (param_.mode == rnn_enum::kLstm) ? 3U : 2U; } CHECK_EQ(in_data.size(), num_inputs); @@ -1427,8 +1428,8 @@ class RNNOp { #if USE_CUDNN_LSTM_PROJ cudnnRNNDataDescriptor_t x_data_desc_, y_data_desc_, dx_data_desc_, dy_data_desc_; DType padding_fill_ = 0; - std::vector sequence_length_array_cpu; #endif + std::vector sequence_length_array_cpu; cudnnTensorDescriptor_t hx_desc_, cx_desc_; cudnnTensorDescriptor_t hy_desc_, cy_desc_; cudnnTensorDescriptor_t dhx_desc_, dcx_desc_; @@ -1453,13 +1454,23 @@ static OpStatePtr CreateRNNState(const nnvm::NodeAttrs &attrs, const std::vector &in_types) { const RNNParam& param = nnvm::get(attrs.parsed); OpStatePtr state = OpStatePtr(); - MSHADOW_REAL_TYPE_SWITCH(in_types[rnn_enum::kData], DType, { - if (ctx.dev_type == kGPU) { - state = OpStatePtr::Create>(param, ctx); - } else { - state = OpStatePtr::Create>(param, ctx); - } - }); + int dtype = in_types[rnn_enum::kData]; + int itype = dtype; + if (param.use_sequence_length) { + itype = in_types[rnn_enum::kSequenceLength]; + if (param.mode == rnn_enum::kLstm) + itype -= 1; + } + + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { + MSHADOW_TYPE_SWITCH(itype, IType, { + if (ctx.dev_type == kGPU) { + state = OpStatePtr::Create>(param, ctx); + } else { + state = OpStatePtr::Create>(param, ctx); + } + }); + }); return state; } @@ -1470,10 +1481,18 @@ void RNNStatefulCompute(const OpStatePtr& state, const std::vector& req, const std::vector& outputs) { int dtype = inputs[rnn_enum::kData].type_flag_; + + // Hacky. This relies on fact that seq-len type is either the last input, + // or we aren't using seq-len input and this type should be same as dtype. + // Would prefer direct access to RNNParam object here but not sure how to get. + int itype = inputs[inputs.size()-1].type_flag_; + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - RNNOp& op = state.get_state>(); - op.Forward(ctx, inputs, req, outputs); - }); + MSHADOW_TYPE_SWITCH(itype, IType, { + RNNOp& op = state.get_state>(); + op.Forward(ctx, inputs, req, outputs); + }); + }); } /* @@ -1501,25 +1520,33 @@ void RNNStatefulGradCompute(const OpStatePtr& state, const std::vector &in_grad = outputs; int dtype = inputs[rnn_enum::kData].type_flag_; - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - RNNOp& op = state.get_state>(); - const RNNParam& param = op.param_; - int index = 5; - if (param.state_outputs) { - out_data.push_back(inputs[index++]); - out_grad.push_back(inputs[index++]); - } - if (param.mode == rnn_enum::kLstm) { - in_data.push_back(inputs[index++]); - if (param.state_outputs) { - out_data.push_back(inputs[index++]); - out_grad.push_back(inputs[index]); - } - } + // Hacky. This relies on fact that seq-len type is either the last input, + // or we aren't using seq-len input and this type should be same as dtype. + // Would prefer direct access to RNNParam object here but not sure how to get. + int itype = inputs[inputs.size()-1].type_flag_; - op.Backward(ctx, out_grad, in_data, out_data, req, in_grad); - }); + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { + MSHADOW_TYPE_SWITCH(itype, IType, { + RNNOp& op = state.get_state>(); + const RNNParam& param = op.param_; + int index = 5; + if (param.state_outputs) { + out_data.push_back(inputs[index++]); + out_grad.push_back(inputs[index++]); + } + + if (param.mode == rnn_enum::kLstm) { + in_data.push_back(inputs[index++]); + if (param.state_outputs) { + out_data.push_back(inputs[index++]); + out_grad.push_back(inputs[index]); + } + } + + op.Backward(ctx, out_grad, in_data, out_data, req, in_grad); + }); + }); } } // namespace op diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc index 6a3f3555a831..9e65e54069a1 100644 --- a/src/operator/rnn.cc +++ b/src/operator/rnn.cc @@ -23,6 +23,10 @@ * \brief * \author Sebastian Bodenstein */ + +//#include +#include + #include "./rnn-inl.h" namespace mxnet { @@ -30,11 +34,19 @@ namespace op { DMLC_REGISTER_PARAMETER(RNNParam); static inline std::vector ListArguments(const RNNParam& param_) { - if (param_.mode == rnn_enum::kLstm) { - return {"data", "parameters", "state", "state_cell"}; - } else { - return {"data", "parameters", "state"}; - } + + // All RNNs start off with same 3 input arguments + std::vector arguments{"data", "parameters", "state"}; + + // LSTMs also have an additional state_cell argument + if (param_.mode == rnn_enum::kLstm) + arguments.push_back("state_cell"); + + // All RNNs have option of additional sequence_length argument + if (param_.use_sequence_length) + arguments.push_back("sequence_length"); + + return arguments; } static bool RNNShape(const nnvm::NodeAttrs& attrs, @@ -42,13 +54,19 @@ static bool RNNShape(const nnvm::NodeAttrs& attrs, std::vector *out_shape) { const RNNParam& param_ = nnvm::get(attrs.parsed); using namespace mshadow; - if (param_.mode == rnn_enum::kLstm) { - CHECK_EQ(in_shape->size(), 4U) << "Needed input:[data, parameters, state, cell_state]," - << " got in_shape->size(): " << in_shape->size(); - } else { - CHECK_EQ(in_shape->size(), 3U) << - "Needed input:[data, parameters, state], got in_shape->size(): " << in_shape->size(); - } + + // Query param_ object to figure out what the expectd input arguments are + std::vector expected_arguments = ListArguments(param_); + + /* + std::stringstream expected_arguments_string_stream; + std::copy(expected_arguments.begin(), expected_arguments.end(), std::ostream_iterator(expected_arguments_string_stream, " ")); + + CHECK_EQ(in_shape->size(), expected_arguments.size()) << "Needed input:[" << expected_arguments_string_stream.str() << "]," + << " got in_shape->size(): " << in_shape->size(); + */ + CHECK_EQ(in_shape->size(), expected_arguments.size()) << "shape mismatch!!"; + const TShape &dshape = (*in_shape)[rnn_enum::kData]; if (!mxnet::ndim_is_known(dshape)) return false; CHECK_EQ(dshape.ndim(), 3U) \ @@ -77,6 +95,16 @@ static bool RNNShape(const nnvm::NodeAttrs& attrs, param_.mode, param_.projection_size); SHAPE_ASSIGN_CHECK(*in_shape, rnn_enum::kParams, Shape1(param_size)); + + // Check on sequence_length shape if using + if (param_.use_sequence_length) { + size_t seq_len_input_idx = rnn_enum::kSequenceLength; + if (param_.mode != rnn_enum::kLstm) + --seq_len_input_idx; + + SHAPE_ASSIGN_CHECK(*in_shape, seq_len_input_idx, Shape1(batch_size)); + } + out_shape->clear(); // output: [sequence len, batch, output size] TShape oshape = dshape; @@ -106,6 +134,7 @@ static bool RNNShape(const nnvm::NodeAttrs& attrs, out_shape->push_back(cellStateShape); } } + return true; } @@ -113,18 +142,25 @@ static bool RNNType(const nnvm::NodeAttrs& attrs, std::vector *in_type, std::vector *out_type) { const RNNParam& param_ = nnvm::get(attrs.parsed); - if (param_.mode == rnn_enum::kLstm) { - CHECK_EQ(in_type->size(), 4U); - } else { - CHECK_EQ(in_type->size(), 3U); - } + + CHECK_EQ(in_type->size(), GetNumInputArguments(param_)); + + size_t seq_len_input_idx = rnn_enum::kSequenceLength; + if (param_.mode != rnn_enum::kLstm) + --seq_len_input_idx; + int dtype = (*in_type)[0]; CHECK_NE(dtype, -1) << "First input must have specified type"; + std::vector arguments = ListArguments(param_); for (size_t i = 0; i < in_type->size(); ++i) { if ((*in_type)[i] == -1) { TYPE_ASSIGN_CHECK(*in_type, i, dtype); } else { - UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments(param_)[i]); + // If using sequence length argument, it has its own indexing type + // All other input arguments must match the main data type + if (!(param_.use_sequence_length && i == seq_len_input_idx)) { + UNIFORM_TYPE_CHECK((*in_type)[i], dtype, arguments[i]); + } } } out_type->clear(); @@ -220,7 +256,7 @@ The definition of GRU here is slightly different from paper but compatible with .set_attr_parser(ParamParser) .set_num_inputs([](const NodeAttrs& attrs) { const RNNParam& params = nnvm::get(attrs.parsed); - return params.mode == rnn_enum::kLstm ? 4 : 3; + return GetNumInputArguments(params); }) .set_num_outputs([](const NodeAttrs& attrs) { const RNNParam& params = nnvm::get(attrs.parsed); @@ -271,7 +307,7 @@ The definition of GRU here is slightly different from paper but compatible with NNVM_REGISTER_OP(_backward_RNN) .set_num_outputs([](const NodeAttrs& attrs) { const RNNParam& params = nnvm::get(attrs.parsed); - return params.mode == rnn_enum::kLstm ? 4 : 3; + return GetNumInputArguments(params); }) .set_attr_parser(ParamParser) .set_attr("TIsLayerOpBackward", true)