Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
MKL-DNN RNN backward path enhancement (#17183)
Browse files Browse the repository at this point in the history
* Flush memory before RNN backward primitive

* Add gluon rnn unit test for gradients check

* Cache reorder

* Re-write rnn supporting check

* Update OpSignature.AddSign to avoid potential hash collision for
rnn-packed memory

Get the data type from mkldnn memory descriptor when setting grad handle
  • Loading branch information
zixuanweeei authored and pengzhao-intel committed Jan 6, 2020
1 parent 3e638b4 commit 83a23b0
Show file tree
Hide file tree
Showing 6 changed files with 322 additions and 193 deletions.
9 changes: 6 additions & 3 deletions src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,12 @@ static inline bool SupportMKLDNN(int dtype, const mxnet::TShape &shape) {
return dtype == mshadow::kFloat32 && (ndim == 1 || ndim == 2 || ndim == 4);
}

static inline bool SupportMKLDNNRNN(const NDArray &input) {
int ndim = input.shape().ndim();
return (input.dtype() == mshadow::kFloat32) && (ndim == 3);
static inline bool SupportMKLDNNRnn(const NDArray &input) {
if (input.dtype() == mshadow::kFloat32 && input.shape().ndim() == 3
&& dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1)) {
return true;
}
return false;
}

static inline bool SupportMKLDNNQuantize(int dtype) {
Expand Down
9 changes: 6 additions & 3 deletions src/operator/nn/mkldnn/mkldnn_rnn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -369,9 +369,10 @@ class MKLDNNRnnBackward {
void SetDataGradsMem(void* diff_src, void* diff_state, void* diff_statecell,
void* diff_out, void* diff_state_out, void* diff_statecell_out,
const int dtype = mshadow::kFloat32);
void CommitWeightsDiff(void* diff_weights, void* diff_bias,
const OpReqType req,
const int dtype = mshadow::kFloat32);
void SetNativeWeightsGrads() const;
void CommitWeightsGrads(void* diff_weights, void* diff_bias,
const OpReqType req,
const int dtype = mshadow::kFloat32);

const mkldnn::primitive& GetBwd() const { return *bwd_.primitive_; }
const mkldnn_args_map_t& GetArgsMap() const { return net_args_; }
Expand All @@ -386,6 +387,8 @@ class MKLDNNRnnBackward {

mkldnn_shared_mem_t diff_weights_layer_;
mkldnn_shared_mem_t diff_weights_iter_;
mkldnn_shared_mem_t diff_weights_layer_r_;
mkldnn_shared_mem_t diff_weights_iter_r_;
mkldnn_shared_mem_t diff_bias_;

mkldnn_args_map_t net_args_;
Expand Down
Loading

0 comments on commit 83a23b0

Please sign in to comment.