diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index 7f879fcac395..d41b5b4f030b 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -122,7 +122,7 @@ struct RNNParam : public dmlc::Parameter { } }; -inline int GetRnnParamSize(int num_layer, +inline index_t GetRnnParamSize(int num_layer, index_t input_size, int state_size, int direction, diff --git a/src/operator/rnn_impl.h b/src/operator/rnn_impl.h index f91c7ec4eaa7..459345797936 100644 --- a/src/operator/rnn_impl.h +++ b/src/operator/rnn_impl.h @@ -145,7 +145,7 @@ void LstmForwardTraining(DType* ws, const int total_layers = D * L; Tensor hx(hx_ptr, Shape3(total_layers, N, H)); Tensor cx(cx_ptr, Shape3(total_layers, N, H)); - const int b_size = 2 * H * 4; + const index_t b_size = 2 * H * 4; const index_t r_size = D * T * N * H * 6; const index_t y_offset = T * N * H * 5; const index_t cell_size = N * H; @@ -298,7 +298,7 @@ void LstmForwardInference(DType* ws, const int total_layers = D * L; Tensor hx(hx_ptr, Shape3(total_layers, N, P ? P : H)); Tensor cx(cx_ptr, Shape3(total_layers, N, H)); - const int b_size = 2 * H * 4; + const index_t b_size = 2 * H * 4; const index_t cell_size = N * H; const index_t projection_size = (P ? P : H) * N; DType* y_tmp_ptr = ws + (T + 1) * cell_size * 4 + cell_size * 2; @@ -553,7 +553,7 @@ void LstmBackward(DType* ws, Tensor cx(cx_ptr, Shape3(total_layers, N, H)); Tensor dhx(dhx_ptr, Shape3(total_layers, N, H)); Tensor dcx(dcx_ptr, Shape3(total_layers, N, H)); - const int b_size = 2 * H * 4; + const index_t b_size = 2 * H * 4; const index_t r_size = D * T * N * H * 6; const index_t y_offset = T * N * H * 5; const index_t w_size1 = (I + H) * H * 4; // first layer