Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

mkldnn s8 conv API change for master #13903

Merged
merged 1 commit into from
Jan 25, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -694,9 +694,13 @@ class NDArray {
/*
* Create NDArray from mkldnn memory.
* mkldnn_mem The mkldnn memory to be managed.
* static_data If true, mkldnn memory won't be freed on destruction.
*/
explicit NDArray(const mkldnn::memory *mkldnn_mem, bool static_data = true);
explicit NDArray(const std::shared_ptr<mkldnn::memory> &mkldnn_mem);
/*
* Create NDArray from mkldnn memory descriptor.
* mem_pd The mkldnn memory descriptor to be created.
*/
explicit NDArray(mkldnn::memory::primitive_desc mem_pd);
/*
* Test if the data is stored in one of special MKLDNN format.
*/
Expand Down Expand Up @@ -776,7 +780,7 @@ class NDArray {
/*!
* \ Fix mkldnn memory descriptor mismatch from NDArray.
*/
void UpdateMKLDNNMemDesc();
void UpdateMKLDNNMemDesc(mkldnn::memory::format format);
#endif

/*!
Expand Down
39 changes: 24 additions & 15 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,16 +168,28 @@ nnvm::Symbol NDArray::get_autograd_symbol() const {

#if MXNET_USE_MKLDNN == 1

NDArray::NDArray(const mkldnn::memory *mkldnn_mem, bool static_data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems the old impl is incorrect. what if static_data=false?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If static_data=false, then NDArray will take responsible to free this memory at deconstruction. But it has a double free risk that if mkldnn_mem be freed outside. I agreed that it's not a good design, so I changed it to std::shared_ptrmkldnn::memory to avoid the chance of double free.

NDArray::NDArray(mkldnn::memory::primitive_desc mem_pd)
: storage_type_(kDefaultStorage), entry_({nullptr, 0, 0}) {
auto mem_pd = mkldnn_mem->get_primitive_desc();
auto mem_desc = mem_pd.desc();
shape_ = TShape(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims);
dtype_ = get_mxnet_type(mem_desc.data.data_type);
auto data = TBlob(mkldnn_mem->get_data_handle(), shape_, cpu::kDevMask, dtype_);
ptr_ = std::make_shared<Chunk>(data, 0);
ptr_ = std::make_shared<Chunk>(shape_, Context::CPU(), true, dtype_);
ptr_->CheckAndAlloc(mem_pd.get_size());
ptr_->mkl_mem_ = std::make_shared<MKLDNNMemory>(mem_pd, ptr_->shandle.dptr);
ptr_->static_data = static_data;
}

NDArray::NDArray(const std::shared_ptr<mkldnn::memory> &mkldnn_mem)
: storage_type_(kDefaultStorage), entry_({nullptr, 0, 0}) {
auto mem_pd = mkldnn_mem->get_primitive_desc();
auto mem_desc = mem_pd.desc();
shape_ = TShape(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims);
dtype_ = get_mxnet_type(mem_desc.data.data_type);
ptr_ = std::make_shared<Chunk>(shape_, Context::CPU(), true, dtype_);
ptr_->shandle.dptr = mkldnn_mem->get_data_handle();
ptr_->shandle.size = mem_pd.get_size();
ptr_->delay_alloc = false;
ptr_->mkl_mem_ = std::make_shared<MKLDNNMemory>(mkldnn_mem);
ptr_->static_data = true;
}

NDArray NDArray::MKLDNNDataReshape(const TShape &shape) const {
Expand Down Expand Up @@ -710,19 +722,16 @@ mkldnn::memory *NDArray::CreateMKLDNNData(const mkldnn::memory::primitive_desc &
return ptr_->mkl_mem_->GetRaw();
}

void NDArray::UpdateMKLDNNMemDesc() {
void NDArray::UpdateMKLDNNMemDesc(mkldnn::memory::format format) {
const mkldnn::memory *mem = GetMKLDNNData();
auto mem_desc = mem->get_primitive_desc().desc();
auto this_dtype = get_mkldnn_type(dtype());
if (this_dtype != mem_desc.data.data_type) {
mkldnn::memory::desc data_md(
mkldnn::memory::dims(mem_desc.data.dims,
mem_desc.data.dims + mem_desc.data.ndims),
this_dtype, static_cast<mkldnn::memory::format>(mem_desc.data.format));
mkldnn::memory::primitive_desc pd(data_md, CpuEngine::Get()->get_engine());
ptr_->mkl_mem_.reset(new MKLDNNMemory(pd, ptr_->shandle.dptr));
MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem());
}
mkldnn::memory::desc data_md(
mkldnn::memory::dims(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims),
this_dtype, format);
mkldnn::memory::primitive_desc pd(data_md, CpuEngine::Get()->get_engine());
ptr_->mkl_mem_.reset(new MKLDNNMemory(pd, ptr_->shandle.dptr));
MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem());
}
#endif

Expand Down
12 changes: 9 additions & 3 deletions src/operator/subgraph/mkldnn/mkldnn_conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,12 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx,
}
if (!inplace_) {
auto in_mkl_mem = inputs[in_sum].GetMKLDNNData();
const_cast<NDArray &>(outputs[kOut]).CopyFrom(*in_mkl_mem);
output = NDArray(outputs[kOut].GetMKLDNNData());
auto out_mkl_mem = outputs[kOut].GetMKLDNNData();
mkldnn_mem_ptr tmp_mem(
new mkldnn::memory(in_mkl_mem->get_primitive_desc(), out_mkl_mem->get_data_handle()));
MKLDNNStream::Get()->RegisterMem(tmp_mem);
mxnet::MKLDNNCopy(*in_mkl_mem, tmp_mem.get());
output = NDArray(tmp_mem);
}
}

Expand Down Expand Up @@ -388,7 +392,9 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx,

if (mkldnn_param.with_sum) {
auto out = const_cast<NDArray &>(outputs[kOut]);
out.UpdateMKLDNNMemDesc();
auto format = static_cast<mkldnn::memory::format>(
fwd_->fwd_pd.dst_primitive_desc().desc().data.format);
out.UpdateMKLDNNMemDesc(format);
}
}

Expand Down