Skip to content

Commit

Permalink
improve hashing.
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-da committed Mar 29, 2018
1 parent 8f3194f commit 58854be
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 3 deletions.
3 changes: 3 additions & 0 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,9 @@ class NDArray {
bool IsDefaultData() const {
return ptr_->IsDefault();
}

const mkldnn::memory::desc &GetMKLDNNDesc() const;

/*
* All functions below return a raw pointer to mkldnn memory. Actually there
* is a shared pointer that hold the memory either in NDArray or in MKLDNN
Expand Down
5 changes: 5 additions & 0 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,11 @@ nnvm::Symbol NDArray::get_autograd_symbol() const {

#if MXNET_USE_MKLDNN == 1

const mkldnn::memory::desc &NDArray::GetMKLDNNDesc() const {
CHECK(ptr_->mkl_mem_);
return ptr_->mkl_mem_->GetDesc();
}

NDArray NDArray::MKLDNNDataReshape(const TShape &shape) const {
CHECK(!is_none()) << "NDArray is not initialized";
CHECK_GE(shape_.Size(), shape.Size())
Expand Down
4 changes: 4 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,10 @@ class MKLDNNMemory {
return size;
}

const mkldnn::memory::desc &GetDesc() const {
return desc;
}

mkldnn::memory::primitive_desc GetPrimitiveDesc() const {
return mem->get_primitive_desc();
}
Expand Down
5 changes: 2 additions & 3 deletions src/operator/operator_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -519,8 +519,7 @@ class OpSignature {
*/

#if MXNET_USE_MKLDNN == 1
void AddSign(const mkldnn::memory &mem) {
auto desc = mem.get_primitive_desc().desc();
void AddSign(const mkldnn::memory::desc &desc) {
hash = hash * 2 + desc.data.format;
eles.push_back(desc.data.format);
hash = hash * 2 + desc.data.data_type;
Expand All @@ -541,7 +540,7 @@ class OpSignature {
void AddSign(const NDArray &arr) {
#if MXNET_USE_MKLDNN == 1
if (arr.IsMKLDNNData()) {
AddSign(*(arr.GetMKLDNNData()));
AddSign(arr.GetMKLDNNDesc());
} else {
#endif
hash = hash * 2 + arr.dtype();
Expand Down

0 comments on commit 58854be

Please sign in to comment.