diff --git a/src/operator/nn/mkldnn/mkldnn_rnn.cc b/src/operator/nn/mkldnn/mkldnn_rnn.cc index 51fbc56271d0..c8f1d45814f5 100644 --- a/src/operator/nn/mkldnn/mkldnn_rnn.cc +++ b/src/operator/nn/mkldnn/mkldnn_rnn.cc @@ -47,15 +47,6 @@ inline int GetRnnGatesNum(int mode) { } } -// Bug in oneDNN <= 1.6 in memory descriptor comparision operators. -// for specific dims and strides in descriptors == operator can return `true` -// but get_size() function will return different size -// TODO(bgawrych): Remove with oneDNN 1.7 upgrade -static inline bool CheckMemDescEquality(const mkldnn::memory::desc &left, - const mkldnn::memory::desc &right) { - return left == right && left.get_size() == right.get_size(); -} - void MKLDNNRnnLayerParam::SetDims() { const int ngates = GetRnnGatesNum(mode); //* NOTES: LBR-GRU's new gate formula needs two bias. So it has one more bias with LBR-GRU @@ -599,13 +590,13 @@ void MKLDNNRnnForwardTraining::SetTrnMem(const MKLDNNRnnForward& fwd) { weights_iter_ = mkldnn_shared_mem_t(new memory(fwd_trn_.GetIterDesc(), cpu_engine)); // fill weights memory using the reordered weights of fwd_inference primitive - if (CheckMemDescEquality(fwd.weights_layer_r_->get_desc(), fwd_trn_.GetLayerDesc())) { + if (fwd.weights_layer_r_->get_desc() == fwd_trn_.GetLayerDesc()) { weights_layer_->set_data_handle(fwd.weights_layer_r_->get_data_handle()); } else { MKLDNNMemoryReorder(*fwd.weights_layer_r_, *weights_layer_); } - if (CheckMemDescEquality(fwd.weights_iter_r_->get_desc(), fwd_trn_.GetIterDesc())) { + if (fwd.weights_iter_r_->get_desc() == fwd_trn_.GetIterDesc()) { weights_iter_->set_data_handle(fwd.weights_iter_r_->get_data_handle()); } else { MKLDNNMemoryReorder(*fwd.weights_iter_r_, *weights_iter_); @@ -729,7 +720,7 @@ void MKLDNNRnnBackward::FetchDataWeightsMem(const MKLDNNRnnForwardTraining& fwd) const mkldnn::memory* valid_mem; switch (kv.first) { case MKLDNN_ARG_WEIGHTS_LAYER: { - if (CheckMemDescEquality(bwd_.weights_layer_desc_, fwd.fwd_trn_.GetLayerDesc())) { + if (bwd_.weights_layer_desc_ == fwd.fwd_trn_.GetLayerDesc()) { this->weights_layer_->set_data_handle(kv.second.get_data_handle()); } else { MKLDNNMemoryReorder(*fwd.weights_layer_, *this->weights_layer_); @@ -737,7 +728,7 @@ void MKLDNNRnnBackward::FetchDataWeightsMem(const MKLDNNRnnForwardTraining& fwd) valid_mem = this->weights_layer_.get(); } break; case MKLDNN_ARG_WEIGHTS_ITER: { - if (CheckMemDescEquality(bwd_.weights_iter_desc_, fwd.fwd_trn_.GetIterDesc())) { + if (bwd_.weights_iter_desc_ == fwd.fwd_trn_.GetIterDesc()) { this->weights_iter_->set_data_handle(kv.second.get_data_handle()); } else { MKLDNNMemoryReorder(*fwd.weights_iter_, *this->weights_iter_); @@ -771,14 +762,14 @@ void MKLDNNRnnBackward::SetWeightsGradsMem() { this->diff_weights_iter_r_ = std::make_shared( native_iter_desc, cpu_engine); - if (CheckMemDescEquality(native_layer_desc, bwd_.diff_weights_layer_desc_)) { + if (native_layer_desc == bwd_.diff_weights_layer_desc_) { this->diff_weights_layer_ = std::make_shared( bwd_.diff_weights_layer_desc_, cpu_engine, diff_weights_layer_r_->get_data_handle()); } else { this->diff_weights_layer_ = std::make_shared( bwd_.diff_weights_layer_desc_, cpu_engine); } - if (CheckMemDescEquality(native_iter_desc, bwd_.diff_weights_iter_desc_)) { + if (native_iter_desc == bwd_.diff_weights_iter_desc_) { this->diff_weights_iter_ = std::make_shared( bwd_.diff_weights_iter_desc_, cpu_engine, diff_weights_iter_r_->get_data_handle()); } else { @@ -830,12 +821,10 @@ void MKLDNNRnnBackward::SetDataGradsMem( } void MKLDNNRnnBackward::SetNativeWeightsGrads() const { - if (!CheckMemDescEquality(this->diff_weights_layer_->get_desc(), - this->diff_weights_layer_r_->get_desc())) { + if (this->diff_weights_layer_->get_desc() != this->diff_weights_layer_r_->get_desc()) { MKLDNNMemoryReorder(*this->diff_weights_layer_, *this->diff_weights_layer_r_); } - if (!CheckMemDescEquality(this->diff_weights_iter_->get_desc(), - this->diff_weights_iter_r_->get_desc())) { + if (this->diff_weights_iter_->get_desc() != this->diff_weights_iter_r_->get_desc()) { MKLDNNMemoryReorder(*this->diff_weights_iter_, *this->diff_weights_iter_r_); } } @@ -854,11 +843,9 @@ void MKLDNNRnnBackward::CommitWeightsGrads(void* diff_weights, void* diff_bias, void* diff_weights_layer_ptr = this->diff_weights_layer_->get_data_handle(); void* diff_weights_iter_ptr = this->diff_weights_iter_->get_data_handle(); - if (!CheckMemDescEquality(this->diff_weights_layer_->get_desc(), - this->diff_weights_layer_r_->get_desc())) + if (this->diff_weights_layer_->get_desc() != this->diff_weights_layer_r_->get_desc()) diff_weights_layer_ptr = this->diff_weights_layer_r_->get_data_handle(); - if (!CheckMemDescEquality(this->diff_weights_iter_->get_desc(), - this->diff_weights_iter_r_->get_desc())) + if (this->diff_weights_iter_->get_desc() != this->diff_weights_iter_r_->get_desc()) diff_weights_iter_ptr = this->diff_weights_iter_r_->get_data_handle(); const int num_layer = param.num_layer;