From 0a921a4432ac657b47be44af86ab498fee66f964 Mon Sep 17 00:00:00 2001 From: Nick Guletskii Date: Fri, 22 May 2020 23:28:31 +0300 Subject: [PATCH] Fix CPU-only RRNOp Forward --- src/operator/rnn-inl.h | 111 ++++++++++++++++++++--------------------- 1 file changed, 55 insertions(+), 56 deletions(-) diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index ede1d5f4717f..fdce937e50d1 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -842,46 +842,65 @@ class RNNOp { #endif // MXNET_USE_CUDNN_GE_7200 } #endif // MXNET_USE_CUDNN == 1 && defined(__CUDACC__) +#if !defined(__CUDACC__) + int projection_size = 0; + if (param_.projection_size.has_value()) { + projection_size = param_.projection_size.value(); + } - if (ctx_.dev_type == kCPU) { - int projection_size = 0; + // allocate temp space + const size_t work_cpu_space_size = GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, + param_.state_size, projection_size, direction, param_.mode); + if (!temp_init_space_ || temp_cpu_space_size_ < work_cpu_space_size) { + temp_cpu_space_size_ = work_cpu_space_size; + temp_cpu_space_ = NDArray(TShape({static_cast(temp_cpu_space_size_)}), ctx_, + false, in_data[rnn_enum::kData].type_flag_); + temp_init_space_ = true; + } + DType* work_cpu_space = static_cast(temp_cpu_space_.data().dptr_); + + if (ctx.is_train || ctx.need_grad) { + mshadow::Random *prnd = ctx.requested[0].get_random(s); + std::mt19937 &rnd_engine = prnd->GetRndEngine(); + + // allocate reserve space if (param_.projection_size.has_value()) { - projection_size = param_.projection_size.value(); + LOG(FATAL) << "No training support for LSTM with projection on CPU currently."; } - // allocate temp space - const size_t work_cpu_space_size = GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, - param_.state_size, projection_size, direction, param_.mode); - if (!temp_init_space_ || temp_cpu_space_size_ < work_cpu_space_size) { - temp_cpu_space_size_ = work_cpu_space_size; - temp_cpu_space_ = NDArray(TShape({static_cast(temp_cpu_space_size_)}), ctx_, + const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction, + param_.seq_length_, param_.batch_size_, + param_.state_size, param_.mode); + if (!init_space_ || reserve_cpu_space_size_ < r_size) { + reserve_cpu_space_size_ = r_size; + reserve_cpu_space_ = NDArray(TShape({static_cast(reserve_cpu_space_size_)}), ctx_, false, in_data[rnn_enum::kData].type_flag_); - temp_init_space_ = true; + init_space_ = true; } - DType* work_cpu_space = static_cast(temp_cpu_space_.data().dptr_); - - if (ctx.is_train || ctx.need_grad) { - mshadow::Random *prnd = ctx.requested[0].get_random(s); - std::mt19937 &rnd_engine = prnd->GetRndEngine(); - - // allocate reserve space - if (param_.projection_size.has_value()) { - LOG(FATAL) << "No training support for LSTM with projection on CPU currently."; - } - - const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction, - param_.seq_length_, param_.batch_size_, - param_.state_size, param_.mode); - if (!init_space_ || reserve_cpu_space_size_ < r_size) { - reserve_cpu_space_size_ = r_size; - reserve_cpu_space_ = NDArray(TShape({static_cast(reserve_cpu_space_size_)}), ctx_, - false, in_data[rnn_enum::kData].type_flag_); - init_space_ = true; - } - DType* reserve_space_ptr = static_cast(reserve_cpu_space_.data().dptr_); + DType* reserve_space_ptr = static_cast(reserve_cpu_space_.data().dptr_); - RNNForwardTraining(work_cpu_space, - reserve_space_ptr, + RNNForwardTraining(work_cpu_space, + reserve_space_ptr, + param_.state_outputs, + param_.num_layers, + direction, + param_.seq_length_, + param_.batch_size_, + param_.input_size_, + param_.state_size, + x.dptr_, + hx.dptr_, + cx_ptr, + w.dptr_, + b_ptr, + y.dptr_, + hy_ptr, + cy_ptr, + param_.p, + param_.mode, + rnd_engine); + } else { + RNNForwardInference(work_cpu_space, param_.state_outputs, param_.num_layers, direction, @@ -889,6 +908,7 @@ class RNNOp { param_.batch_size_, param_.input_size_, param_.state_size, + projection_size, x.dptr_, hx.dptr_, cx_ptr, @@ -897,30 +917,9 @@ class RNNOp { y.dptr_, hy_ptr, cy_ptr, - param_.p, - param_.mode, - rnd_engine); - } else { - RNNForwardInference(work_cpu_space, - param_.state_outputs, - param_.num_layers, - direction, - param_.seq_length_, - param_.batch_size_, - param_.input_size_, - param_.state_size, - projection_size, - x.dptr_, - hx.dptr_, - cx_ptr, - w.dptr_, - b_ptr, - y.dptr_, - hy_ptr, - cy_ptr, - param_.mode); - } + param_.mode); } +#endif // !defined(__CUDACC__) } void Backward(const OpContext &ctx,