diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index 4ba13ca6498a..5de42e19a657 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -694,9 +694,13 @@ class NDArray { /* * Create NDArray from mkldnn memory. * mkldnn_mem The mkldnn memory to be managed. - * static_data If true, mkldnn memory won't be freed on destruction. */ - explicit NDArray(const mkldnn::memory *mkldnn_mem, bool static_data = true); + explicit NDArray(const std::shared_ptr &mkldnn_mem); + /* + * Create NDArray from mkldnn memory descriptor. + * mem_pd The mkldnn memory descriptor to be created. + */ + explicit NDArray(mkldnn::memory::primitive_desc mem_pd); /* * Test if the data is stored in one of special MKLDNN format. */ @@ -776,7 +780,7 @@ class NDArray { /*! * \ Fix mkldnn memory descriptor mismatch from NDArray. */ - void UpdateMKLDNNMemDesc(); + void UpdateMKLDNNMemDesc(mkldnn::memory::format format); #endif /*! diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 251bfb3f0e1f..0f0fed24d4e6 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -168,16 +168,28 @@ nnvm::Symbol NDArray::get_autograd_symbol() const { #if MXNET_USE_MKLDNN == 1 -NDArray::NDArray(const mkldnn::memory *mkldnn_mem, bool static_data) +NDArray::NDArray(mkldnn::memory::primitive_desc mem_pd) : storage_type_(kDefaultStorage), entry_({nullptr, 0, 0}) { - auto mem_pd = mkldnn_mem->get_primitive_desc(); auto mem_desc = mem_pd.desc(); shape_ = TShape(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims); dtype_ = get_mxnet_type(mem_desc.data.data_type); - auto data = TBlob(mkldnn_mem->get_data_handle(), shape_, cpu::kDevMask, dtype_); - ptr_ = std::make_shared(data, 0); + ptr_ = std::make_shared(shape_, Context::CPU(), true, dtype_); + ptr_->CheckAndAlloc(mem_pd.get_size()); ptr_->mkl_mem_ = std::make_shared(mem_pd, ptr_->shandle.dptr); - ptr_->static_data = static_data; +} + +NDArray::NDArray(const std::shared_ptr &mkldnn_mem) + : storage_type_(kDefaultStorage), entry_({nullptr, 0, 0}) { + auto mem_pd = mkldnn_mem->get_primitive_desc(); + auto mem_desc = mem_pd.desc(); + shape_ = TShape(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims); + dtype_ = get_mxnet_type(mem_desc.data.data_type); + ptr_ = std::make_shared(shape_, Context::CPU(), true, dtype_); + ptr_->shandle.dptr = mkldnn_mem->get_data_handle(); + ptr_->shandle.size = mem_pd.get_size(); + ptr_->delay_alloc = false; + ptr_->mkl_mem_ = std::make_shared(mkldnn_mem); + ptr_->static_data = true; } NDArray NDArray::MKLDNNDataReshape(const TShape &shape) const { @@ -710,19 +722,16 @@ mkldnn::memory *NDArray::CreateMKLDNNData(const mkldnn::memory::primitive_desc & return ptr_->mkl_mem_->GetRaw(); } -void NDArray::UpdateMKLDNNMemDesc() { +void NDArray::UpdateMKLDNNMemDesc(mkldnn::memory::format format) { const mkldnn::memory *mem = GetMKLDNNData(); auto mem_desc = mem->get_primitive_desc().desc(); auto this_dtype = get_mkldnn_type(dtype()); - if (this_dtype != mem_desc.data.data_type) { - mkldnn::memory::desc data_md( - mkldnn::memory::dims(mem_desc.data.dims, - mem_desc.data.dims + mem_desc.data.ndims), - this_dtype, static_cast(mem_desc.data.format)); - mkldnn::memory::primitive_desc pd(data_md, CpuEngine::Get()->get_engine()); - ptr_->mkl_mem_.reset(new MKLDNNMemory(pd, ptr_->shandle.dptr)); - MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem()); - } + mkldnn::memory::desc data_md( + mkldnn::memory::dims(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims), + this_dtype, format); + mkldnn::memory::primitive_desc pd(data_md, CpuEngine::Get()->get_engine()); + ptr_->mkl_mem_.reset(new MKLDNNMemory(pd, ptr_->shandle.dptr)); + MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem()); } #endif diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv.cc b/src/operator/subgraph/mkldnn/mkldnn_conv.cc index 65e0e5c4b27a..6a8feaedec87 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_conv.cc @@ -261,8 +261,12 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, } if (!inplace_) { auto in_mkl_mem = inputs[in_sum].GetMKLDNNData(); - const_cast(outputs[kOut]).CopyFrom(*in_mkl_mem); - output = NDArray(outputs[kOut].GetMKLDNNData()); + auto out_mkl_mem = outputs[kOut].GetMKLDNNData(); + mkldnn_mem_ptr tmp_mem( + new mkldnn::memory(in_mkl_mem->get_primitive_desc(), out_mkl_mem->get_data_handle())); + MKLDNNStream::Get()->RegisterMem(tmp_mem); + mxnet::MKLDNNCopy(*in_mkl_mem, tmp_mem.get()); + output = NDArray(tmp_mem); } } @@ -388,7 +392,9 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, if (mkldnn_param.with_sum) { auto out = const_cast(outputs[kOut]); - out.UpdateMKLDNNMemDesc(); + auto format = static_cast( + fwd_->fwd_pd.dst_primitive_desc().desc().data.format); + out.UpdateMKLDNNMemDesc(format); } }