diff --git a/src/operator/cudnn_rnn-inl.h b/src/operator/cudnn_rnn-inl.h deleted file mode 100644 index cc8e4db404da..000000000000 --- a/src/operator/cudnn_rnn-inl.h +++ /dev/null @@ -1,863 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2016 by Contributors - * \file cudnn_rnn-inl.h - * \brief - * \author Sebastian Bodenstein -*/ -#ifndef MXNET_OPERATOR_CUDNN_RNN_INL_H_ -#define MXNET_OPERATOR_CUDNN_RNN_INL_H_ - -#define USE_CUDNN_LSTM_PROJ MXNET_USE_CUDNN == 1 && CUDNN_VERSION >= 7200 - -#include -#include -#include -#include -#include -#include -#include "./rnn-inl.h" - -namespace mxnet { -namespace op { -#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5 -template -class CuDNNRNNOp : public Operator { - public: - explicit CuDNNRNNOp(RNNParam param) { - this->param_ = param; - init_cudnn_ = false; - dtype_ = mshadow::DataType::kCudnnFlag; - // TensorCore algos only allowed on fp16-I/O convolutions if permitted by the global policy. - // No tests in place for fp16 RNNs, so leave TensorCore disabled for now. - cudnn_tensor_core_ = false; - // When fp16 RNN tests are introduced, we can enable TensorCore as follows: -// cudnn_tensor_core = -// mshadow::DataType::kFlag == mshadow::kFloat16 && GetEnvAllowTensorCore(); - // Defaults - input_mode_ = CUDNN_LINEAR_INPUT; // Don't support this yet - // RNN Mode - switch (param_.mode) { - case rnn_enum::kRnnRelu: - mode_ = CUDNN_RNN_RELU; - break; - case rnn_enum::kRnnTanh: - mode_ = CUDNN_RNN_TANH; - break; - case rnn_enum::kLstm: - mode_ = CUDNN_LSTM; - break; - case rnn_enum::kGru: - mode_ = CUDNN_GRU; - break; - default: - LOG(FATAL) << "Not implmented"; - } -#if USE_CUDNN_LSTM_PROJ - if (param_.projection_size.has_value()) { - CHECK_EQ(param_.mode, rnn_enum::kLstm) - << "Projection is only supported for LSTM."; - CHECK_GE(param_.state_size, param_.projection_size.value()) - << "State size must be larger than projection size."; - } -#else - CHECK(!param_.projection_size.has_value()) - << "Projection is only supported for LSTM with CuDNN version later than 7.1.1."; -#endif -#if USE_CUDNN_LSTM_PROJ - if (param_.lstm_state_clip_min.has_value() - || param_.lstm_state_clip_max.has_value()) { - CHECK_EQ(param_.mode, rnn_enum::kLstm) - << "State clipping is only supported for LSTM."; - CHECK(param_.lstm_state_clip_min.has_value() && param_.lstm_state_clip_max.has_value()) - << "lstm_state_clip_min and lstm_state_clip_max must be specified together."; - CHECK_GE(param_.lstm_state_clip_max.value(), param_.lstm_state_clip_min.value()) - << "lstm_state_clip_max must be greater or equal to lstm_state_clip_min"; - } -#else - CHECK(!param_.lstm_state_clip_min.has_value() - && !param_.lstm_state_clip_max.has_value()) - << "State clipping is only supported for LSTM with CuDNN version later than 7.2.1."; -#endif - // RNN Direction - direction_ = param_.bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL; - // Other - if (param_.mode == rnn_enum::kLstm) - param_.lstm_q_ = true; - else - param_.lstm_q_ = false; - - // Create descriptors - CUDNN_CALL(cudnnCreateTensorDescriptor(&hx_desc_)); - CUDNN_CALL(cudnnCreateTensorDescriptor(&cx_desc_)); - CUDNN_CALL(cudnnCreateTensorDescriptor(&hy_desc_)); - CUDNN_CALL(cudnnCreateTensorDescriptor(&cy_desc_)); - CUDNN_CALL(cudnnCreateTensorDescriptor(&dhx_desc_)); - CUDNN_CALL(cudnnCreateTensorDescriptor(&dcx_desc_)); - CUDNN_CALL(cudnnCreateTensorDescriptor(&dhy_desc_)); - CUDNN_CALL(cudnnCreateTensorDescriptor(&dcy_desc_)); - - CUDNN_CALL(cudnnCreateFilterDescriptor(&w_desc_)); - CUDNN_CALL(cudnnCreateFilterDescriptor(&dw_desc_)); - - CUDNN_CALL(cudnnCreateRNNDescriptor(&rnn_desc_)); - CUDNN_CALL(cudnnCreateDropoutDescriptor(&dropout_desc_)); - - #if USE_CUDNN_LSTM_PROJ - CUDNN_CALL(cudnnCreateRNNDataDescriptor(&x_data_desc_)); - CUDNN_CALL(cudnnCreateRNNDataDescriptor(&y_data_desc_)); - CUDNN_CALL(cudnnCreateRNNDataDescriptor(&dx_data_desc_)); - CUDNN_CALL(cudnnCreateRNNDataDescriptor(&dy_data_desc_)); - #endif - } - - ~CuDNNRNNOp() { - CUDNN_CALL(cudnnDestroyTensorDescriptor(hx_desc_)); - CUDNN_CALL(cudnnDestroyTensorDescriptor(cx_desc_)); - CUDNN_CALL(cudnnDestroyTensorDescriptor(hy_desc_)); - CUDNN_CALL(cudnnDestroyTensorDescriptor(cy_desc_)); - CUDNN_CALL(cudnnDestroyTensorDescriptor(dhx_desc_)); - CUDNN_CALL(cudnnDestroyTensorDescriptor(dcx_desc_)); - CUDNN_CALL(cudnnDestroyTensorDescriptor(dhy_desc_)); - CUDNN_CALL(cudnnDestroyTensorDescriptor(dcy_desc_)); - - CUDNN_CALL(cudnnDestroyFilterDescriptor(w_desc_)); - CUDNN_CALL(cudnnDestroyFilterDescriptor(dw_desc_)); - CUDNN_CALL(cudnnDestroyRNNDescriptor(rnn_desc_)); - CUDNN_CALL(cudnnDestroyDropoutDescriptor(dropout_desc_)); - - if (init_cudnn_) { - for (size_t i = 0; i < x_desc_vec_.size(); ++i) { - CUDNN_CALL(cudnnDestroyTensorDescriptor(x_desc_vec_[i])); - CUDNN_CALL(cudnnDestroyTensorDescriptor(y_desc_vec_[i])); - CUDNN_CALL(cudnnDestroyTensorDescriptor(dx_desc_vec_[i])); - CUDNN_CALL(cudnnDestroyTensorDescriptor(dy_desc_vec_[i])); - } - init_cudnn_ = false; - - Storage::Get()->Free(reserve_space_); - if (param_.p > 0) { - Storage::Get()->Free(dropout_states_); - } - } - #if USE_CUDNN_LSTM_PROJ - CUDNN_CALL(cudnnDestroyRNNDataDescriptor(x_data_desc_)); - CUDNN_CALL(cudnnDestroyRNNDataDescriptor(y_data_desc_)); - CUDNN_CALL(cudnnDestroyRNNDataDescriptor(dx_data_desc_)); - CUDNN_CALL(cudnnDestroyRNNDataDescriptor(dy_data_desc_)); - #endif - } - - virtual void Forward(const OpContext &ctx, const std::vector &in_data, - const std::vector &req, - const std::vector &out_data, - const std::vector &aux_args) { - using namespace mshadow; - size_t in_expected = param_.lstm_q_ ? 4 : 3; - size_t out_expected = param_.lstm_q_ ? 3 : 2; - if (!param_.state_outputs) - out_expected = 1; - - CHECK_EQ(in_data.size(), in_expected); - CHECK_EQ(out_data.size(), out_expected); - Stream *s = ctx.get_stream(); - // get input + output tensors - Tensor x = in_data[rnn_enum::kData].get(s); - Tensor w = in_data[rnn_enum::kParams].get(s); - Tensor hx = in_data[rnn_enum::kState].get(s); - Tensor y = out_data[rnn_enum::kOut].get(s); - - void * hy_ptr = NULL; - if (param_.state_outputs) - hy_ptr = out_data[rnn_enum::kStateOut].get(s).dptr_; - - DType * cx_ptr = NULL; - DType * cy_ptr = NULL; - - if (param_.lstm_q_) - cx_ptr = (in_data[rnn_enum::kStateCell].get(s)).dptr_; - if (param_.lstm_q_ && param_.state_outputs) - cy_ptr = (out_data[rnn_enum::kStateCellOut].get(s)).dptr_; - - CHECK_EQ(x.CheckContiguous(), true); - CHECK_EQ(w.CheckContiguous(), true); - CHECK_EQ(hx.CheckContiguous(), true); - CHECK_EQ(y.CheckContiguous(), true); - - if (!init_cudnn_) { - Init(s, in_data, out_data); - } - // Get temp space - int temp_size = workspace_size_; - Tensor temp_space = - ctx.requested[rnn_enum::kTempSpace].get_space_typed( - mshadow::Shape1(temp_size), s); - #if USE_CUDNN_LSTM_PROJ - std::vector seqLengthArray(param_.batch_size_, param_.seq_length_); - CUDNN_CALL(cudnnSetRNNDataDescriptor(x_data_desc_, - dtype_, - CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED, - param_.seq_length_, - param_.batch_size_, - param_.input_size_, - seqLengthArray.data(), - nullptr)); - int out_size = - (param_.projection_size.has_value()) ? param_.projection_size.value() : param_.state_size; - out_size = (param_.bidirectional) ? (out_size * 2) : out_size; - CUDNN_CALL(cudnnSetRNNDataDescriptor(y_data_desc_, - dtype_, - CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED, - param_.seq_length_, - param_.batch_size_, - out_size, - seqLengthArray.data(), - nullptr)); - if (ctx.is_train) { - CUDNN_CALL(cudnnSetRNNDataDescriptor(dx_data_desc_, - dtype_, - CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED, - param_.seq_length_, - param_.batch_size_, - param_.input_size_, - seqLengthArray.data(), - nullptr)); - CUDNN_CALL(cudnnSetRNNDataDescriptor(dy_data_desc_, - dtype_, - CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED, - param_.seq_length_, - param_.batch_size_, - out_size, - seqLengthArray.data(), - nullptr)); - } - #endif - - #if USE_CUDNN_LSTM_PROJ - bool clip_state = param_.lstm_state_clip_min.has_value(); - bool clip_nan = param_.lstm_state_clip_nan; - CUDNN_CALL(cudnnRNNSetClip(s->dnn_handle_, - rnn_desc_, - clip_state ? CUDNN_RNN_CLIP_MINMAX : CUDNN_RNN_CLIP_NONE, - clip_nan ? CUDNN_NOT_PROPAGATE_NAN : CUDNN_PROPAGATE_NAN, - clip_state ? param_.lstm_state_clip_min.value() : 0.0, - clip_state ? param_.lstm_state_clip_max.value() : 0.0)); - #endif - - if (ctx.is_train) { - #if USE_CUDNN_LSTM_PROJ - CUDNN_CALL(cudnnRNNForwardTrainingEx(s->dnn_handle_, - rnn_desc_, - x_data_desc_, - x.dptr_, - hx_desc_, - hx.dptr_, - cx_desc_, - cx_ptr, - w_desc_, - w.dptr_, - y_data_desc_, - y.dptr_, - hy_desc_, - hy_ptr, - cy_desc_, - cy_ptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - temp_space.dptr_, - workspace_byte_, - reserve_space_.dptr, - reserve_space_byte_)); - #else - CUDNN_CALL(cudnnRNNForwardTraining(s->dnn_handle_, - rnn_desc_, - param_.seq_length_, - x_desc_vec_.data(), - x.dptr_, - hx_desc_, - hx.dptr_, - cx_desc_, - cx_ptr, - w_desc_, - w.dptr_, - y_desc_vec_.data(), - y.dptr_, - hy_desc_, - hy_ptr, - cy_desc_, - cy_ptr, - temp_space.dptr_, - workspace_byte_, - reserve_space_.dptr, - reserve_space_byte_)); - #endif - } else { - #if USE_CUDNN_LSTM_PROJ - CUDNN_CALL(cudnnRNNForwardInferenceEx(s->dnn_handle_, - rnn_desc_, - x_data_desc_, - x.dptr_, - hx_desc_, - hx.dptr_, - cx_desc_, - cx_ptr, - w_desc_, - w.dptr_, - y_data_desc_, - y.dptr_, - hy_desc_, - hy_ptr, - cy_desc_, - cy_ptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - temp_space.dptr_, - workspace_byte_)); - #else - CUDNN_CALL(cudnnRNNForwardInference(s->dnn_handle_, - rnn_desc_, - param_.seq_length_, - x_desc_vec_.data(), - x.dptr_, - hx_desc_, - hx.dptr_, - cx_desc_, - cx_ptr, - w_desc_, - w.dptr_, - y_desc_vec_.data(), - y.dptr_, - hy_desc_, - hy_ptr, - cy_desc_, - cy_ptr, - temp_space.dptr_, - workspace_byte_)); - #endif - } - } - - virtual void Backward(const OpContext &ctx, - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data, - const std::vector &req, - const std::vector &in_grad, - const std::vector &aux_args) { - using namespace mshadow; - size_t in_expected = param_.lstm_q_ ? 4 : 3; - size_t out_expected = param_.lstm_q_ ? 3 : 2; - if (!param_.state_outputs) - out_expected = 1; - - CHECK_EQ(in_data.size(), in_expected); - CHECK_EQ(out_data.size(), out_expected); - CHECK_EQ(in_grad.size(), in_expected); - CHECK_EQ(out_grad.size(), out_expected); - CHECK_EQ(req.size(), in_expected); - CHECK_NE(req[rnn_enum::kData], kAddTo) << "AddTo is not supported for data"; - CHECK_NE(req[rnn_enum::kState], kAddTo) << "AddTo is not supported for state"; - Stream *s = ctx.get_stream(); - // get input + output tensors - Tensor x = in_data[rnn_enum::kData].get(s); - Tensor dx = in_grad[rnn_enum::kData].get(s); - Tensor w = in_data[rnn_enum::kParams].get(s); - Tensor dw = in_grad[rnn_enum::kParams].get(s); - Tensor hx = in_data[rnn_enum::kState].get(s); - Tensor dhx = in_grad[rnn_enum::kState].get(s); - Tensor y = out_data[rnn_enum::kOut].get(s); - Tensor dy = out_grad[rnn_enum::kOut].get(s); - if (req[rnn_enum::kParams] != kAddTo) { - dw = mshadow::expr::ScalarExp(0.0f); - } - // only need kStateOut grad output_states is true - void * dhy_ptr = NULL; - if (param_.state_outputs) - dhy_ptr = out_grad[rnn_enum::kStateOut].get(s).dptr_; - - // Deal with lstm - void * dcx_ptr = NULL; - void * dcy_ptr = NULL; - void * cx_ptr = NULL; - - if (param_.mode == rnn_enum::kLstm) { - CHECK_NE(req[rnn_enum::kStateCell], kAddTo) << "AddTo is not supported for state cell"; - cx_ptr = (in_data[rnn_enum::kStateCell].get(s)).dptr_; - dcx_ptr = (in_grad[rnn_enum::kStateCell].get(s)).dptr_; - } - if ((param_.mode == rnn_enum::kLstm) && param_.state_outputs) - dcy_ptr = (out_grad[rnn_enum::kStateCellOut].get(s)).dptr_; - - CHECK_EQ(x.CheckContiguous(), true); - CHECK_EQ(w.CheckContiguous(), true); - CHECK_EQ(dw.CheckContiguous(), true); - CHECK_EQ(hx.CheckContiguous(), true); - CHECK_EQ(dhx.CheckContiguous(), true); - CHECK_EQ(y.CheckContiguous(), true); - CHECK_EQ(dy.CheckContiguous(), true); - - if (!init_cudnn_) { - Init(s, in_data, out_data); - } - - // Get temp space - int temp_size = workspace_size_; - Tensor temp_space = - ctx.requested[rnn_enum::kTempSpace].get_space_typed( - mshadow::Shape1(temp_size), s); - #if USE_CUDNN_LSTM_PROJ - CUDNN_CALL(cudnnRNNBackwardDataEx(s->dnn_handle_, - rnn_desc_, - y_data_desc_, - y.dptr_, - dy_data_desc_, - dy.dptr_, - nullptr, - nullptr, - dhy_desc_, - dhy_ptr, - dcy_desc_, - dcy_ptr, - w_desc_, - w.dptr_, - hx_desc_, - hx.dptr_, - cx_desc_, - cx_ptr, - dx_data_desc_, - dx.dptr_, - dhx_desc_, - dhx.dptr_, - dcx_desc_, - dcx_ptr, - nullptr, - nullptr, - temp_space.dptr_, - workspace_byte_, - reserve_space_.dptr, - reserve_space_byte_)); - CUDNN_CALL(cudnnRNNBackwardWeightsEx(s->dnn_handle_, - rnn_desc_, - x_data_desc_, - x.dptr_, - hx_desc_, - hx.dptr_, - y_data_desc_, - y.dptr_, - temp_space.dptr_, - workspace_byte_, - dw_desc_, - dw.dptr_, - reserve_space_.dptr, - reserve_space_byte_)); - #else - CUDNN_CALL(cudnnRNNBackwardData(s->dnn_handle_, - rnn_desc_, - param_.seq_length_, - y_desc_vec_.data(), - y.dptr_, - dy_desc_vec_.data(), - dy.dptr_, - dhy_desc_, - dhy_ptr, - dcy_desc_, - dcy_ptr, - w_desc_, - w.dptr_, - hx_desc_, - hx.dptr_, - cx_desc_, - cx_ptr, - dx_desc_vec_.data(), - dx.dptr_, - dhx_desc_, - dhx.dptr_, - dcx_desc_, - dcx_ptr, - temp_space.dptr_, - workspace_byte_, - reserve_space_.dptr, - reserve_space_byte_)); - CUDNN_CALL(cudnnRNNBackwardWeights(s->dnn_handle_, - rnn_desc_, - param_.seq_length_, - x_desc_vec_.data(), - x.dptr_, - hx_desc_, - hx.dptr_, - y_desc_vec_.data(), - y.dptr_, - temp_space.dptr_, - workspace_byte_, - dw_desc_, - dw.dptr_, - reserve_space_.dptr, - reserve_space_byte_)); - #endif - } - - private: - inline void Init(mshadow::Stream *s, - const std::vector &in_data, - const std::vector &out_data) { - using namespace mshadow; - #if CUDNN_MAJOR >= 5 - format_ = CUDNN_TENSOR_NCHW; - #endif - size_t in_expected = param_.lstm_q_ ? 4 : 3; - size_t out_expected = param_.lstm_q_ ? 3 : 2; - if (!param_.state_outputs) - out_expected = 1; - - CHECK_EQ(in_data.size(), in_expected); - CHECK_EQ(out_data.size(), out_expected); - if (!init_cudnn_) { - init_cudnn_ = true; - // get input + output tensors - Tensor x = in_data[rnn_enum::kData].get(s); - Tensor w = in_data[rnn_enum::kParams].get(s); - param_.seq_length_ = x.shape_[0]; - param_.batch_size_ = x.shape_[1]; - param_.input_size_ = x.shape_[2]; - - // Tensor Descriptors - std::vector x_vec(param_.seq_length_); - std::vector y_vec(param_.seq_length_); - std::vector dx_vec(param_.seq_length_); - std::vector dy_vec(param_.seq_length_); - int dimA[3]; - int strideA[3]; - for (int i = 0; i < param_.seq_length_; i++) { - CUDNN_CALL(cudnnCreateTensorDescriptor(&x_vec[i])); - CUDNN_CALL(cudnnCreateTensorDescriptor(&y_vec[i])); - CUDNN_CALL(cudnnCreateTensorDescriptor(&dx_vec[i])); - CUDNN_CALL(cudnnCreateTensorDescriptor(&dy_vec[i])); - - dimA[0] = param_.batch_size_; - dimA[1] = param_.input_size_; - dimA[2] = 1; - strideA[0] = dimA[2] * dimA[1]; - strideA[1] = dimA[2]; - strideA[2] = 1; - - CUDNN_CALL(cudnnSetTensorNdDescriptor(x_vec[i], - dtype_, - 3, - dimA, - strideA)); - CUDNN_CALL(cudnnSetTensorNdDescriptor(dx_vec[i], - dtype_, - 3, - dimA, - strideA)); - dimA[0] = param_.batch_size_; - dimA[1] = param_.bidirectional ? param_.state_size * 2 : param_.state_size; - dimA[2] = 1; - strideA[0] = dimA[2] * dimA[1]; - strideA[1] = dimA[2]; - strideA[2] = 1; - - CUDNN_CALL(cudnnSetTensorNdDescriptor(y_vec[i], - dtype_, - 3, - dimA, - strideA)); - CUDNN_CALL(cudnnSetTensorNdDescriptor(dy_vec[i], - dtype_, - 3, - dimA, - strideA)); - } - x_desc_vec_ = x_vec; - y_desc_vec_ = y_vec; - dx_desc_vec_ = dx_vec; - dy_desc_vec_ = dy_vec; - - // set the state tensors - dimA[0] = param_.num_layers * (param_.bidirectional ? 2 : 1); - dimA[1] = param_.batch_size_; - dimA[2] = param_.state_size; - strideA[0] = dimA[2] * dimA[1]; - strideA[1] = dimA[2]; - strideA[2] = 1; - #if USE_CUDNN_LSTM_PROJ - int dimB[3]; - int strideB[3]; - dimB[0] = param_.num_layers * (param_.bidirectional ? 2 : 1); - dimB[1] = param_.batch_size_; - dimB[2] = param_.projection_size.has_value() ? - param_.projection_size.value() : param_.state_size; - strideB[0] = dimB[2] * dimB[1]; - strideB[1] = dimB[2]; - strideB[2] = 1; - #endif - - #if USE_CUDNN_LSTM_PROJ - CUDNN_CALL(cudnnSetTensorNdDescriptor(hx_desc_, - dtype_, - 3, - dimB, - strideB)); - #else - CUDNN_CALL(cudnnSetTensorNdDescriptor(hx_desc_, - dtype_, - 3, - dimA, - strideA)); - #endif - CUDNN_CALL(cudnnSetTensorNdDescriptor(cx_desc_, - dtype_, - 3, - dimA, - strideA)); - #if USE_CUDNN_LSTM_PROJ - CUDNN_CALL(cudnnSetTensorNdDescriptor(hy_desc_, - dtype_, - 3, - dimB, - strideB)); - #else - CUDNN_CALL(cudnnSetTensorNdDescriptor(hy_desc_, - dtype_, - 3, - dimA, - strideA)); - #endif - CUDNN_CALL(cudnnSetTensorNdDescriptor(cy_desc_, - dtype_, - 3, - dimA, - strideA)); - #if USE_CUDNN_LSTM_PROJ - CUDNN_CALL(cudnnSetTensorNdDescriptor(dhx_desc_, - dtype_, - 3, - dimB, - strideB)); - #else - CUDNN_CALL(cudnnSetTensorNdDescriptor(dhx_desc_, - dtype_, - 3, - dimA, - strideA)); - #endif - CUDNN_CALL(cudnnSetTensorNdDescriptor(dcx_desc_, - dtype_, - 3, - dimA, - strideA)); - #if USE_CUDNN_LSTM_PROJ - CUDNN_CALL(cudnnSetTensorNdDescriptor(dhy_desc_, - dtype_, - 3, - dimB, - strideB)); - #else - CUDNN_CALL(cudnnSetTensorNdDescriptor(dhy_desc_, - dtype_, - 3, - dimA, - strideA)); - #endif - CUDNN_CALL(cudnnSetTensorNdDescriptor(dcy_desc_, - dtype_, - 3, - dimA, - strideA)); - - // Create Dropout descriptors - if (param_.p > 0) { - CUDNN_CALL(cudnnDropoutGetStatesSize(s->dnn_handle_, &dropout_byte_)); - dropout_size_ = dropout_byte_ / sizeof(DType); - dropout_states_ = Storage::Get()->Alloc(dropout_byte_, Context::GPU(s->dev_id)); - } else { - dropout_states_ = {}; - dropout_byte_ = 0; - } - CUDNN_CALL(cudnnSetDropoutDescriptor(dropout_desc_, s->dnn_handle_, - param_.p, // discard probability - dropout_states_.dptr, dropout_byte_, - seed_)); - // RNN descriptors - #if CUDNN_MAJOR >= 6 - cudnnRNNAlgo_t rnn_algo = CUDNN_RNN_ALGO_STANDARD; - CUDNN_CALL(cudnnSetRNNDescriptor_v6(s->dnn_handle_, - rnn_desc_, - param_.state_size, - param_.num_layers, - dropout_desc_, - input_mode_, - direction_, - mode_, - rnn_algo, - dtype_)); - #else - CUDNN_CALL(cudnnSetRNNDescriptor(rnn_desc_, - param_.state_size, - param_.num_layers, - dropout_desc_, - input_mode_, - direction_, - mode_, - dtype_)); - #endif - #if CUDNN_MAJOR >= 7 - cudnnMathType_t math_type = CUDNN_DEFAULT_MATH; - if (cudnn_tensor_core_ && rnn_algo == CUDNN_RNN_ALGO_STANDARD) { - math_type = CUDNN_TENSOR_OP_MATH; - } - #if CUDNN_VERSION >= 7200 - if (GetEnvAllowTensorCore() && GetEnvAllowTensorCoreConversion() && - (DataType::kFlag != kFloat16)) - math_type = CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION; - #endif - CUDNN_CALL(cudnnSetRNNMatrixMathType(rnn_desc_, math_type)); - #endif - #if USE_CUDNN_LSTM_PROJ - if (param_.projection_size.has_value()) { - CUDNN_CALL(cudnnSetRNNProjectionLayers(s->dnn_handle_, - rnn_desc_, - param_.projection_size.value(), - 0)); - } - #endif - // Get temp space sizes - CUDNN_CALL(cudnnGetRNNWorkspaceSize(s->dnn_handle_, - rnn_desc_, - param_.seq_length_, - x_desc_vec_.data(), - &workspace_byte_)); - CUDNN_CALL(cudnnGetRNNTrainingReserveSize(s->dnn_handle_, - rnn_desc_, - param_.seq_length_, - x_desc_vec_.data(), - &reserve_space_byte_)); - workspace_size_ = workspace_byte_ / sizeof(DType); - // Allocate the reserve space - reserve_space_ = Storage::Get()->Alloc(reserve_space_byte_, Context::GPU(s->dev_id)); - - // Check that number of params are correct - size_t cudnn_param_size; - CUDNN_CALL(cudnnGetRNNParamsSize(s->dnn_handle_, - rnn_desc_, - x_desc_vec_[0], - &cudnn_param_size, - dtype_)); - CHECK_EQ(w.shape_[0] * sizeof(DType), cudnn_param_size); - - // Set param descriptors - int dim_w[3] = {1, 1, 1}; - dim_w[0] = w.shape_[0]; - CUDNN_CALL(cudnnSetFilterNdDescriptor(w_desc_, - dtype_, - format_, - 3, - dim_w)); - CUDNN_CALL(cudnnSetFilterNdDescriptor(dw_desc_, - dtype_, - format_, - 3, - dim_w)); - - // Query weight layout - // cudnnFilterDescriptor_t m_desc; - // CHECK_EQ(cudnnCreateFilterDescriptor(&m_desc), CUDNN_STATUS_SUCCESS); - // DType *p; - // int n = 2; - // int64_t last = 0; - // if (param_.mode == rnn_enum::kLstm) n = 8; - // else if (param_.mode == rnn_enum::kGru) n = 6; - - // for (int i = 0; i < param_.num_layers*(param_.bidirectional?2:1); ++i) { - // for (int j = 0; j < n; ++j) { - // CHECK_EQ(cudnnGetRNNLinLayerMatrixParams(s->dnn_handle_, rnn_desc_, - // i, x_desc_vec_[0], w_desc_, 0, j, m_desc, (void**)&p), CUDNN_STATUS_SUCCESS); - // LOG(INFO) << ((int64_t)(p - NULL))/sizeof(DType) - last; - // last = ((int64_t)(p - NULL))/sizeof(DType); - // cudnnDataType_t t; - // cudnnTensorFormat_t f; - // int ndim = 5; - // int dims[5] = {0, 0, 0, 0, 0}; - // CHECK_EQ(cudnnGetFilterNdDescriptor(m_desc, ndim, &t, &f, &ndim, &dims[0]), - // CUDNN_STATUS_SUCCESS); - // LOG(INFO) << "w: " << i << " " << j << " " << ((int64_t)(p - NULL))/sizeof(DType); - // for (int i = 0; i < ndim; ++i) LOG(INFO) << dims[i]; - // } - // } - - // for (int i = 0; i < param_.num_layers*(param_.bidirectional?2:1); ++i) { - // for (int j = 0; j < n; ++j) { - // CHECK_EQ(cudnnGetRNNLinLayerBiasParams(s->dnn_handle_, rnn_desc_, i, x_desc_vec_[0], - // w_desc_, 0, j, m_desc, (void**)&p), CUDNN_STATUS_SUCCESS); - // LOG(INFO) << ((int64_t)(p - NULL))/sizeof(DType) - last; - // last = ((int64_t)(p - NULL))/sizeof(DType); - // LOG(INFO) << "b: " << i << " " << j << " " << ((int64_t)(p - NULL))/sizeof(DType); - // } - // } - } - } - - cudnnDataType_t dtype_; - bool init_cudnn_; - cudnnRNNDescriptor_t rnn_desc_; - cudnnRNNMode_t mode_; - cudnnDirectionMode_t direction_; - cudnnRNNInputMode_t input_mode_; - cudnnDropoutDescriptor_t dropout_desc_; - Storage::Handle dropout_states_, reserve_space_; - uint64_t seed_ = 17 + rand() % 4096; // NOLINT(runtime/threadsafe_fn) - size_t workspace_byte_, reserve_space_byte_, dropout_byte_; - int workspace_size_, dropout_size_; - std::vector x_desc_vec_, y_desc_vec_, dx_desc_vec_, dy_desc_vec_; - #if USE_CUDNN_LSTM_PROJ - cudnnRNNDataDescriptor_t x_data_desc_, y_data_desc_, dx_data_desc_, dy_data_desc_; - #endif - cudnnTensorDescriptor_t hx_desc_, cx_desc_; - cudnnTensorDescriptor_t hy_desc_, cy_desc_; - cudnnTensorDescriptor_t dhx_desc_, dcx_desc_; - cudnnTensorDescriptor_t dhy_desc_, dcy_desc_; - - cudnnFilterDescriptor_t w_desc_, dw_desc_; - // Allow TensorCore algo policy - bool cudnn_tensor_core_; - - #if CUDNN_MAJOR >= 5 - cudnnTensorFormat_t format_; - #endif - RNNParam param_; -}; -#endif // __CUDACC__ && CUDNN -} // namespace op -} // namespace mxnet - -#endif // MXNET_OPERATOR_CUDNN_RNN_INL_H_ diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index 71ad331786ae..37f21ce6d126 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -26,6 +26,9 @@ #ifndef MXNET_OPERATOR_RNN_INL_H_ #define MXNET_OPERATOR_RNN_INL_H_ +#define MXNET_USE_CUDNN_RNN MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5 +#define USE_CUDNN_LSTM_PROJ MXNET_USE_CUDNN == 1 && CUDNN_VERSION >= 7200 + #include #include #include @@ -35,6 +38,7 @@ #include #include #include +#include #include "./math.h" #include "./math_functions-inl.h" #include "./operator_common.h" @@ -47,7 +51,7 @@ namespace rnn_enum { enum RNNOpInputs {kData, kParams, kState, kStateCell}; enum RNNOpOutputs {kOut, kStateOut, kStateCellOut}; enum RNNModeType {kRnnRelu, kRnnTanh, kLstm, kGru}; - enum RNNOpResource {kTempSpace}; + enum RNNOpResource {kCuDNNDropoutDescSpace}; } inline int GetRnnParamSize(int num_layer, @@ -160,9 +164,8 @@ struct RNNParam : public dmlc::Parameter { uint32_t num_layers; bool bidirectional, state_outputs; int mode; - float p, pkeep_; + float p; int seq_length_, batch_size_, input_size_; - bool lstm_q_; // whether type is lstm dmlc::optional projection_size; dmlc::optional lstm_state_clip_min, lstm_state_clip_max; bool lstm_state_clip_nan; @@ -212,7 +215,6 @@ struct RNNParam : public dmlc::Parameter { } }; - /** * @params: ws: Temp workspace for gemm's output storage. * rs: Reserve space of forward intermediate data used for training. @@ -236,6 +238,7 @@ struct RNNParam : public dmlc::Parameter { * hy's shape is [num_layers, batch_size, state_size] * cy_ptr: Only used in lstm mode. pointer of tensor cy containing the cell state * for t=seq_length. cy' shape is [num_layers, batch_size, state_size] + * dropout: should be 0 <= dropout < 1 * mode: Specifies the type of RNN to compute. */ template @@ -376,58 +379,189 @@ void RNNBackward(DType* ws, } } -template -class RNNOp : public Operator{ +template +class RNNOp { public: - explicit RNNOp(RNNParam p) - :param_(p), init_space_(false), reserve_space_size_(0) { + RNNParam param_; + Context ctx_; + explicit RNNOp(RNNParam param, Context ctx) { + this->param_ = param; + this->ctx_ = ctx; + #if MXNET_USE_CUDNN_RNN + init_cudnn_ = false; + dtype_ = mshadow::DataType::kCudnnFlag; + // TensorCore algos only allowed on fp16-I/O convolutions if permitted by the global policy. + // No tests in place for fp16 RNNs, so leave TensorCore disabled for now. + cudnn_tensor_core_ = false; + // When fp16 RNN tests are introduced, we can enable TensorCore as follows: +// cudnn_tensor_core = +// mshadow::DataType::kFlag == mshadow::kFloat16 && GetEnvAllowTensorCore(); + // Defaults + input_mode_ = CUDNN_LINEAR_INPUT; // Don't support this yet + // RNN Mode + switch (param_.mode) { + case rnn_enum::kRnnRelu: + mode_ = CUDNN_RNN_RELU; + break; + case rnn_enum::kRnnTanh: + mode_ = CUDNN_RNN_TANH; + break; + case rnn_enum::kLstm: + mode_ = CUDNN_LSTM; + break; + case rnn_enum::kGru: + mode_ = CUDNN_GRU; + break; + default: + LOG(FATAL) << "Not implmented"; + } +#if USE_CUDNN_LSTM_PROJ if (param_.projection_size.has_value()) { - LOG(FATAL) << "hidden layer projection is only supported for GPU with CuDNN later than 7.1.1"; + CHECK_EQ(param_.mode, rnn_enum::kLstm) + << "Projection is only supported for LSTM."; + CHECK_GE(param_.state_size, param_.projection_size.value()) + << "State size must be larger than projection size."; } +#else + CHECK(!param_.projection_size.has_value()) + << "Projection is only supported for LSTM with CuDNN version later than 7.1.1."; +#endif +#if USE_CUDNN_LSTM_PROJ if (param_.lstm_state_clip_min.has_value() || param_.lstm_state_clip_max.has_value()) { - LOG(FATAL) << "LSTM state clipping is only supported for GPU with CuDNN later than 7.2.1"; + CHECK_EQ(param_.mode, rnn_enum::kLstm) + << "State clipping is only supported for LSTM."; + CHECK(param_.lstm_state_clip_min.has_value() && param_.lstm_state_clip_max.has_value()) + << "lstm_state_clip_min and lstm_state_clip_max must be specified together."; + CHECK_GE(param_.lstm_state_clip_max.value(), param_.lstm_state_clip_min.value()) + << "lstm_state_clip_max must be greater or equal to lstm_state_clip_min"; + } +#else + CHECK(!param_.lstm_state_clip_min.has_value() + && !param_.lstm_state_clip_max.has_value()) + << "State clipping is only supported for LSTM with CuDNN version later than 7.2.1."; +#endif + // RNN Direction + direction_ = param_.bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL; + // Create descriptors + CUDNN_CALL(cudnnCreateTensorDescriptor(&hx_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&cx_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&hy_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&cy_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&dhx_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&dcx_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&dhy_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&dcy_desc_)); + + CUDNN_CALL(cudnnCreateFilterDescriptor(&w_desc_)); + CUDNN_CALL(cudnnCreateFilterDescriptor(&dw_desc_)); + + CUDNN_CALL(cudnnCreateRNNDescriptor(&rnn_desc_)); + CUDNN_CALL(cudnnCreateDropoutDescriptor(&dropout_desc_)); + + #if USE_CUDNN_LSTM_PROJ + CUDNN_CALL(cudnnCreateRNNDataDescriptor(&x_data_desc_)); + CUDNN_CALL(cudnnCreateRNNDataDescriptor(&y_data_desc_)); + CUDNN_CALL(cudnnCreateRNNDataDescriptor(&dx_data_desc_)); + CUDNN_CALL(cudnnCreateRNNDataDescriptor(&dy_data_desc_)); + #endif + #else + if (ctx_.dev_type == kGPU) { + LOG(FATAL) << "RNN on GPU is only available for cuDNN at the moment."; + } + #endif + + if (ctx_.dev_type == kCPU) { + this->init_space_ = false; + this->temp_init_space_ = false; + this->reserve_cpu_space_size_ = 0; + this->temp_cpu_space_size_ = 0; + + if (param_.projection_size.has_value()) { + LOG(FATAL) << + "hidden layer projection is only supported for GPU with CuDNN later than 7.1.1"; + } + if (param_.lstm_state_clip_min.has_value() + || param_.lstm_state_clip_max.has_value()) { + LOG(FATAL) << "LSTM state clipping is only supported for GPU with CuDNN later than 7.2.1"; + } } } ~RNNOp() { - if (init_space_) { + #if MXNET_USE_CUDNN_RNN + CUDNN_CALL(cudnnDestroyTensorDescriptor(hx_desc_)); + CUDNN_CALL(cudnnDestroyTensorDescriptor(cx_desc_)); + CUDNN_CALL(cudnnDestroyTensorDescriptor(hy_desc_)); + CUDNN_CALL(cudnnDestroyTensorDescriptor(cy_desc_)); + CUDNN_CALL(cudnnDestroyTensorDescriptor(dhx_desc_)); + CUDNN_CALL(cudnnDestroyTensorDescriptor(dcx_desc_)); + CUDNN_CALL(cudnnDestroyTensorDescriptor(dhy_desc_)); + CUDNN_CALL(cudnnDestroyTensorDescriptor(dcy_desc_)); + + CUDNN_CALL(cudnnDestroyFilterDescriptor(w_desc_)); + CUDNN_CALL(cudnnDestroyFilterDescriptor(dw_desc_)); + CUDNN_CALL(cudnnDestroyRNNDescriptor(rnn_desc_)); + CUDNN_CALL(cudnnDestroyDropoutDescriptor(dropout_desc_)); + + if (init_cudnn_) { + for (size_t i = 0; i < x_desc_vec_.size(); ++i) { + CUDNN_CALL(cudnnDestroyTensorDescriptor(x_desc_vec_[i])); + CUDNN_CALL(cudnnDestroyTensorDescriptor(y_desc_vec_[i])); + CUDNN_CALL(cudnnDestroyTensorDescriptor(dx_desc_vec_[i])); + CUDNN_CALL(cudnnDestroyTensorDescriptor(dy_desc_vec_[i])); + } + init_cudnn_ = false; + Storage::Get()->Free(temp_space_); Storage::Get()->Free(reserve_space_); - init_space_ = false; + } + #if USE_CUDNN_LSTM_PROJ + CUDNN_CALL(cudnnDestroyRNNDataDescriptor(x_data_desc_)); + CUDNN_CALL(cudnnDestroyRNNDataDescriptor(y_data_desc_)); + CUDNN_CALL(cudnnDestroyRNNDataDescriptor(dx_data_desc_)); + CUDNN_CALL(cudnnDestroyRNNDataDescriptor(dy_data_desc_)); + #endif + #endif + + if (ctx_.dev_type == kCPU) { + if (init_space_) { + Storage::Get()->Free(reserve_cpu_space_); + init_space_ = false; + } + if (temp_init_space_) { + Storage::Get()->Free(temp_cpu_space_); + temp_init_space_ = false; + } } } - virtual void Forward(const OpContext &ctx, - const std::vector &in_data, - const std::vector &req, - const std::vector &out_data, - const std::vector &aux_args) { + void Forward(const OpContext &ctx, const std::vector &in_data, + const std::vector &req, + const std::vector &out_data) { using namespace mshadow; using namespace mshadow::expr; CHECK(param_.p >= 0.0f && param_.p < 1.0f) << "unsupported dropout value, should be 0 <= dropout < 1"; - - size_t in_expected = (param_.mode == rnn_enum::kLstm) ? 4 : 3; - size_t out_expected = (param_.mode == rnn_enum::kLstm) ? 3 : 2; - if (!param_.state_outputs) { - out_expected = 1; + size_t num_inputs = (param_.mode == rnn_enum::kLstm) ? 4 : 3; + // kOut + size_t num_outputs = 1; + if (param_.state_outputs) { + // kOut, kStateOut, kStateCellOut + num_outputs = (param_.mode == rnn_enum::kLstm) ? 3 : 2; } - CHECK_EQ(in_data.size(), in_expected); - CHECK_EQ(out_data.size(), out_expected); - Stream *s = ctx.get_stream(); - // get input + output tensor - Tensor x = in_data[rnn_enum::kData].get(s); - Tensor w = in_data[rnn_enum::kParams].get(s); - Tensor hx = in_data[rnn_enum::kState].get(s); - Tensor y = out_data[rnn_enum::kOut].get(s); - CHECK(x.CheckContiguous()); - CHECK(w.CheckContiguous()); - CHECK(hx.CheckContiguous()); - CHECK(y.CheckContiguous()); + + CHECK_EQ(in_data.size(), num_inputs); + CHECK_EQ(out_data.size(), num_outputs); + Stream *s = ctx.get_stream(); + // get input + output tensors + Tensor x = in_data[rnn_enum::kData].get(s); + Tensor w = in_data[rnn_enum::kParams].get(s); + Tensor hx = in_data[rnn_enum::kState].get(s); + Tensor y = out_data[rnn_enum::kOut].get(s); + param_.seq_length_ = x.shape_[0]; param_.batch_size_ = x.shape_[1]; param_.input_size_ = x.shape_[2]; - const int direction = param_.bidirectional ? 2 : 1; const int bsize = GetRnnBiasSize(param_.num_layers, param_.state_size, direction, param_.mode); DType* b_ptr = w.dptr_ + w.shape_[0] - bsize; @@ -438,124 +572,308 @@ class RNNOp : public Operator{ } DType* cx_ptr = NULL; DType* cy_ptr = NULL; - - if (param_.mode == rnn_enum::kLstm) { - cx_ptr = in_data[rnn_enum::kStateCell].dptr(); - if (param_.state_outputs) { - cy_ptr = out_data[rnn_enum::kStateCellOut].dptr(); - } + if (param_.mode == rnn_enum::kLstm) + cx_ptr = (in_data[rnn_enum::kStateCell].get(s)).dptr_; + if (param_.mode == rnn_enum::kLstm && param_.state_outputs) + cy_ptr = (out_data[rnn_enum::kStateCellOut].get(s)).dptr_; + + CHECK_EQ(x.CheckContiguous(), true); + CHECK_EQ(w.CheckContiguous(), true); + CHECK_EQ(hx.CheckContiguous(), true); + CHECK_EQ(y.CheckContiguous(), true); + + #if MXNET_USE_CUDNN_RNN && defined(__CUDACC__) + if (!init_cudnn_) { + Init(ctx, s, in_data, out_data); } - // allocate temp space - const size_t workspace_size = GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, - param_.state_size, direction, param_.mode); - Tensor workspace = ctx.requested[rnn_enum::kTempSpace] - .get_space_typed(Shape1(workspace_size), s); + #if USE_CUDNN_LSTM_PROJ + std::vector seqLengthArray(param_.batch_size_, param_.seq_length_); + CUDNN_CALL(cudnnSetRNNDataDescriptor(x_data_desc_, + dtype_, + CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED, + param_.seq_length_, + param_.batch_size_, + param_.input_size_, + seqLengthArray.data(), + nullptr)); + int out_size = + (param_.projection_size.has_value()) ? param_.projection_size.value() : param_.state_size; + out_size = (param_.bidirectional) ? (out_size * 2) : out_size; + CUDNN_CALL(cudnnSetRNNDataDescriptor(y_data_desc_, + dtype_, + CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED, + param_.seq_length_, + param_.batch_size_, + out_size, + seqLengthArray.data(), + nullptr)); + if (ctx.is_train) { + CUDNN_CALL(cudnnSetRNNDataDescriptor(dx_data_desc_, + dtype_, + CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED, + param_.seq_length_, + param_.batch_size_, + param_.input_size_, + seqLengthArray.data(), + nullptr)); + CUDNN_CALL(cudnnSetRNNDataDescriptor(dy_data_desc_, + dtype_, + CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED, + param_.seq_length_, + param_.batch_size_, + out_size, + seqLengthArray.data(), + nullptr)); + } + #endif + + #if USE_CUDNN_LSTM_PROJ + bool clip_state = param_.lstm_state_clip_min.has_value(); + bool clip_nan = param_.lstm_state_clip_nan; + CUDNN_CALL(cudnnRNNSetClip(s->dnn_handle_, + rnn_desc_, + clip_state ? CUDNN_RNN_CLIP_MINMAX : CUDNN_RNN_CLIP_NONE, + clip_nan ? CUDNN_NOT_PROPAGATE_NAN : CUDNN_PROPAGATE_NAN, + clip_state ? param_.lstm_state_clip_min.value() : 0.0, + clip_state ? param_.lstm_state_clip_max.value() : 0.0)); + #endif if (ctx.is_train) { - const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction, - param_.seq_length_, param_.batch_size_, - param_.state_size, param_.mode); - if (init_space_ && reserve_space_size_ < r_size) { - Storage::Get()->Free(reserve_space_); - init_space_ = false; + #if USE_CUDNN_LSTM_PROJ + CUDNN_CALL(cudnnRNNForwardTrainingEx(s->dnn_handle_, + rnn_desc_, + x_data_desc_, + x.dptr_, + hx_desc_, + hx.dptr_, + cx_desc_, + cx_ptr, + w_desc_, + w.dptr_, + y_data_desc_, + y.dptr_, + hy_desc_, + hy_ptr, + cy_desc_, + cy_ptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + temp_space_.dptr, + workspace_byte_, + reserve_space_.dptr, + reserve_space_byte_)); + #else + CUDNN_CALL(cudnnRNNForwardTraining(s->dnn_handle_, + rnn_desc_, + param_.seq_length_, + x_desc_vec_.data(), + x.dptr_, + hx_desc_, + hx.dptr_, + cx_desc_, + cx_ptr, + w_desc_, + w.dptr_, + y_desc_vec_.data(), + y.dptr_, + hy_desc_, + hy_ptr, + cy_desc_, + cy_ptr, + temp_space_.dptr, + workspace_byte_, + reserve_space_.dptr, + reserve_space_byte_)); + #endif + } else { + #if USE_CUDNN_LSTM_PROJ + CUDNN_CALL(cudnnRNNForwardInferenceEx(s->dnn_handle_, + rnn_desc_, + x_data_desc_, + x.dptr_, + hx_desc_, + hx.dptr_, + cx_desc_, + cx_ptr, + w_desc_, + w.dptr_, + y_data_desc_, + y.dptr_, + hy_desc_, + hy_ptr, + cy_desc_, + cy_ptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + temp_space_.dptr, + workspace_byte_)); + #else + CUDNN_CALL(cudnnRNNForwardInference(s->dnn_handle_, + rnn_desc_, + param_.seq_length_, + x_desc_vec_.data(), + x.dptr_, + hx_desc_, + hx.dptr_, + cx_desc_, + cx_ptr, + w_desc_, + w.dptr_, + y_desc_vec_.data(), + y.dptr_, + hy_desc_, + hy_ptr, + cy_desc_, + cy_ptr, + temp_space_.dptr, + workspace_byte_)); + #endif + } + #endif + + if (ctx_.dev_type == kCPU) { + // allocate temp space + const size_t work_cpu_space_size = + GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, + param_.state_size, direction, param_.mode); + if (temp_init_space_ && temp_cpu_space_size_ < work_cpu_space_size) { + Storage::Get()->Free(temp_cpu_space_); + temp_init_space_ = false; } - - if (!init_space_) { - reserve_space_ = Storage::Get()->Alloc(r_size * sizeof(DType), Context::CPU()); - reserve_space_size_ = r_size; - init_space_ = true; + if (!temp_init_space_) { + temp_cpu_space_ = Storage::Get()->Alloc + (work_cpu_space_size * sizeof(DType), Context::CPU()); + temp_cpu_space_size_ = work_cpu_space_size; + temp_init_space_ = true; + } + DType* work_cpu_space = static_cast(temp_cpu_space_.dptr); + if (ctx.is_train) { + const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction, + param_.seq_length_, param_.batch_size_, + param_.state_size, param_.mode); + if (init_space_ && reserve_cpu_space_size_ < r_size) { + Storage::Get()->Free(reserve_cpu_space_); + init_space_ = false; + } + if (!init_space_) { + reserve_cpu_space_ = Storage::Get()->Alloc(r_size * sizeof(DType), Context::CPU()); + reserve_cpu_space_size_ = r_size; + init_space_ = true; + } + + DType* reserve_space_ptr = static_cast(reserve_cpu_space_.dptr); + + RNNForwardTraining(work_cpu_space, + reserve_space_ptr, + param_.state_outputs, + param_.num_layers, + direction, + param_.seq_length_, + param_.batch_size_, + param_.input_size_, + param_.state_size, + x.dptr_, + hx.dptr_, + cx_ptr, + w.dptr_, + b_ptr, + y.dptr_, + hy_ptr, + cy_ptr, + param_.p, + param_.mode); + } else { + RNNForwardInference(work_cpu_space, + param_.state_outputs, + param_.num_layers, + direction, + param_.seq_length_, + param_.batch_size_, + param_.input_size_, + param_.state_size, + x.dptr_, + hx.dptr_, + cx_ptr, + w.dptr_, + b_ptr, + y.dptr_, + hy_ptr, + cy_ptr, + param_.mode); } - - DType* reserve_space_ptr = static_cast(reserve_space_.dptr); - RNNForwardTraining(workspace.dptr_, - reserve_space_ptr, - param_.state_outputs, - param_.num_layers, - direction, - param_.seq_length_, - param_.batch_size_, - param_.input_size_, - param_.state_size, - x.dptr_, - hx.dptr_, - cx_ptr, - w.dptr_, - b_ptr, - y.dptr_, - hy_ptr, - cy_ptr, - param_.p, - param_.mode); - } else { - RNNForwardInference(workspace.dptr_, - param_.state_outputs, - param_.num_layers, - direction, - param_.seq_length_, - param_.batch_size_, - param_.input_size_, - param_.state_size, - x.dptr_, - hx.dptr_, - cx_ptr, - w.dptr_, - b_ptr, - y.dptr_, - hy_ptr, - cy_ptr, - param_.mode); } } - virtual void Backward(const OpContext &ctx, - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data, - const std::vector &req, - const std::vector &in_grad, - const std::vector &aux_args) { + void Backward(const OpContext &ctx, + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &req, + const std::vector &in_grad) { using namespace mshadow; using namespace mshadow::expr; CHECK(param_.p >= 0.0f && param_.p < 1.0f) << "unsupported dropout value, should be 0 <= dropout < 1"; - size_t in_expected = (param_.mode == rnn_enum::kLstm) ? 4 : 3; - size_t out_expected = (param_.mode == rnn_enum::kLstm) ? 3 : 2; - if (!param_.state_outputs) { - out_expected = 1; + size_t num_inputs = (param_.mode == rnn_enum::kLstm) ? 4 : 3; + // kOut + size_t num_outputs = 1; + if (param_.state_outputs) { + // kOut, kStateOut, kStateCellOut + num_outputs = (param_.mode == rnn_enum::kLstm) ? 3 : 2; } - CHECK_EQ(in_data.size(), in_expected); - CHECK_EQ(out_data.size(), out_expected); - CHECK_EQ(in_grad.size(), in_expected); - CHECK_EQ(out_grad.size(), out_expected); - CHECK_EQ(req.size(), in_expected); + + CHECK_EQ(in_data.size(), num_inputs); + CHECK_EQ(out_data.size(), num_outputs); + CHECK_EQ(in_grad.size(), num_inputs); + CHECK_EQ(out_grad.size(), num_outputs); + CHECK_EQ(req.size(), num_inputs); CHECK_NE(req[rnn_enum::kData], kAddTo) << "AddTo is not supported for data"; CHECK_NE(req[rnn_enum::kState], kAddTo) << "AddTo is not supported for state"; - mshadow::Stream *s = ctx.get_stream(); + Stream *s = ctx.get_stream(); // get input + output tensors - Tensor x = in_data[rnn_enum::kData].get(s); - Tensor w = in_data[rnn_enum::kParams].get(s); - Tensor hx = in_data[rnn_enum::kState].get(s); - Tensor y = out_data[rnn_enum::kOut].get(s); - Tensor dx = in_grad[rnn_enum::kData].get(s); - Tensor dw = in_grad[rnn_enum::kParams].get(s); - Tensor dhx = in_grad[rnn_enum::kState].get(s); - Tensor dy = out_grad[rnn_enum::kOut].get(s); - CHECK(x.CheckContiguous()); - CHECK(w.CheckContiguous()); - CHECK(hx.CheckContiguous()); - CHECK(y.CheckContiguous()); - CHECK(dx.CheckContiguous()); - CHECK(dw.CheckContiguous()); - CHECK(dhx.CheckContiguous()); - CHECK(dy.CheckContiguous()); + Tensor x = in_data[rnn_enum::kData].get(s); + Tensor dx = in_grad[rnn_enum::kData].get(s); + Tensor w = in_data[rnn_enum::kParams].get(s); + Tensor dw = in_grad[rnn_enum::kParams].get(s); + Tensor hx = in_data[rnn_enum::kState].get(s); + Tensor dhx = in_grad[rnn_enum::kState].get(s); + Tensor y = out_data[rnn_enum::kOut].get(s); + Tensor dy = out_grad[rnn_enum::kOut].get(s); + + CHECK_EQ(x.CheckContiguous(), true); + CHECK_EQ(w.CheckContiguous(), true); + CHECK_EQ(dw.CheckContiguous(), true); + CHECK_EQ(hx.CheckContiguous(), true); + CHECK_EQ(dhx.CheckContiguous(), true); + CHECK_EQ(y.CheckContiguous(), true); + CHECK_EQ(dy.CheckContiguous(), true); + CHECK_EQ(dx.CheckContiguous(), true); + + if (req[rnn_enum::kParams] != kAddTo) { + dw = mshadow::expr::ScalarExp(0.0f); + } + param_.seq_length_ = x.shape_[0]; param_.batch_size_ = x.shape_[1]; param_.input_size_ = x.shape_[2]; const int direction = param_.bidirectional ? 2 : 1; const int bsize = GetRnnBiasSize(param_.num_layers, param_.state_size, direction, param_.mode); + DType* db_ptr = dw.dptr_ + w.shape_[0] - bsize; DType * dhy_ptr = NULL; @@ -563,260 +881,582 @@ class RNNOp : public Operator{ dhy_ptr = out_grad[rnn_enum::kStateOut].dptr(); } - DType * cx_ptr = NULL; - DType * dcx_ptr = NULL; - DType * dcy_ptr = NULL; + DType* dcx_ptr = NULL; + DType* dcy_ptr = NULL; + DType* cx_ptr = NULL; if (param_.mode == rnn_enum::kLstm) { CHECK_NE(req[rnn_enum::kStateCell], kAddTo) << "AddTo is not supported for state cell"; - cx_ptr = in_data[rnn_enum::kStateCell].dptr(); - dcx_ptr = in_grad[rnn_enum::kStateCell].dptr(); - if (param_.state_outputs) { - dcy_ptr = out_grad[rnn_enum::kStateCellOut].dptr(); - } + cx_ptr = (in_data[rnn_enum::kStateCell].get(s)).dptr_; + dcx_ptr = (in_grad[rnn_enum::kStateCell].get(s)).dptr_; } + if ((param_.mode == rnn_enum::kLstm) && param_.state_outputs) + dcy_ptr = (out_grad[rnn_enum::kStateCellOut].get(s)).dptr_; - // allocate temp space - const size_t workspace_size = GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, - param_.state_size, direction, param_.mode); - Tensor workspace = ctx.requested[rnn_enum::kTempSpace] - .get_space_typed(Shape1(workspace_size), s); - - size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction, - param_.seq_length_, param_.batch_size_, - param_.state_size, param_.mode); - if (!init_space_ || reserve_space_size_ != r_size) { - LOG(FATAL) << "Check forward init error"; + #if MXNET_USE_CUDNN_RNN && defined(__CUDACC__) + if (!init_cudnn_) { + Init(ctx, s, in_data, out_data); } - DType* reserve_space_ptr = static_cast(reserve_space_.dptr); - RNNBackward(workspace.dptr_, - reserve_space_ptr, - param_.num_layers, - direction, - param_.seq_length_, - param_.batch_size_, - param_.input_size_, - param_.state_size, - x.dptr_, - hx.dptr_, - cx_ptr, - w.dptr_, - y.dptr_, - dy.dptr_, - dhy_ptr, - dcy_ptr, - dx.dptr_, - dhx.dptr_, - dcx_ptr, - dw.dptr_, - db_ptr, - req[rnn_enum::kData], - req[rnn_enum::kParams], - req[rnn_enum::kState], - // State cell should be present for LSTMs, but is absent for other RNNs. - param_.mode == rnn_enum::kLstm ? req[rnn_enum::kStateCell] : kNullOp, - param_.p, - param_.mode); - } - - private: - RNNParam param_; - bool init_space_; - size_t reserve_space_size_; - Storage::Handle reserve_space_; -}; // class RNNOp + #if USE_CUDNN_LSTM_PROJ + CUDNN_CALL(cudnnRNNBackwardDataEx(s->dnn_handle_, + rnn_desc_, + y_data_desc_, + y.dptr_, + dy_data_desc_, + dy.dptr_, + nullptr, + nullptr, + dhy_desc_, + dhy_ptr, + dcy_desc_, + dcy_ptr, + w_desc_, + w.dptr_, + hx_desc_, + hx.dptr_, + cx_desc_, + cx_ptr, + dx_data_desc_, + dx.dptr_, + dhx_desc_, + dhx.dptr_, + dcx_desc_, + dcx_ptr, + nullptr, + nullptr, + temp_space_.dptr, + workspace_byte_, + reserve_space_.dptr, + reserve_space_byte_)); + CUDNN_CALL(cudnnRNNBackwardWeightsEx(s->dnn_handle_, + rnn_desc_, + x_data_desc_, + x.dptr_, + hx_desc_, + hx.dptr_, + y_data_desc_, + y.dptr_, + temp_space_.dptr, + workspace_byte_, + dw_desc_, + dw.dptr_, + reserve_space_.dptr, + reserve_space_byte_)); + #else + CUDNN_CALL(cudnnRNNBackwardData(s->dnn_handle_, + rnn_desc_, + param_.seq_length_, + y_desc_vec_.data(), + y.dptr_, + dy_desc_vec_.data(), + dy.dptr_, + dhy_desc_, + dhy_ptr, + dcy_desc_, + dcy_ptr, + w_desc_, + w.dptr_, + hx_desc_, + hx.dptr_, + cx_desc_, + cx_ptr, + dx_desc_vec_.data(), + dx.dptr_, + dhx_desc_, + dhx.dptr_, + dcx_desc_, + dcx_ptr, + temp_space_.dptr, + workspace_byte_, + reserve_space_.dptr, + reserve_space_byte_)); + CUDNN_CALL(cudnnRNNBackwardWeights(s->dnn_handle_, + rnn_desc_, + param_.seq_length_, + x_desc_vec_.data(), + x.dptr_, + hx_desc_, + hx.dptr_, + y_desc_vec_.data(), + y.dptr_, + temp_space_.dptr, + workspace_byte_, + dw_desc_, + dw.dptr_, + reserve_space_.dptr, + reserve_space_byte_)); + #endif + #endif + + if (ctx_.dev_type == kCPU) { + // allocate temp space + const size_t work_cpu_space_size = + GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, + param_.state_size, direction, param_.mode); + if (!temp_init_space_ || temp_cpu_space_size_ != work_cpu_space_size) { + LOG(FATAL) << "Check temp init error"; + } + DType* work_cpu_space = static_cast(temp_cpu_space_.dptr); + size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction, + param_.seq_length_, param_.batch_size_, + param_.state_size, param_.mode); -template -Operator* CreateOp(RNNParam param, int dtype); + if (!init_space_ || reserve_cpu_space_size_ != r_size) { + LOG(FATAL) << "Check forward init error"; + } -#if DMLC_USE_CXX11 -class RNNProp : public OperatorProperty { - public: - std::vector ListArguments() const override { - if (param_.mode == rnn_enum::kLstm) { - return {"data", "parameters", "state", "state_cell"}; - } else { - return {"data", "parameters", "state"}; + DType* reserve_space_ptr = static_cast(reserve_cpu_space_.dptr); + RNNBackward(work_cpu_space, + reserve_space_ptr, + param_.num_layers, + direction, + param_.seq_length_, + param_.batch_size_, + param_.input_size_, + param_.state_size, + x.dptr_, + hx.dptr_, + cx_ptr, + w.dptr_, + y.dptr_, + dy.dptr_, + dhy_ptr, + dcy_ptr, + dx.dptr_, + dhx.dptr_, + dcx_ptr, + dw.dptr_, + db_ptr, + req[rnn_enum::kData], + req[rnn_enum::kParams], + req[rnn_enum::kState], + // State cell should be present for LSTMs, but is absent for other RNNs. + param_.mode == rnn_enum::kLstm ? req[rnn_enum::kStateCell] : kNullOp, + param_.p, + param_.mode); } } - std::vector ListOutputs() const override { - std::vector outputs = {"output"}; - if (!param_.state_outputs) - return outputs; - else - outputs.emplace_back("state"); - if (param_.mode == rnn_enum::kLstm) - outputs.emplace_back("state_cell"); - return outputs; - } - - int NumOutputs() const override { - int mode_num = (param_.mode == rnn_enum::kLstm) ? 2 : 1; - int num_outputs = param_.state_outputs ? (mode_num + 1) : 1; - return num_outputs; - } - - void Init(const std::vector >& kwargs) override { - param_.Init(kwargs); - } - std::map GetParams() const override { - return param_.__DICT__(); - } - - bool InferShape(mxnet::ShapeVector *in_shape, - mxnet::ShapeVector *out_shape, - mxnet::ShapeVector *aux_shape) const override { + private: + inline void Init(const OpContext &ctx, + mshadow::Stream *s, + const std::vector &in_data, + const std::vector &out_data) { using namespace mshadow; - if (param_.mode == rnn_enum::kLstm) { - CHECK_EQ(in_shape->size(), 4U) << "Input:[data, parameters, state, cell_state]"; - } else { - CHECK_EQ(in_shape->size(), 3U) << "Input:[data, parameters, state]"; - } - const mxnet::TShape &dshape = (*in_shape)[rnn_enum::kData]; - if (dshape.ndim() == 0) return false; - CHECK_EQ(dshape.ndim(), 3U) \ - << "Input data should be rank-3 tensor of dim [sequence length, batch size, input size]"; - // data: [sequence len, batch, input dimension] - int batch_size = dshape[1]; - int input_size = dshape[2]; - int numDirections = param_.bidirectional ? 2 : 1; - int total_layers = numDirections * param_.num_layers; // double for bidirectional - int layer_size = (param_.projection_size.has_value()) ? - param_.projection_size.value() : param_.state_size; - SHAPE_ASSIGN_CHECK(*in_shape, - rnn_enum::kState, - Shape3(total_layers, batch_size, layer_size)); - if (param_.mode == rnn_enum::kLstm) - SHAPE_ASSIGN_CHECK(*in_shape, - rnn_enum::kStateCell, - Shape3(total_layers, batch_size, param_.state_size)); - - // calculate parameter vector length - int param_size = GetRnnParamSize(param_.num_layers, - input_size, - param_.state_size, - numDirections, - param_.mode, - param_.projection_size); - SHAPE_ASSIGN_CHECK(*in_shape, rnn_enum::kParams, Shape1(param_size)); - - out_shape->clear(); - // output: [sequence len, batch, output size] - mxnet::TShape oshape = dshape; - if (param_.projection_size.has_value()) { - oshape[2] = numDirections * param_.projection_size.value(); - } else { - oshape[2] = numDirections * param_.state_size; + size_t num_inputs = (param_.mode == rnn_enum::kLstm) ? 4 : 3; + // kOut + size_t num_outputs = 1; + if (param_.state_outputs) { + // kOut, kStateOut, kStateCellOut + num_outputs = (param_.mode == rnn_enum::kLstm) ? 3 : 2; } - out_shape->push_back(oshape); - if (!param_.state_outputs) { - return true; - } else { - // outStateShape: [layer_num, batch, state size] - mxnet::TShape outStateShape = dshape; - outStateShape[0] = total_layers; - outStateShape[1] = batch_size; - if (param_.projection_size.has_value()) { - outStateShape[2] = param_.projection_size.value(); - } else { - outStateShape[2] = param_.state_size; + + CHECK_EQ(in_data.size(), num_inputs); + CHECK_EQ(out_data.size(), num_outputs); + + #if MXNET_USE_CUDNN_RNN && defined(__CUDACC__) + #if CUDNN_MAJOR >= 5 + format_ = CUDNN_TENSOR_NCHW; + #endif + + if (!init_cudnn_) { + init_cudnn_ = true; + // get input + output tensors + Tensor x = in_data[rnn_enum::kData].get(s); + Tensor w = in_data[rnn_enum::kParams].get(s); + param_.seq_length_ = x.shape_[0]; + param_.batch_size_ = x.shape_[1]; + param_.input_size_ = x.shape_[2]; + + // Tensor Descriptors + std::vector x_vec(param_.seq_length_); + std::vector y_vec(param_.seq_length_); + std::vector dx_vec(param_.seq_length_); + std::vector dy_vec(param_.seq_length_); + int dimA[3]; + int strideA[3]; + for (int i = 0; i < param_.seq_length_; i++) { + CUDNN_CALL(cudnnCreateTensorDescriptor(&x_vec[i])); + CUDNN_CALL(cudnnCreateTensorDescriptor(&y_vec[i])); + CUDNN_CALL(cudnnCreateTensorDescriptor(&dx_vec[i])); + CUDNN_CALL(cudnnCreateTensorDescriptor(&dy_vec[i])); + + dimA[0] = param_.batch_size_; + dimA[1] = param_.input_size_; + dimA[2] = 1; + strideA[0] = dimA[2] * dimA[1]; + strideA[1] = dimA[2]; + strideA[2] = 1; + + CUDNN_CALL(cudnnSetTensorNdDescriptor(x_vec[i], + dtype_, + 3, + dimA, + strideA)); + CUDNN_CALL(cudnnSetTensorNdDescriptor(dx_vec[i], + dtype_, + 3, + dimA, + strideA)); + dimA[0] = param_.batch_size_; + dimA[1] = param_.bidirectional ? param_.state_size * 2 : param_.state_size; + dimA[2] = 1; + strideA[0] = dimA[2] * dimA[1]; + strideA[1] = dimA[2]; + strideA[2] = 1; + + CUDNN_CALL(cudnnSetTensorNdDescriptor(y_vec[i], + dtype_, + 3, + dimA, + strideA)); + CUDNN_CALL(cudnnSetTensorNdDescriptor(dy_vec[i], + dtype_, + 3, + dimA, + strideA)); } - out_shape->push_back(outStateShape); - // Deal with lstm cell state - if (param_.mode == rnn_enum::kLstm) { - mxnet::TShape cellStateShape = dshape; - cellStateShape[0] = total_layers; - cellStateShape[1] = batch_size; - cellStateShape[2] = param_.state_size; - out_shape->push_back(cellStateShape); + x_desc_vec_ = x_vec; + y_desc_vec_ = y_vec; + dx_desc_vec_ = dx_vec; + dy_desc_vec_ = dy_vec; + + // set the state tensors + dimA[0] = param_.num_layers * (param_.bidirectional ? 2 : 1); + dimA[1] = param_.batch_size_; + dimA[2] = param_.state_size; + strideA[0] = dimA[2] * dimA[1]; + strideA[1] = dimA[2]; + strideA[2] = 1; + #if USE_CUDNN_LSTM_PROJ + int dimB[3]; + int strideB[3]; + dimB[0] = param_.num_layers * (param_.bidirectional ? 2 : 1); + dimB[1] = param_.batch_size_; + dimB[2] = param_.projection_size.has_value() ? + param_.projection_size.value() : param_.state_size; + strideB[0] = dimB[2] * dimB[1]; + strideB[1] = dimB[2]; + strideB[2] = 1; + #endif + #if USE_CUDNN_LSTM_PROJ + CUDNN_CALL(cudnnSetTensorNdDescriptor(hx_desc_, + dtype_, + 3, + dimB, + strideB)); + #else + CUDNN_CALL(cudnnSetTensorNdDescriptor(hx_desc_, + dtype_, + 3, + dimA, + strideA)); + #endif + CUDNN_CALL(cudnnSetTensorNdDescriptor(cx_desc_, + dtype_, + 3, + dimA, + strideA)); + #if USE_CUDNN_LSTM_PROJ + CUDNN_CALL(cudnnSetTensorNdDescriptor(hy_desc_, + dtype_, + 3, + dimB, + strideB)); + #else + CUDNN_CALL(cudnnSetTensorNdDescriptor(hy_desc_, + dtype_, + 3, + dimA, + strideA)); + #endif + CUDNN_CALL(cudnnSetTensorNdDescriptor(cy_desc_, + dtype_, + 3, + dimA, + strideA)); + #if USE_CUDNN_LSTM_PROJ + CUDNN_CALL(cudnnSetTensorNdDescriptor(dhx_desc_, + dtype_, + 3, + dimB, + strideB)); + #else + CUDNN_CALL(cudnnSetTensorNdDescriptor(dhx_desc_, + dtype_, + 3, + dimA, + strideA)); + #endif + CUDNN_CALL(cudnnSetTensorNdDescriptor(dcx_desc_, + dtype_, + 3, + dimA, + strideA)); + #if USE_CUDNN_LSTM_PROJ + CUDNN_CALL(cudnnSetTensorNdDescriptor(dhy_desc_, + dtype_, + 3, + dimB, + strideB)); + #else + CUDNN_CALL(cudnnSetTensorNdDescriptor(dhy_desc_, + dtype_, + 3, + dimA, + strideA)); + #endif + CUDNN_CALL(cudnnSetTensorNdDescriptor(dcy_desc_, + dtype_, + 3, + dimA, + strideA)); + + // Create Dropout descriptors + DType* dropout_states_ = NULL; + if (param_.p > 0) { + ctx.requested[rnn_enum::kCuDNNDropoutDescSpace].get_cudnn_dropout_desc + (&dropout_desc_, s, 1.0f - param_.p, seed_); + } else { + dropout_byte_ = 0; } - return true; - } - } - bool InferType(std::vector *in_type, - std::vector *out_type, - std::vector *aux_type) const override { - CHECK_GE(in_type->size(), 1U); - int dtype = (*in_type)[0]; - CHECK_NE(dtype, -1) << "First input must have specified type"; - for (size_t i = 0; i < in_type->size(); ++i) { - if ((*in_type)[i] == -1) { - (*in_type)[i] = dtype; - } else { - UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments()[i]); + CUDNN_CALL(cudnnSetDropoutDescriptor(dropout_desc_, s->dnn_handle_, + param_.p, // discard probability + dropout_states_, dropout_byte_, + seed_)); + + // RNN descriptors + #if CUDNN_MAJOR >= 6 + cudnnRNNAlgo_t rnn_algo = CUDNN_RNN_ALGO_STANDARD; + CUDNN_CALL(cudnnSetRNNDescriptor_v6(s->dnn_handle_, + rnn_desc_, + param_.state_size, + param_.num_layers, + dropout_desc_, + input_mode_, + direction_, + mode_, + rnn_algo, + dtype_)); + #else + CUDNN_CALL(cudnnSetRNNDescriptor(rnn_desc_, + param_.state_size, + param_.num_layers, + dropout_desc_, + input_mode_, + direction_, + mode_, + dtype_)); + #endif + #if CUDNN_MAJOR >= 7 + cudnnMathType_t math_type = CUDNN_DEFAULT_MATH; + if (cudnn_tensor_core_ && rnn_algo == CUDNN_RNN_ALGO_STANDARD) { + math_type = CUDNN_TENSOR_OP_MATH; + } + #if CUDNN_VERSION >= 7200 + if (GetEnvAllowTensorCore() && GetEnvAllowTensorCoreConversion() && + (DataType::kFlag != kFloat16)) + math_type = CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION; + #endif + CUDNN_CALL(cudnnSetRNNMatrixMathType(rnn_desc_, math_type)); + #endif + #if USE_CUDNN_LSTM_PROJ + if (param_.projection_size.has_value()) { + CUDNN_CALL(cudnnSetRNNProjectionLayers(s->dnn_handle_, + rnn_desc_, + param_.projection_size.value(), + 0)); } + #endif + // Get temp space sizes + CUDNN_CALL(cudnnGetRNNWorkspaceSize(s->dnn_handle_, + rnn_desc_, + param_.seq_length_, + x_desc_vec_.data(), + &workspace_byte_)); + CUDNN_CALL(cudnnGetRNNTrainingReserveSize(s->dnn_handle_, + rnn_desc_, + param_.seq_length_, + x_desc_vec_.data(), + &reserve_space_byte_)); + workspace_size_ = workspace_byte_ / sizeof(DType); + // Allocate the reserve space + reserve_space_ = Storage::Get()->Alloc(reserve_space_byte_, Context::GPU(s->dev_id)); + // Allocate the temp space + temp_space_ = Storage::Get()->Alloc(workspace_byte_, Context::GPU(s->dev_id)); + // Check that number of params are correct + size_t cudnn_param_size; + CUDNN_CALL(cudnnGetRNNParamsSize(s->dnn_handle_, + rnn_desc_, + x_desc_vec_[0], + &cudnn_param_size, + dtype_)); + CHECK_EQ(w.shape_[0] * sizeof(DType), cudnn_param_size); + // Set param descriptors + int dim_w[3] = {1, 1, 1}; + dim_w[0] = w.shape_[0]; + CUDNN_CALL(cudnnSetFilterNdDescriptor(w_desc_, + dtype_, + format_, + 3, + dim_w)); + CUDNN_CALL(cudnnSetFilterNdDescriptor(dw_desc_, + dtype_, + format_, + 3, + dim_w)); + + // Query weight layout + // cudnnFilterDescriptor_t m_desc; + // CHECK_EQ(cudnnCreateFilterDescriptor(&m_desc), CUDNN_STATUS_SUCCESS); + // DType *p; + // int n = 2; + // int64_t last = 0; + // if (param_.mode == rnn_enum::kLstm) n = 8; + // else if (param_.mode == rnn_enum::kGru) n = 6; + + // for (int i = 0; i < param_.num_layers*(param_.bidirectional?2:1); ++i) { + // for (int j = 0; j < n; ++j) { + // CHECK_EQ(cudnnGetRNNLinLayerMatrixParams(s->dnn_handle_, rnn_desc_, + // i, x_desc_vec_[0], w_desc_, 0, j, m_desc, (void**)&p), CUDNN_STATUS_SUCCESS); + // LOG(INFO) << ((int64_t)(p - NULL))/sizeof(DType) - last; + // last = ((int64_t)(p - NULL))/sizeof(DType); + // cudnnDataType_t t; + // cudnnTensorFormat_t f; + // int ndim = 5; + // int dims[5] = {0, 0, 0, 0, 0}; + // CHECK_EQ(cudnnGetFilterNdDescriptor(m_desc, ndim, &t, &f, &ndim, &dims[0]), + // CUDNN_STATUS_SUCCESS); + // LOG(INFO) << "w: " << i << " " << j << " " << ((int64_t)(p - NULL))/sizeof(DType); + // for (int i = 0; i < ndim; ++i) LOG(INFO) << dims[i]; + // } + // } + + // for (int i = 0; i < param_.num_layers*(param_.bidirectional?2:1); ++i) { + // for (int j = 0; j < n; ++j) { + // CHECK_EQ(cudnnGetRNNLinLayerBiasParams(s->dnn_handle_, rnn_desc_, i, x_desc_vec_[0], + // w_desc_, 0, j, m_desc, (void**)&p), CUDNN_STATUS_SUCCESS); + // LOG(INFO) << ((int64_t)(p - NULL))/sizeof(DType) - last; + // last = ((int64_t)(p - NULL))/sizeof(DType); + // LOG(INFO) << "b: " << i << " " << j << " " << ((int64_t)(p - NULL))/sizeof(DType); + // } + // } } - out_type->clear(); - out_type->push_back(dtype); - if (!param_.state_outputs) { - return true; + #endif + } + #if MXNET_USE_CUDNN_RNN + cudnnDataType_t dtype_; + bool init_cudnn_; + cudnnRNNDescriptor_t rnn_desc_; + cudnnRNNMode_t mode_; + cudnnDirectionMode_t direction_; + cudnnRNNInputMode_t input_mode_; + cudnnDropoutDescriptor_t dropout_desc_; + Storage::Handle reserve_space_, temp_space_; + uint64_t seed_ = 17 + rand() % 4096; // NOLINT(runtime/threadsafe_fn) + size_t workspace_byte_, reserve_space_byte_, dropout_byte_; + int workspace_size_; + std::vector x_desc_vec_, y_desc_vec_, dx_desc_vec_, dy_desc_vec_; + #if USE_CUDNN_LSTM_PROJ + cudnnRNNDataDescriptor_t x_data_desc_, y_data_desc_, dx_data_desc_, dy_data_desc_; + #endif + cudnnTensorDescriptor_t hx_desc_, cx_desc_; + cudnnTensorDescriptor_t hy_desc_, cy_desc_; + cudnnTensorDescriptor_t dhx_desc_, dcx_desc_; + cudnnTensorDescriptor_t dhy_desc_, dcy_desc_; + + cudnnFilterDescriptor_t w_desc_, dw_desc_; + // Allow TensorCore algo policy + bool cudnn_tensor_core_; + + #if CUDNN_MAJOR >= 5 + cudnnTensorFormat_t format_; + #endif + #endif + bool init_space_, temp_init_space_; + size_t reserve_cpu_space_size_, temp_cpu_space_size_; + Storage::Handle reserve_cpu_space_, temp_cpu_space_; +}; // class RNNOp + +static OpStatePtr CreateRNNState(const nnvm::NodeAttrs &attrs, + const Context ctx, + const mxnet::ShapeVector &in_shapes, + const std::vector &in_types) { + const RNNParam& param = nnvm::get(attrs.parsed); + OpStatePtr state = OpStatePtr(); + MSHADOW_REAL_TYPE_SWITCH(in_types[rnn_enum::kData], DType, { + if (ctx.dev_type == kGPU) { + state = OpStatePtr::Create>(param, ctx); } else { - out_type->push_back(dtype); - // Deal with lstm cell state - if (param_.mode == rnn_enum::kLstm) - out_type->push_back(dtype); - return true; + state = OpStatePtr::Create>(param, ctx); } - } - - OperatorProperty* Copy() const override { - auto ptr = new RNNProp(); - ptr->param_ = param_; - return ptr; - } - - std::string TypeString() const override { - return "RNN"; - } + }); + return state; +} - std::vector DeclareBackwardDependency( - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data) const override { - std::vector dep = {in_data[rnn_enum::kData], in_data[rnn_enum::kParams], - in_data[rnn_enum::kState], out_data[rnn_enum::kOut], out_grad[rnn_enum::kOut]}; +template +void RNNStatefulCompute(const OpStatePtr& state, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + int dtype = inputs[rnn_enum::kData].type_flag_; + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { + RNNOp& op = state.get_state>(); + op.Forward(ctx, inputs, req, outputs); + }); +} - if (param_.state_outputs) { - dep.push_back(out_data[rnn_enum::kStateOut]); - dep.push_back(out_grad[rnn_enum::kStateOut]); +/* +index description +0: x +1: w +2: hx +3: y +4: dy +5: hy +6: dhy +7: cx +8: cy +9: dcy +*/ +template +void RNNStatefulGradCompute(const OpStatePtr& state, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + std::vector in_data(inputs.begin(), inputs.begin() + 3); + std::vector out_data{inputs[3]}; + std::vector out_grad{inputs[4]}; + const std::vector &in_grad = outputs; + + int dtype = inputs[rnn_enum::kData].type_flag_; + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { + RNNOp& op = state.get_state>(); + const RNNParam& param = op.param_; + int index = 5; + if (param.state_outputs) { + out_data.push_back(inputs[index++]); + out_grad.push_back(inputs[index++]); } - if (param_.mode == rnn_enum::kLstm) { - dep.push_back(in_data[rnn_enum::kStateCell]); - if (param_.state_outputs) { - dep.push_back(out_data[rnn_enum::kStateCellOut]); - dep.push_back(out_grad[rnn_enum::kStateCellOut]); + if (param.mode == rnn_enum::kLstm) { + in_data.push_back(inputs[index++]); + if (param.state_outputs) { + out_data.push_back(inputs[index++]); + out_grad.push_back(inputs[index]); } } - return dep; - } - - std::vector ForwardResource( - const mxnet::ShapeVector &in_shape) const override { - return {ResourceRequest::kTempSpace}; - } - - std::vector BackwardResource( - const mxnet::ShapeVector &in_shape) const override { - return {ResourceRequest::kTempSpace}; - } - Operator* CreateOperator(Context ctx) const override { - LOG(FATAL) << "Not Implemented"; - return NULL; - } - - Operator* CreateOperatorEx(Context ctx, mxnet::ShapeVector *in_shape, - std::vector *in_type) const override; + op.Backward(ctx, out_grad, in_data, out_data, req, in_grad); + }); +} - private: - RNNParam param_; -}; // class RNNProp -#endif // DMLC_USE_CXX11 } // namespace op } // namespace mxnet + #endif // MXNET_OPERATOR_RNN_INL_H_ diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc index 621b9eb110e7..74c563afceb1 100644 --- a/src/operator/rnn.cc +++ b/src/operator/rnn.cc @@ -27,24 +27,142 @@ namespace mxnet { namespace op { -template<> -Operator *CreateOp(RNNParam param, int dtype) { - Operator *op = nullptr; - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - op = new RNNOp(param); - }); - return op; + +DMLC_REGISTER_PARAMETER(RNNParam); +static inline std::vector ListArguments(const RNNParam& param_) { + if (param_.mode == rnn_enum::kLstm) { + return {"data", "parameters", "state", "state_cell"}; + } else { + return {"data", "parameters", "state"}; + } } -Operator *RNNProp::CreateOperatorEx(Context ctx, - mxnet::ShapeVector *in_shape, - std::vector *in_type) const { - DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]); +static bool RNNShape(const nnvm::NodeAttrs& attrs, + std::vector *in_shape, + std::vector *out_shape) { + const RNNParam& param_ = nnvm::get(attrs.parsed); + using namespace mshadow; + if (param_.mode == rnn_enum::kLstm) { + CHECK_EQ(in_shape->size(), 4U) << "Needed input:[data, parameters, state, cell_state]," + << " got in_shape->size(): " << in_shape->size(); + } else { + CHECK_EQ(in_shape->size(), 3U) << + "Needed input:[data, parameters, state], got in_shape->size(): " << in_shape->size(); + } + const TShape &dshape = (*in_shape)[rnn_enum::kData]; + if (dshape.ndim() == 0) return false; + CHECK_EQ(dshape.ndim(), 3U) \ + << "Input data should be rank-3 tensor of dim [sequence length, batch size, input size]"; + // data: [sequence len, batch, input dimension] + int batch_size = dshape[1]; + int input_size = dshape[2]; + int numDirections = param_.bidirectional ? 2 : 1; + int total_layers = numDirections * param_.num_layers; // double for bidirectional + int layer_size = (param_.projection_size.has_value()) ? + param_.projection_size.value() : param_.state_size; + SHAPE_ASSIGN_CHECK(*in_shape, + rnn_enum::kState, + Shape3(total_layers, batch_size, layer_size)); + if (param_.mode == rnn_enum::kLstm) { + SHAPE_ASSIGN_CHECK(*in_shape, + rnn_enum::kStateCell, + Shape3(total_layers, batch_size, param_.state_size)); + } + + // calculate parameter vector length + int param_size = GetRnnParamSize(param_.num_layers, + input_size, + param_.state_size, + numDirections, + param_.mode, + param_.projection_size); + SHAPE_ASSIGN_CHECK(*in_shape, rnn_enum::kParams, Shape1(param_size)); + out_shape->clear(); + // output: [sequence len, batch, output size] + TShape oshape = dshape; + if (param_.projection_size.has_value()) { + oshape[2] = numDirections * param_.projection_size.value(); + } else { + oshape[2] = numDirections * param_.state_size; + } + out_shape->push_back(oshape); + if (param_.state_outputs) { + // outStateShape: [layer_num, batch, state size] + TShape outStateShape = dshape; + outStateShape[0] = total_layers; + outStateShape[1] = batch_size; + if (param_.projection_size.has_value()) { + outStateShape[2] = param_.projection_size.value(); + } else { + outStateShape[2] = param_.state_size; + } + out_shape->push_back(outStateShape); + // Deal with lstm cell state + if (param_.mode == rnn_enum::kLstm) { + TShape cellStateShape = dshape; + cellStateShape[0] = total_layers; + cellStateShape[1] = batch_size; + cellStateShape[2] = param_.state_size; + out_shape->push_back(cellStateShape); + } + } + return true; } -DMLC_REGISTER_PARAMETER(RNNParam); +static bool RNNType(const nnvm::NodeAttrs& attrs, + std::vector *in_type, + std::vector *out_type) { + const RNNParam& param_ = nnvm::get(attrs.parsed); + if (param_.mode == rnn_enum::kLstm) { + CHECK_EQ(in_type->size(), 4U); + } else { + CHECK_EQ(in_type->size(), 3U); + } + int dtype = (*in_type)[0]; + CHECK_NE(dtype, -1) << "First input must have specified type"; + for (size_t i = 0; i < in_type->size(); ++i) { + if ((*in_type)[i] == -1) { + TYPE_ASSIGN_CHECK(*in_type, i, dtype); + } else { + UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments(param_)[i]); + } + } + out_type->clear(); + out_type->push_back(dtype); + if (param_.state_outputs) { + out_type->push_back(dtype); + // Deal with lstm cell state + if (param_.mode == rnn_enum::kLstm) + out_type->push_back(dtype); + } + return true; +} -MXNET_REGISTER_OP_PROPERTY(RNN, RNNProp) +struct RNNGrad { + const char *op_name; + std::vector operator()(const nnvm::NodePtr &n, + const std::vector &ograd) const { + const RNNParam& params = nnvm::get(n->attrs.parsed); + std::vector heads{ n->inputs[rnn_enum::kData], + n->inputs[rnn_enum::kParams], n->inputs[rnn_enum::kState] }; + heads.emplace_back(nnvm::NodeEntry{n, rnn_enum::kOut, 0}); + heads.push_back(ograd[rnn_enum::kOut]); + if (params.state_outputs) { + heads.emplace_back(nnvm::NodeEntry{n, rnn_enum::kStateOut, 0}); + heads.push_back(ograd[rnn_enum::kStateOut]); + } + if (params.mode == rnn_enum::kLstm) { + heads.push_back(n->inputs[rnn_enum::kStateCell]); + if (params.state_outputs) { + heads.emplace_back(nnvm::NodeEntry{n, rnn_enum::kStateCellOut, 0}); + heads.push_back(ograd[rnn_enum::kStateCellOut]); + } + } + return MakeGradNode(op_name, n, heads, n->attrs.dict); + } +}; + +NNVM_REGISTER_OP(RNN) .describe(R"code(Applies recurrent layers to input data. Currently, vanilla RNN, LSTM and GRU are implemented, with both multi-layer and bidirectional support. @@ -97,7 +215,49 @@ The definition of GRU here is slightly different from paper but compatible with z_t = \mathrm{sigmoid}(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\ n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\ h_t = (1 - z_t) * n_t + z_t * h_{(t-1)} \\ - \end{array})code") + \end{array} +)code" ADD_FILELINE) +.set_attr_parser(ParamParser) +.set_num_inputs([](const NodeAttrs& attrs) { + const RNNParam& params = nnvm::get(attrs.parsed); + return params.mode == rnn_enum::kLstm ? 4 : 3; +}) +.set_num_outputs([](const NodeAttrs& attrs) { + const RNNParam& params = nnvm::get(attrs.parsed); + // kOut + int num_outputs = 1; + if (params.state_outputs) { + // kOut, kStateOut, kStateCellOut + num_outputs = (params.mode == rnn_enum::kLstm) ? 3 : 2; + } + + return num_outputs; +}) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + const RNNParam& params = nnvm::get(attrs.parsed); + return ListArguments(params); +}) +.set_attr("FInferShape", RNNShape) +.set_attr("FInferType", RNNType) +.set_attr("FCreateOpState", CreateRNNState) +.set_attr("FStatefulCompute", RNNStatefulCompute) +.set_attr("FGradient", RNNGrad{"_backward_RNN"}) +.set_attr("FResourceRequestEx", + [](const NodeAttrs& attrs, const int dev_mask, const DispatchMode dispatch_mode) { + std::vector request; + const RNNParam& param = nnvm::get(attrs.parsed); + if (param.p == 0) return request; + if (dev_mask == kGPU) { +#if MXNET_USE_CUDNN_RNN + if (1.0f - param.p > 0) { + request.emplace_back(ResourceRequest::kCuDNNDropoutDesc); + return request; + } +#endif + } + return request; +}) .add_argument("data", "NDArray-or-Symbol", "Input data to RNN") .add_argument("parameters", "NDArray-or-Symbol", "Vector of all RNN trainable parameters concatenated") @@ -105,5 +265,15 @@ The definition of GRU here is slightly different from paper but compatible with .add_argument("state_cell", "NDArray-or-Symbol", "initial cell state for LSTM networks (only for LSTM)") .add_arguments(RNNParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_RNN) +.set_num_outputs([](const NodeAttrs& attrs) { + const RNNParam& params = nnvm::get(attrs.parsed); + return params.mode == rnn_enum::kLstm ? 4 : 3; +}) +.set_attr_parser(ParamParser) +.set_attr("TIsLayerOpBackward", true) +.set_attr("TIsBackward", true) +.set_attr("FStatefulCompute", RNNStatefulGradCompute); } // namespace op } // namespace mxnet diff --git a/src/operator/rnn.cu b/src/operator/rnn.cu index 402a8cf5f503..77bb95522711 100644 --- a/src/operator/rnn.cu +++ b/src/operator/rnn.cu @@ -26,24 +26,14 @@ #include "./rnn-inl.h" #include -#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5 -#include "./cudnn_rnn-inl.h" -#endif // MXNET_USE_CUDNN && CUDNN_MAJOR namespace mxnet { namespace op { -template<> -Operator* CreateOp(RNNParam param, int dtype) { - Operator *op = NULL; -#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5 - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - op = new CuDNNRNNOp(param); - }) -#else - LOG(FATAL) << "RNN on GPU is only available for cuDNN at the moment."; -#endif // MXNET_USE_CUDNN && CUDNN_MAJOR - return op; -} +NNVM_REGISTER_OP(RNN) +.set_attr("FStatefulCompute", RNNStatefulCompute); + +NNVM_REGISTER_OP(_backward_RNN) +.set_attr("FStatefulCompute", RNNStatefulGradCompute); } // namespace op } // namespace mxnet