Skip to content

Commit

Permalink
Fix CPU-only RRNOp Forward
Browse files Browse the repository at this point in the history
  • Loading branch information
nickguletskii committed May 22, 2020
1 parent 567518b commit 0a921a4
Showing 1 changed file with 55 additions and 56 deletions.
111 changes: 55 additions & 56 deletions src/operator/rnn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -842,53 +842,73 @@ class RNNOp {
#endif // MXNET_USE_CUDNN_GE_7200
}
#endif // MXNET_USE_CUDNN == 1 && defined(__CUDACC__)
#if !defined(__CUDACC__)
int projection_size = 0;
if (param_.projection_size.has_value()) {
projection_size = param_.projection_size.value();
}

if (ctx_.dev_type == kCPU) {
int projection_size = 0;
// allocate temp space
const size_t work_cpu_space_size = GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_,
param_.state_size, projection_size, direction, param_.mode);
if (!temp_init_space_ || temp_cpu_space_size_ < work_cpu_space_size) {
temp_cpu_space_size_ = work_cpu_space_size;
temp_cpu_space_ = NDArray(TShape({static_cast<dim_t>(temp_cpu_space_size_)}), ctx_,
false, in_data[rnn_enum::kData].type_flag_);
temp_init_space_ = true;
}
DType* work_cpu_space = static_cast<DType*>(temp_cpu_space_.data().dptr_);

if (ctx.is_train || ctx.need_grad) {
mshadow::Random<cpu, unsigned> *prnd = ctx.requested[0].get_random<xpu, unsigned int>(s);
std::mt19937 &rnd_engine = prnd->GetRndEngine();

// allocate reserve space
if (param_.projection_size.has_value()) {
projection_size = param_.projection_size.value();
LOG(FATAL) << "No training support for LSTM with projection on CPU currently.";
}

// allocate temp space
const size_t work_cpu_space_size = GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_,
param_.state_size, projection_size, direction, param_.mode);
if (!temp_init_space_ || temp_cpu_space_size_ < work_cpu_space_size) {
temp_cpu_space_size_ = work_cpu_space_size;
temp_cpu_space_ = NDArray(TShape({static_cast<dim_t>(temp_cpu_space_size_)}), ctx_,
const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction,
param_.seq_length_, param_.batch_size_,
param_.state_size, param_.mode);
if (!init_space_ || reserve_cpu_space_size_ < r_size) {
reserve_cpu_space_size_ = r_size;
reserve_cpu_space_ = NDArray(TShape({static_cast<dim_t>(reserve_cpu_space_size_)}), ctx_,
false, in_data[rnn_enum::kData].type_flag_);
temp_init_space_ = true;
init_space_ = true;
}
DType* work_cpu_space = static_cast<DType*>(temp_cpu_space_.data().dptr_);

if (ctx.is_train || ctx.need_grad) {
mshadow::Random<cpu, unsigned> *prnd = ctx.requested[0].get_random<xpu, unsigned int>(s);
std::mt19937 &rnd_engine = prnd->GetRndEngine();

// allocate reserve space
if (param_.projection_size.has_value()) {
LOG(FATAL) << "No training support for LSTM with projection on CPU currently.";
}

const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction,
param_.seq_length_, param_.batch_size_,
param_.state_size, param_.mode);
if (!init_space_ || reserve_cpu_space_size_ < r_size) {
reserve_cpu_space_size_ = r_size;
reserve_cpu_space_ = NDArray(TShape({static_cast<dim_t>(reserve_cpu_space_size_)}), ctx_,
false, in_data[rnn_enum::kData].type_flag_);
init_space_ = true;
}
DType* reserve_space_ptr = static_cast<DType*>(reserve_cpu_space_.data().dptr_);
DType* reserve_space_ptr = static_cast<DType*>(reserve_cpu_space_.data().dptr_);

RNNForwardTraining<DType>(work_cpu_space,
reserve_space_ptr,
RNNForwardTraining<DType>(work_cpu_space,
reserve_space_ptr,
param_.state_outputs,
param_.num_layers,
direction,
param_.seq_length_,
param_.batch_size_,
param_.input_size_,
param_.state_size,
x.dptr_,
hx.dptr_,
cx_ptr,
w.dptr_,
b_ptr,
y.dptr_,
hy_ptr,
cy_ptr,
param_.p,
param_.mode,
rnd_engine);
} else {
RNNForwardInference<DType>(work_cpu_space,
param_.state_outputs,
param_.num_layers,
direction,
param_.seq_length_,
param_.batch_size_,
param_.input_size_,
param_.state_size,
projection_size,
x.dptr_,
hx.dptr_,
cx_ptr,
Expand All @@ -897,30 +917,9 @@ class RNNOp {
y.dptr_,
hy_ptr,
cy_ptr,
param_.p,
param_.mode,
rnd_engine);
} else {
RNNForwardInference<DType>(work_cpu_space,
param_.state_outputs,
param_.num_layers,
direction,
param_.seq_length_,
param_.batch_size_,
param_.input_size_,
param_.state_size,
projection_size,
x.dptr_,
hx.dptr_,
cx_ptr,
w.dptr_,
b_ptr,
y.dptr_,
hy_ptr,
cy_ptr,
param_.mode);
}
param_.mode);
}
#endif // !defined(__CUDACC__)
}

void Backward(const OpContext &ctx,
Expand Down

0 comments on commit 0a921a4

Please sign in to comment.