Skip to content

Commit

Permalink
Fix bugs in NDArray.
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-da committed Nov 7, 2017
1 parent cd53fb4 commit f5624a4
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -255,20 +255,19 @@ void NDArray::set_fresh_out_grad(bool state) const {
}

#if MXNET_USE_MKLDNN == 1
static inline bool same_shape(const TShape &shape, mkldnn_dims_t dims, int ndims) {
static inline bool same_shape(const TShape &shape, mkldnn::memory::primitive_desc pd) {
int ndims = pd.desc().data.ndims;
if (shape.ndim() != ndims)
return false;
for (int i = 0; i < ndims; i++)
if (shape[i] != dims[i])
if (shape[i] != pd.desc().data.dims[i])
return false;
return true;
}

void NDArray::Chunk::SetMKLMem(const TShape &shape, int dtype) {
if (Mkl_mem_ && same_shape(shape, Mkl_mem_->get_primitive_desc().desc().data.dims,
Mkl_mem_->get_primitive_desc().desc().data.ndims)) {
if (Mkl_mem_ && same_shape(shape, Mkl_mem_->get_primitive_desc()))
return;
}

mkldnn::memory::dims dims(shape.ndim());
for (size_t i = 0; i < dims.size(); i++)
Expand Down Expand Up @@ -304,6 +303,10 @@ static int GetTypeSize(int dtype) {

std::shared_ptr<const mkldnn::memory> NDArray::GetMKLDNNData(
const mkldnn::memory::primitive_desc &desc) const {
// If the array size doesn't match, we should reset MKL memory.
if (ptr_->Mkl_mem_ && !same_shape(shape(), ptr_->Mkl_mem_->get_primitive_desc()))
ptr_->Mkl_mem_ = nullptr;

if (desc.get_size() != shape().Size() * GetTypeSize(dtype_)) {
LOG(FATAL) << "The size of NDArray doesn't match the requested MKLDNN memory desc";
return nullptr;
Expand All @@ -319,6 +322,10 @@ std::shared_ptr<const mkldnn::memory> NDArray::GetMKLDNNData(

std::shared_ptr<const mkldnn::memory> NDArray::GetMKLDNNDataReorder(
const mkldnn::memory::primitive_desc &desc) const {
// If the array size doesn't match, we should reset MKL memory.
if (ptr_->Mkl_mem_ && !same_shape(shape(), ptr_->Mkl_mem_->get_primitive_desc()))
ptr_->Mkl_mem_ = nullptr;

if (desc.get_size() != shape().Size() * GetTypeSize(dtype_)) {
LOG(FATAL) << "The size of NDArray doesn't match the requested MKLDNN memory desc";
return nullptr;
Expand Down Expand Up @@ -388,6 +395,7 @@ void NDArray::SetTBlob() const {
} else if (stype == kMKLDNNStorage) {
// TODO we may really need to convert format.
CHECK_EQ(byte_offset_, 0);
ptr_->SetMKLMem(shape_, dtype_);
dptr = (char *) ptr_->Mkl_mem_->get_data_handle();
#endif
} else {
Expand Down

0 comments on commit f5624a4

Please sign in to comment.