Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

GPU RNN to use TempSpace resource for workspace. #15056

Merged
merged 5 commits into from
May 25, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 26 additions & 18 deletions src/operator/rnn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,6 @@ class RNNOp {
CUDNN_CALL(cudnnDestroyTensorDescriptor(dy_desc_vec_[i]));
}
init_cudnn_ = false;
Storage::Get()->Free(temp_space_);
Storage::Get()->Free(reserve_space_);
}
#if MXNET_USE_CUDNN_GE_7200
Expand Down Expand Up @@ -677,6 +676,12 @@ class RNNOp {
Init(ctx, s, in_data, out_data);
}

// Get temp space
int temp_size = workspace_size_;
Tensor<gpu, 1, DType> temp_space =
ctx.requested[rnn_enum::kTempSpace].get_space_typed<gpu, 1, DType>(
mshadow::Shape1(temp_size), s);

#if MXNET_USE_CUDNN_GE_7200

cudnnRNNDataLayout_t layout_t;
Expand Down Expand Up @@ -770,7 +775,7 @@ class RNNOp {
nullptr,
nullptr,
nullptr,
temp_space_.dptr,
temp_space.dptr_,
workspace_byte_,
reserve_space_.dptr,
reserve_space_byte_));
Expand All @@ -792,7 +797,7 @@ class RNNOp {
hy_ptr,
cy_desc_,
cy_ptr,
temp_space_.dptr,
temp_space.dptr_,
workspace_byte_,
reserve_space_.dptr,
reserve_space_byte_));
Expand Down Expand Up @@ -823,7 +828,7 @@ class RNNOp {
nullptr,
nullptr,
nullptr,
temp_space_.dptr,
temp_space.dptr_,
workspace_byte_));
#else
CUDNN_CALL(cudnnRNNForwardInference(s->dnn_handle_,
Expand All @@ -843,7 +848,7 @@ class RNNOp {
hy_ptr,
cy_desc_,
cy_ptr,
temp_space_.dptr,
temp_space.dptr_,
workspace_byte_));
#endif
}
Expand Down Expand Up @@ -1061,6 +1066,12 @@ class RNNOp {
Init(ctx, s, in_data, out_data);
}

// Get temp space
int temp_size = workspace_size_;
Tensor<gpu, 1, DType> temp_space =
ctx.requested[rnn_enum::kTempSpace].get_space_typed<gpu, 1, DType>(
mshadow::Shape1(temp_size), s);

#if MXNET_USE_CUDNN_GE_7200
CUDNN_CALL(cudnnRNNBackwardDataEx(s->dnn_handle_,
rnn_desc_,
Expand Down Expand Up @@ -1088,7 +1099,7 @@ class RNNOp {
dcx_ptr,
nullptr,
nullptr,
temp_space_.dptr,
temp_space.dptr_,
workspace_byte_,
reserve_space_.dptr,
reserve_space_byte_));
Expand All @@ -1100,7 +1111,7 @@ class RNNOp {
hx.dptr_,
y_data_desc_,
y.dptr_,
temp_space_.dptr,
temp_space.dptr_,
workspace_byte_,
dw_desc_,
dw.dptr_,
Expand Down Expand Up @@ -1130,7 +1141,7 @@ class RNNOp {
dhx.dptr_,
dcx_desc_,
dcx_ptr,
temp_space_.dptr,
temp_space.dptr_,
workspace_byte_,
reserve_space_.dptr,
reserve_space_byte_));
Expand All @@ -1143,7 +1154,7 @@ class RNNOp {
hx.dptr_,
y_desc_vec_.data(),
y.dptr_,
temp_space_.dptr,
temp_space.dptr_,
workspace_byte_,
dw_desc_,
dw.dptr_,
Expand Down Expand Up @@ -1378,17 +1389,16 @@ class RNNOp {
strideA));

// Create Dropout descriptors
DType* dropout_states_ = NULL;
if (param_.p > 0) {
ctx.requested[rnn_enum::kCuDNNDropoutDescSpace].get_cudnn_dropout_desc
(&dropout_desc_, s, 1.0f - param_.p, seed_);
} else {
dropout_byte_ = 0;
}

// Only update the probability by passing in a null dropout_states ptr
DType* dropout_states = NULL;
size_t dropout_bytes = 0;
CUDNN_CALL(cudnnSetDropoutDescriptor(dropout_desc_, s->dnn_handle_,
param_.p, // discard probability
dropout_states_, dropout_byte_,
dropout_states, dropout_bytes,
seed_));

// RNN descriptors
Expand Down Expand Up @@ -1469,8 +1479,6 @@ class RNNOp {
workspace_size_ = workspace_byte_ / sizeof(DType);
// Allocate the reserve space
reserve_space_ = Storage::Get()->Alloc(reserve_space_byte_, Context::GPU(s->dev_id));
// Allocate the temp space
temp_space_ = Storage::Get()->Alloc(workspace_byte_, Context::GPU(s->dev_id));
// Check that number of params are correct
size_t cudnn_param_size;
CUDNN_CALL(cudnnGetRNNParamsSize(s->dnn_handle_,
Expand Down Expand Up @@ -1539,9 +1547,9 @@ class RNNOp {
cudnnDirectionMode_t direction_;
cudnnRNNInputMode_t input_mode_;
cudnnDropoutDescriptor_t dropout_desc_;
Storage::Handle reserve_space_, temp_space_;
Storage::Handle reserve_space_;
uint64_t seed_ = 17 + rand() % 4096; // NOLINT(runtime/threadsafe_fn)
size_t workspace_byte_, reserve_space_byte_, dropout_byte_;
size_t workspace_byte_, reserve_space_byte_;
int workspace_size_;
std::vector<cudnnTensorDescriptor_t> x_desc_vec_, y_desc_vec_, dx_desc_vec_, dy_desc_vec_;
#if MXNET_USE_CUDNN_GE_7200
Expand Down
35 changes: 19 additions & 16 deletions src/operator/rnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,22 @@ static bool RNNType(const nnvm::NodeAttrs& attrs,
return true;
}

static std::vector<ResourceRequest> RNNResourceEx(const NodeAttrs& attrs, const int dev_mask,
const DispatchMode dispatch_mode) {
std::vector<ResourceRequest> request;
if (dev_mask == kGPU) {
#if MXNET_USE_CUDNN_RNN
request.emplace_back(ResourceRequest::kTempSpace);

const RNNParam& param = nnvm::get<RNNParam>(attrs.parsed);
if (param.p != 0 && 1.0f - param.p > 0) {
request.emplace_back(ResourceRequest::kCuDNNDropoutDesc);
}
#endif
}
return request;
}

inline static bool RNNStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
Expand Down Expand Up @@ -703,21 +719,7 @@ The definition of GRU here is slightly different from paper but compatible with
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", RNNStatefulComputeCPU)
#endif
.set_attr<nnvm::FGradient>("FGradient", RNNGrad{"_backward_RNN"})
.set_attr<FResourceRequestEx>("FResourceRequestEx",
[](const NodeAttrs& attrs, const int dev_mask, const DispatchMode dispatch_mode) {
std::vector<ResourceRequest> request;
if (dev_mask == kGPU) {
#if MXNET_USE_CUDNN_RNN
request.emplace_back(ResourceRequest::kTempSpace);

const RNNParam& param = nnvm::get<RNNParam>(attrs.parsed);
if (param.p != 0 && 1.0f - param.p > 0) {
request.emplace_back(ResourceRequest::kCuDNNDropoutDesc);
}
#endif
}
return request;
})
.set_attr<FResourceRequestEx>("FResourceRequestEx", RNNResourceEx)
.add_argument("data", "NDArray-or-Symbol", "Input data to RNN")
.add_argument("parameters", "NDArray-or-Symbol",
"Vector of all RNN trainable parameters concatenated")
Expand All @@ -737,6 +739,7 @@ NNVM_REGISTER_OP(_backward_RNN)
.set_attr_parser(ParamParser<RNNParam>)
.set_attr<bool>("TIsLayerOpBackward", true)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", RNNStatefulGradCompute<cpu>);
.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", RNNStatefulGradCompute<cpu>)
.set_attr<FResourceRequestEx>("FResourceRequestEx", RNNResourceEx);
} // namespace op
} // namespace mxnet