diff --git a/include/mxnet/base.h b/include/mxnet/base.h index 02dd204fc77d..897d930915a5 100644 --- a/include/mxnet/base.h +++ b/include/mxnet/base.h @@ -551,7 +551,7 @@ inline std::ostream& operator<<(std::ostream &out, const Context &ctx) { #define ADD_FILELINE "\n\nDefined in " __FILE__ ":L" STRINGIZE(__LINE__) -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 constexpr size_t kMKLDNNAlign = 64; #endif diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index 176aa0aaa197..fc4375b493a7 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -37,7 +37,7 @@ #include #include #include -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include #endif #include "./base.h" @@ -699,7 +699,7 @@ class NDArray { ptr_->CheckAndAllocAuxData(i, aux_shape); } -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 /* * Create NDArray from mkldnn memory. * mkldnn_mem The mkldnn memory to be managed. @@ -709,7 +709,7 @@ class NDArray { * Create NDArray from mkldnn memory descriptor. * mem_pd The mkldnn memory descriptor to be created. */ - explicit NDArray(mkldnn::memory::primitive_desc mem_pd); + explicit NDArray(const mkldnn::memory::desc &md); /* * Test if the data is stored in one of special MKLDNN format. */ @@ -737,15 +737,14 @@ class NDArray { * This function returns mkldnn::memory with the given primitive_desc * as long as the array size meets the required size in the given primitive_desc. */ - const mkldnn::memory *GetMKLDNNData( - const mkldnn::memory::primitive_desc &desc) const; + const mkldnn::memory *GetMKLDNNData(const mkldnn::memory::desc &md) const; /* * This function returns mkldnn::memory with the given primitive_desc. * The returned mkldnn::memory will have the same physical layout as * the given primitive_desc. */ const mkldnn::memory *GetMKLDNNDataReorder( - const mkldnn::memory::primitive_desc &desc) const; + const mkldnn::memory::desc &md) const; /* * This function copies data from mkldnn memory. @@ -755,8 +754,7 @@ class NDArray { * This function allocates memory for array and creates mkldnn memory * with the specified format. */ - mkldnn::memory *CreateMKLDNNData( - const mkldnn::memory::primitive_desc &desc); + mkldnn::memory *CreateMKLDNNData(const mkldnn::memory::desc &md); /* * These are the async version of the methods above. @@ -764,7 +762,7 @@ class NDArray { * the array are complete. */ void Reorder2DefaultAsync(); - void MKLDNNDataReorderAsync(const mkldnn::memory::primitive_desc &desc); + void MKLDNNDataReorderAsync(const mkldnn::memory::desc &md); /* * This creates a new NDArray with the reordered data. @@ -789,7 +787,7 @@ class NDArray { /*! * \ Fix mkldnn memory descriptor mismatch from NDArray. */ - void UpdateMKLDNNMemDesc(mkldnn::memory::format format); + void UpdateMKLDNNMemDesc(mkldnn::memory::format_tag format); #endif /*! @@ -827,7 +825,7 @@ class NDArray { */ std::vector aux_handles; -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 /*! This is created when data is stored in MKLDNN format. */ std::shared_ptr mkl_mem_; @@ -986,7 +984,7 @@ class NDArray { inline void CheckAndAlloc(void) { if (delay_alloc) { shandle = Storage::Get()->Alloc(shandle.size, shandle.ctx); -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 mkl_mem_ = nullptr; #endif delay_alloc = false; @@ -1001,7 +999,7 @@ class NDArray { dbytes = std::max(dbytes, static_cast(shandle.size)); if (delay_alloc) { shandle = Storage::Get()->Alloc(dbytes, shandle.ctx); -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 mkl_mem_ = nullptr; #endif delay_alloc = false; @@ -1010,7 +1008,7 @@ class NDArray { Storage::Get()->Free(shandle); // init storage shandle = Storage::Get()->Alloc(dbytes, shandle.ctx); -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 mkl_mem_ = nullptr; #endif } @@ -1046,7 +1044,7 @@ class NDArray { // and allocate new storage void CheckAndAllocData(const mxnet::TShape &shape, int dtype); -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 // Have MKL memory reference to the data in the default storage // or create memory for MKLDNN. void SetMKLMem(const mxnet::TShape &shape, int dtype); @@ -1054,7 +1052,7 @@ class NDArray { // save the result in shandle. void Reorder2Default(); // Reroder data to a specified layout. - void MKLDNNDataReorder(const mkldnn::memory::primitive_desc &desc); + void MKLDNNDataReorder(const mkldnn::memory::desc &md); bool IsMKLDNN() const; bool IsDefault() const; #endif diff --git a/mkldnn.mk b/mkldnn.mk index 64c49598f9dd..2d09ebc7bff5 100644 --- a/mkldnn.mk +++ b/mkldnn.mk @@ -35,12 +35,14 @@ mkldnn_build: $(MKLDNN_LIBFILE) $(MKLDNN_LIBFILE): mkdir -p $(MKLDNNROOT) + mkdir -p $(MKLDNNROOT)/lib cmake $(MKLDNN_SUBMODDIR) -DCMAKE_INSTALL_PREFIX=$(MKLDNNROOT) -B$(MKLDNN_BUILDDIR) -DMKLDNN_ARCH_OPT_FLAGS="" -DMKLDNN_BUILD_TESTS=OFF -DMKLDNN_BUILD_EXAMPLES=OFF -DMKLDNN_ENABLE_JIT_PROFILING=OFF $(MAKE) -C $(MKLDNN_BUILDDIR) VERBOSE=1 $(MAKE) -C $(MKLDNN_BUILDDIR) install mkdir -p $(MXNET_LIBDIR) if [ -f "$(MKLDNN_LIB64FILE)" ]; then \ cp $(MKLDNNROOT)/lib64/libmkldnn* $(MXNET_LIBDIR); \ + cp $(MKLDNNROOT)/lib64/libmkldnn* $(MKLDNNROOT)/lib/; \ else \ cp $(MKLDNNROOT)/lib/libmkldnn* $(MXNET_LIBDIR); \ fi diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index cc21dd242a2d..97daa29492f6 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -31,9 +31,6 @@ #include #include #include -#if MXNET_USE_MKLDNN == 1 -#include -#endif #include "./ndarray_function.h" #include "../common/utils.h" #include "../operator/tensor/matrix_op-inl.h" @@ -106,7 +103,7 @@ void NDArray::SetShapeFromChunk() { struct ChunkMem { Storage::Handle h; std::vector aux_h; -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 std::shared_ptr mem; #endif }; @@ -116,14 +113,14 @@ NDArray::Chunk::~Chunk() { ChunkMem mem; mem.h = this->shandle; mem.aux_h = this->aux_handles; -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 // We want to delete mkldnn memory after deleting the variable. mem.mem = this->mkl_mem_; #endif if (auto engine = engine_ref_.lock()) { engine->DeleteVariable([mem, skip_free](RunContext s) { if (skip_free == false) { -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 if (mem.mem) { CHECK_LE(mem.mem->GetSize(), mem.h.size); CHECK_EQ(mem.mem->GetDataHandle(), mem.h.dptr); @@ -147,7 +144,7 @@ void NDArray::Chunk::CheckAndAllocData(const mxnet::TShape &shape, int dtype) { Storage::Get()->Free(shandle); // init storage shandle = Storage::Get()->Alloc(dbytes, ctx); -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 mkl_mem_ = nullptr; #endif } @@ -175,27 +172,25 @@ nnvm::Symbol NDArray::get_autograd_symbol() const { return ret; } -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 -NDArray::NDArray(mkldnn::memory::primitive_desc mem_pd) +NDArray::NDArray(const mkldnn::memory::desc &md) : storage_type_(kDefaultStorage), entry_(nullptr) { - auto mem_desc = mem_pd.desc(); - shape_ = mxnet::TShape(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims); - dtype_ = get_mxnet_type(mem_desc.data.data_type); + shape_ = mxnet::TShape(md.data.dims, md.data.dims + md.data.ndims); + dtype_ = get_mxnet_type(md.data.data_type); ptr_ = std::make_shared(shape_, Context::CPU(), true, dtype_); - ptr_->CheckAndAlloc(mem_pd.get_size()); - ptr_->mkl_mem_ = std::make_shared(mem_pd, ptr_->shandle.dptr); + ptr_->CheckAndAlloc(md.get_size()); + ptr_->mkl_mem_ = std::make_shared(md, ptr_->shandle.dptr); } NDArray::NDArray(const std::shared_ptr &mkldnn_mem) : storage_type_(kDefaultStorage), entry_(nullptr) { - auto mem_pd = mkldnn_mem->get_primitive_desc(); - auto mem_desc = mem_pd.desc(); + auto mem_desc = mkldnn_mem->get_desc(); shape_ = mxnet::TShape(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims); dtype_ = get_mxnet_type(mem_desc.data.data_type); ptr_ = std::make_shared(shape_, Context::CPU(), true, dtype_); ptr_->shandle.dptr = mkldnn_mem->get_data_handle(); - ptr_->shandle.size = mem_pd.get_size(); + ptr_->shandle.size = mem_desc.get_size(); ptr_->delay_alloc = false; ptr_->mkl_mem_ = std::make_shared(mkldnn_mem); ptr_->static_data = true; @@ -214,22 +209,24 @@ NDArray NDArray::MKLDNNDataReshape(const mxnet::TShape &shape) const { NDArray ret(shape, ctx(), true, dtype()); // We shouldn't submit the reorder primitive here because submit will // be called in operators. - mkldnn_memory_format_t format = ptr_->mkl_mem_->GetDefaultFormat(); - CHECK_NE(format, ptr_->mkl_mem_->GetFormat()); - mkldnn::memory::primitive_desc def_pd = ptr_->mkl_mem_->GetPrimitiveDesc(format); - mkldnn::memory *def_mem = TmpMemMgr::Get()->Alloc(def_pd); + mkldnn_format_tag_t format = ptr_->mkl_mem_->GetDefaultFormat(); + CHECK(ptr_->IsMKLDNN()); + mkldnn::memory::desc def_desc = ptr_->mkl_mem_->GetDesc(format); + mkldnn::memory *def_mem = TmpMemMgr::Get()->Alloc(def_desc); MKLDNNStream *stream = MKLDNNStream::Get(); std::shared_ptr curr_mem = ptr_->mkl_mem_->GetMem(); stream->RegisterMem(curr_mem); - stream->RegisterPrim(mkldnn::reorder(*curr_mem, *def_mem)); + std::unordered_map args({{MKLDNN_ARG_FROM, *curr_mem}, + {MKLDNN_ARG_TO, *def_mem}}); + stream->RegisterPrimArgs(mkldnn::reorder(*curr_mem, *def_mem), args); // def_mem points to a memory region in the temp space. It's only valid // inside an operator. As such, the returned NDArray can only be valid // inside an operator and the shared point doesn't need to do anything // when it's destroyed. - auto tmp = std::shared_ptr(def_mem, [](mkldnn::memory *mem){}); + auto tmp = std::shared_ptr(def_mem, [](mkldnn::memory *mem) {}); ret.ptr_->mkl_mem_.reset(new MKLDNNMemory(tmp)); ret.ptr_->shandle.dptr = def_mem->get_data_handle(); - ret.ptr_->shandle.size = def_mem->get_primitive_desc().get_size(); + ret.ptr_->shandle.size = def_mem->get_desc().get_size(); ret.ptr_->delay_alloc = false; ret.ptr_->static_data = true; ret.byte_offset_ = byte_offset_; @@ -237,7 +234,6 @@ NDArray NDArray::MKLDNNDataReshape(const mxnet::TShape &shape) const { return ret; } } - #endif NDArray NDArray::Reshape(const mxnet::TShape &shape) const { @@ -391,7 +387,7 @@ void NDArray::set_fresh_out_grad(bool state) const { info.fresh_out_grad = state; } -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 bool NDArray::Chunk::IsMKLDNN() const { if (storage_type != kDefaultStorage) @@ -415,57 +411,56 @@ void NDArray::Chunk::Reorder2Default() { if (mkl_mem_ == nullptr) return; - mkldnn_memory_format_t format = mkl_mem_->GetDefaultFormat(); - if (format == mkl_mem_->GetFormat()) + if (IsDefault()) return; - mkldnn::memory::primitive_desc def_pd = mkl_mem_->GetPrimitiveDesc(format); - mkldnn_mem_ptr def_mem(new mkldnn::memory(def_pd)); + mkldnn_format_tag_t format = mkl_mem_->GetDefaultFormat(); + mkldnn::memory::desc def_desc = mkl_mem_->GetDesc(format); + mkldnn_mem_ptr def_mem(new mkldnn::memory(def_desc, CpuEngine::Get()->get_engine())); mkl_mem_->ReorderTo(def_mem.get()); - CHECK(shandle.size >= def_pd.get_size()); - CheckAndAlloc(def_pd.get_size()); + CHECK(shandle.size >= def_desc.get_size()); + CheckAndAlloc(def_desc.get_size()); // TODO(zhengda) We need to avoid memory copy here. - memcpy(shandle.dptr, def_mem->get_data_handle(), def_pd.get_size()); + memcpy(shandle.dptr, def_mem->get_data_handle(), def_desc.get_size()); mkl_mem_ = nullptr; } -void NDArray::Chunk::MKLDNNDataReorder(const mkldnn::memory::primitive_desc &pd) { +void NDArray::Chunk::MKLDNNDataReorder(const mkldnn::memory::desc &md) { // If the memory already uses the specified layout, don't do anything. - if (mkl_mem_ != nullptr && mkl_mem_->SameFormat(pd)) + if (mkl_mem_ != nullptr && mkl_mem_->SameFormat(md)) return; - mkldnn::memory::primitive_desc _pd = pd; - mkldnn::memory::desc _desc = _pd.desc(); - mkldnn_memory_format_t def_format = GetDefaultFormat(_desc); + // If the memory is default, don't do anything. - if (def_format == _desc.data.format && IsDefault()) + if (!mxnet::IsMKLDNN(md) && IsDefault()) return; - // If the specified layout is default, we should use Reorder2Default. - if (def_format == _desc.data.format) { + if (!mxnet::IsMKLDNN(md)) { + // If the specified layout is default, we should use Reorder2Default. Reorder2Default(); return; } + auto engine = CpuEngine::Get()->get_engine(); + mkldnn::stream s(engine); - std::shared_ptr new_mem(new mkldnn::memory(pd)); + std::shared_ptr new_mem(new mkldnn::memory(md, engine)); std::shared_ptr old_mem; if (IsDefault()) { - mkldnn::memory::primitive_desc def_pd = GetPrimitiveDesc(pd, def_format); - old_mem.reset(new mkldnn::memory(def_pd, shandle.dptr)); + mkldnn_format_tag_t def_format = GetDefaultFormat(md); + mkldnn::memory::desc def_desc = GetDesc(md, def_format); + old_mem.reset(new mkldnn::memory(def_desc, engine, shandle.dptr)); } else { old_mem = this->mkl_mem_->GetMem(); } - CHECK(old_mem->get_primitive_desc().desc().data.ndims == _desc.data.ndims); + CHECK(old_mem->get_desc().data.ndims == md.data.ndims); // This may be called in MKLDNN operators. We can't use MKLDNNStream here. - std::vector net; - net.push_back(mkldnn::reorder(*old_mem, *new_mem)); - mkldnn::stream(mkldnn::stream::kind::eager).submit(net).wait(); + mkldnn::reorder(*old_mem, *new_mem).execute(s, *old_mem, *new_mem); - CHECK(shandle.size >= pd.get_size()); - CheckAndAlloc(pd.get_size()); + CHECK(shandle.size >= md.get_size()); + CheckAndAlloc(md.get_size()); // TODO(zhengda) We need to avoid memory copy here. - memcpy(shandle.dptr, new_mem->get_data_handle(), pd.get_size()); - mkl_mem_.reset(new MKLDNNMemory(pd, shandle.dptr)); + memcpy(shandle.dptr, new_mem->get_data_handle(), md.get_size()); + mkl_mem_.reset(new MKLDNNMemory(md, shandle.dptr)); } void NDArray::Chunk::SetMKLMem(const mxnet::TShape &shape, int dtype) { @@ -486,43 +481,35 @@ void NDArray::Chunk::SetMKLMem(const mxnet::TShape &shape, int dtype) { } else { LOG(FATAL) << "MKLDNN doesn't support " << shape.ndim() << " dimensions"; } - mkldnn::memory::format layout = mkldnn::memory::format::format_undef; + mkldnn::memory::format_tag layout = mkldnn::memory::format_tag::undef; switch (dims.size()) { - case 1: layout = mkldnn::memory::format::x; break; - case 2: layout = mkldnn::memory::format::nc; break; - case 3: layout = mkldnn::memory::format::ncw; break; - case 4: layout = mkldnn::memory::format::nchw; break; - // This isn't the right layout when the data has 5 dimensions in MXNet. - // MXNet interprets 5 dimensions as ncdhw, but MKLDNN doesn't have - // a corresponding format. - case 5: layout = mkldnn::memory::format::goihw; break; + case 1: layout = mkldnn::memory::format_tag::a; break; + case 2: layout = mkldnn::memory::format_tag::ab; break; + case 3: layout = mkldnn::memory::format_tag::abc; break; + case 4: layout = mkldnn::memory::format_tag::abcd; break; + case 5: layout = mkldnn::memory::format_tag::abcde; break; + default: + LOG(FATAL) << "Not implemented dimension (" << dims.size() << ") for MKLDNN"; } mkldnn::memory::desc data_md{dims, get_mkldnn_type(dtype), layout}; - auto cpu_engine = CpuEngine::Get()->get_engine(); if (shandle.dptr == nullptr) { CHECK(delay_alloc); CheckAndAlloc(); } - mkldnn::memory::primitive_desc pd(data_md, cpu_engine); - CHECK(shandle.size >= pd.get_size()); - mkl_mem_.reset(new MKLDNNMemory(pd, shandle.dptr)); + CHECK(shandle.size >= data_md.get_size()); + mkl_mem_.reset(new MKLDNNMemory(data_md, shandle.dptr)); } -const mkldnn::memory *NDArray::GetMKLDNNData( - const mkldnn::memory::primitive_desc &desc) const { +const mkldnn::memory *NDArray::GetMKLDNNData(const mkldnn::memory::desc &desc) const { if (desc.get_size() != shape().Size() * GetTypeSize(dtype_)) { LOG(FATAL) << "The size of NDArray doesn't match the requested MKLDNN memory desc"; return nullptr; } const mkldnn::memory *mem = GetMKLDNNData(); - mkldnn::memory::primitive_desc _desc = desc; - mkldnn::memory::desc desc1 = mem->get_primitive_desc().desc(); - mkldnn::memory::desc desc2 = _desc.desc(); + mkldnn::memory::desc desc1 = mem->get_desc(); // The MKL memory has the same format and shape as required, // or both use the default format, we can return the MKL memory. - if (mem->get_primitive_desc() == desc - || (desc1.data.format == GetDefaultFormat(desc1) - && desc2.data.format == GetDefaultFormat(desc2))) { + if (desc1 == desc || ((!mxnet::IsMKLDNN(desc1)) && (!mxnet::IsMKLDNN(desc)))) { return GetMKLDNNExact(mem, desc); } else { return nullptr; @@ -530,8 +517,8 @@ const mkldnn::memory *NDArray::GetMKLDNNData( } const mkldnn::memory *NDArray::GetMKLDNNDataReorder( - const mkldnn::memory::primitive_desc &new_pd) const { - if (new_pd.get_size() != shape().Size() * GetTypeSize(dtype_)) { + const mkldnn::memory::desc &new_desc) const { + if (new_desc.get_size() != shape().Size() * GetTypeSize(dtype_)) { LOG(FATAL) << "The size of NDArray doesn't match the requested MKLDNN memory desc"; return nullptr; } @@ -540,24 +527,24 @@ const mkldnn::memory *NDArray::GetMKLDNNDataReorder( const mkldnn::memory *mem = GetMKLDNNData(); // If the memory descriptor matches, it's easy. MKLDNNStream *stream = MKLDNNStream::Get(); - if (mem->get_primitive_desc() == new_pd) { - return GetMKLDNNExact(mem, new_pd); + if (mem->get_desc() == new_desc) { + return GetMKLDNNExact(mem, new_desc); } - mkldnn::memory::primitive_desc _pd = new_pd; - mkldnn::memory::desc desc1 = mem->get_primitive_desc().desc(); - mkldnn::memory::desc desc2 = _pd.desc(); + mkldnn::memory::desc desc1 = mem->get_desc(); + mkldnn::memory::desc desc2 = new_desc; // Now we need to determine if we should reorder the memory. // If both use the default formats, we think we don't need to reorder. - if (desc1.data.format == GetDefaultFormat(desc1) && - desc2.data.format == GetDefaultFormat(desc2)) { - mkldnn_mem_ptr ret(new mkldnn::memory(new_pd, mem->get_data_handle())); + if ((!mxnet::IsMKLDNN(desc1)) && (!mxnet::IsMKLDNN(desc2))) { + mkldnn_mem_ptr ret(new mkldnn::memory(new_desc, + CpuEngine::Get()->get_engine(), mem->get_data_handle())); stream->RegisterMem(ret); return ret.get(); } else if (same_shape(desc1, desc2)) { // If they have the same shape, we can reorder data directly. - mkldnn::memory *ret = TmpMemMgr::Get()->Alloc(new_pd); - stream->RegisterPrim(mkldnn::reorder(*mem, *ret)); + mkldnn::memory *ret = TmpMemMgr::Get()->Alloc(new_desc); + std::unordered_map args({{MKLDNN_ARG_FROM, *mem }, {MKLDNN_ARG_TO, *ret}}); + stream->RegisterPrimArgs(mkldnn::reorder(*mem, *ret), args); return ret; } else { // If they have different shapes, we need to reshape the array first. @@ -568,11 +555,13 @@ const mkldnn::memory *NDArray::GetMKLDNNDataReorder( required_shape[i] = desc2.data.dims[i]; NDArray reshaped = MKLDNNDataReshape(required_shape); const mkldnn::memory *ret = reshaped.GetMKLDNNData(); - if (ret->get_primitive_desc() == new_pd) { - return GetMKLDNNExact(ret, new_pd); + if (ret->get_desc() == new_desc) { + return GetMKLDNNExact(ret, new_desc); } else { - mkldnn::memory *ret2 = TmpMemMgr::Get()->Alloc(new_pd); - stream->RegisterPrim(mkldnn::reorder(*ret, *ret2)); + mkldnn::memory *ret2 = TmpMemMgr::Get()->Alloc(new_desc); + std::unordered_map args({{MKLDNN_ARG_FROM, *ret}, + {MKLDNN_ARG_TO, *ret2}}); + stream->RegisterPrimArgs(mkldnn::reorder(*ret, *ret2), args); return ret2; } } @@ -583,18 +572,18 @@ NDArray NDArray::Reorder2Default() const { if (ptr_->mkl_mem_ == nullptr) return *this; - mkldnn_memory_format_t format = ptr_->mkl_mem_->GetDefaultFormat(); - if (format == ptr_->mkl_mem_->GetFormat()) + if (!ptr_->mkl_mem_->IsMKLDNN()) return *this; // create new ndarray from mkldnn layout - mkldnn::memory::desc from_desc = ptr_->mkl_mem_->GetPrimitiveDesc().desc(); + mkldnn::memory::desc from_desc = ptr_->mkl_mem_->GetDesc(); mxnet::TShape tshape(from_desc.data.ndims, -1); for (int i = 0; i < from_desc.data.ndims; i++) tshape[i] = from_desc.data.dims[i]; NDArray ret(tshape, ctx(), false, dtype()); - mkldnn::memory::primitive_desc def_pd = ptr_->mkl_mem_->GetPrimitiveDesc(format); - CHECK(ret.ptr_->shandle.size >= def_pd.get_size()); - mkldnn::memory def_mem(def_pd, ret.ptr_->shandle.dptr); + mkldnn_format_tag_t format = ptr_->mkl_mem_->GetDefaultFormat(); + mkldnn::memory::desc def_desc = ptr_->mkl_mem_->GetDesc(format); + CHECK(ret.ptr_->shandle.size >= def_desc.get_size()); + mkldnn::memory def_mem(def_desc, CpuEngine::Get()->get_engine(), ret.ptr_->shandle.dptr); ptr_->mkl_mem_->ReorderTo(&def_mem); // reshape as needed ret.shape_ = shape_; @@ -615,21 +604,16 @@ void NDArray::Reorder2DefaultAsync() { FnProperty::kNormal, 0, "Reorder2Default"); } -void NDArray::MKLDNNDataReorderAsync(const mkldnn::memory::primitive_desc &desc) { +void NDArray::MKLDNNDataReorderAsync(const mkldnn::memory::desc &desc) { std::vector const_vars; std::vector mutable_vars(1, this->var()); NDArray tmp = *this; - const auto version = this->version(); Engine::Get()->PushAsync( - [tmp, version, desc](RunContext ctx, Engine::CallbackOnComplete on_complete) { - // MXNet will try to reuse NDArray from memory planning, so we need to ensure - // the NDArray is still holding the original trunk data. - if (tmp.version() == version) { - tmp.ptr_->MKLDNNDataReorder(desc); - } - on_complete(); - }, - ctx(), const_vars, mutable_vars, FnProperty::kNormal, 0, "Reorder"); + [tmp, desc](RunContext ctx, Engine::CallbackOnComplete on_complete) { + tmp.ptr_->MKLDNNDataReorder(desc); + on_complete(); + }, ctx(), const_vars, mutable_vars, + FnProperty::kNormal, 0, "Reorder"); } const mkldnn::memory *NDArray::GetMKLDNNData() const { @@ -652,14 +636,12 @@ const mkldnn::memory *NDArray::GetMKLDNNData() const { mkldnn::memory::dims dims(shape().ndim()); for (size_t i = 0; i < dims.size(); i++) dims[i] = shape()[i]; - mkldnn::memory::format cpp_format = static_cast( + mkldnn::memory::format_tag cpp_format = static_cast( GetDefaultFormat(shape().ndim())); mkldnn::memory::data_type cpp_type = get_mkldnn_type(dtype_); mkldnn::memory::desc data_md(dims, cpp_type, cpp_format); - mkldnn::memory::primitive_desc new_pd(data_md, - CpuEngine::Get()->get_engine()); - - std::shared_ptr ret(new mkldnn::memory(new_pd, off_addr)); + std::shared_ptr ret( + new mkldnn::memory(data_md, CpuEngine::Get()->get_engine(), off_addr)); MKLDNNStream::Get()->RegisterMem(ret); return ret.get(); } else { @@ -682,7 +664,7 @@ void NDArray::CopyFrom(const mkldnn::memory &mem) { if (ptr_->mkl_mem_ && ptr_->mkl_mem_->GetRaw() == &mem) return; - CHECK(mem.get_primitive_desc().get_size() == shape().Size() * GetTypeSize(dtype_)) + CHECK(mem.get_desc().get_size() == shape().Size() * GetTypeSize(dtype_)) << "The size of NDArray doesn't match the requested MKLDNN memory desc"; // If this array uses MKLDNN layout, we have to make sure it's not a view. // Otherwise, we'll have to change the layout inside the array. @@ -694,28 +676,23 @@ void NDArray::CopyFrom(const mkldnn::memory &mem) { MKLDNNCopy(mem, this_mem); } -mkldnn::memory *NDArray::CreateMKLDNNData(const mkldnn::memory::primitive_desc &desc) { +mkldnn::memory *NDArray::CreateMKLDNNData(const mkldnn::memory::desc &desc) { if (desc.get_size() != shape().Size() * GetTypeSize(dtype_)) { - LOG(FATAL) << "The size of NDArray doesn't match the requested MKLDNN memory desc"; + LOG(FATAL) << "The size of NDArray doesn't match the requested MKLDNN memory desc "; return nullptr; } - - mkldnn::memory::primitive_desc _desc = desc; - mkldnn_memory_format_t required_format = _desc.desc().data.format; - mkldnn_memory_format_t def_format = GetDefaultFormat(_desc.desc()); - // If the required format is a default format, we don't need to worry about the shape. - // If the shape isn't the same, it actually implicitly reshapes data. - if (required_format == def_format && !IsView()) { + bool isDefaultFormat = IsDefaultFormat(desc); + if (isDefaultFormat && !IsView()) { ptr_->SetMKLMem(shape_, dtype_); MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem()); return GetMKLDNNExact(ptr_->mkl_mem_->GetRaw(), desc); - } else if (required_format == def_format) { + } else if (isDefaultFormat) { ptr_->CheckAndAlloc(); CHECK(ptr_->shandle.dptr); // When this is a view and a user wants the default layout, we can simply // create a new mkldnn memory that points to the right memory. - std::shared_ptr mem(new mkldnn::memory( - desc, static_cast(ptr_->shandle.dptr) + byte_offset_)); + std::shared_ptr mem(new mkldnn::memory(desc, + CpuEngine::Get()->get_engine(), static_cast(ptr_->shandle.dptr) + byte_offset_)); MKLDNNStream::Get()->RegisterMem(mem); return mem.get(); } else if (IsView()) { @@ -729,7 +706,7 @@ mkldnn::memory *NDArray::CreateMKLDNNData(const mkldnn::memory::primitive_desc & if (ptr_->mkl_mem_) CHECK(ptr_->mkl_mem_->GetDataHandle() == ptr_->shandle.dptr); - if (ptr_->mkl_mem_ && ptr_->mkl_mem_->GetPrimitiveDesc() == desc) { + if (ptr_->mkl_mem_ && ptr_->mkl_mem_->GetDesc() == desc) { MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem()); return GetMKLDNNExact(ptr_->mkl_mem_->GetRaw(), desc); } @@ -741,15 +718,14 @@ mkldnn::memory *NDArray::CreateMKLDNNData(const mkldnn::memory::primitive_desc & return ptr_->mkl_mem_->GetRaw(); } -void NDArray::UpdateMKLDNNMemDesc(mkldnn::memory::format format) { +void NDArray::UpdateMKLDNNMemDesc(mkldnn::memory::format_tag format) { const mkldnn::memory *mem = GetMKLDNNData(); - auto mem_desc = mem->get_primitive_desc().desc(); + auto mem_desc = mem->get_desc(); auto this_dtype = get_mkldnn_type(dtype()); mkldnn::memory::desc data_md( mkldnn::memory::dims(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims), this_dtype, format); - mkldnn::memory::primitive_desc pd(data_md, CpuEngine::Get()->get_engine()); - ptr_->mkl_mem_.reset(new MKLDNNMemory(pd, ptr_->shandle.dptr)); + ptr_->mkl_mem_.reset(new MKLDNNMemory(data_md, ptr_->shandle.dptr)); MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem()); } #endif @@ -760,7 +736,7 @@ void NDArray::SetTBlob() const { char *dptr = static_cast(ptr_->shandle.dptr); auto stype = storage_type(); if (stype == kDefaultStorage) { -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 CHECK(!IsMKLDNNData()) << "We can't generate TBlob for MKLDNN data. " << "Please use Reorder2Default() to generate a new NDArray first"; #endif @@ -1102,7 +1078,7 @@ inline void CopyFromToRspImpl(const NDArray& from, const NDArray& to, RunContext // Make a copy of a dense NDArray template inline void CopyFromToDnsImpl(const NDArray& from, const NDArray& to, RunContext ctx) { -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 // If neither is MKLDNN, we can copy data normally. if (!from.IsMKLDNNData() && !to.IsMKLDNNData()) { #endif @@ -1111,7 +1087,7 @@ inline void CopyFromToDnsImpl(const NDArray& from, const NDArray& to, RunContext TBlob tmp = to.data(); ndarray::Copy(from.data(), &tmp, from.ctx(), to.ctx(), ctx); -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 } else if (SupportMKLDNN(from.dtype(), from.shape()) && SupportMKLDNN(to.dtype(), to.shape()) && from.ctx().dev_mask() == cpu::kDevMask @@ -1120,9 +1096,9 @@ inline void CopyFromToDnsImpl(const NDArray& from, const NDArray& to, RunContext // by MKLDNN. auto from_mem = from.GetMKLDNNData(); auto to_mem = to.GetMKLDNNData(); - if (from_mem->get_primitive_desc() == to_mem->get_primitive_desc()) { - size_t size = std::min(from_mem->get_primitive_desc().get_size(), - to_mem->get_primitive_desc().get_size()); + if (from_mem->get_desc() == to_mem->get_desc()) { + size_t size = std::min(from_mem->get_desc().get_size(), + to_mem->get_desc().get_size()); memcpy(to_mem->get_data_handle(), from_mem->get_data_handle(), size); } else { const_cast(to).CopyFrom(*from_mem); @@ -1638,8 +1614,9 @@ void NDArray::Save(dmlc::Stream *strm) const { this->WaitToRead(); nd_cpu = *this; #if MXNET_USE_MKLDNN == 1 - if (nd_cpu.IsMKLDNNData()) + if (nd_cpu.IsMKLDNNData()) { nd_cpu = nd_cpu.Reorder2Default(); + } #endif save_data = nd_cpu.data(); } @@ -2020,9 +1997,10 @@ void NDArray::SyncCopyToCPU(void *data, size_t size) const { this->WaitToRead(); RunContext rctx{this->ctx(), nullptr, nullptr, false}; NDArray src = *this; -#if MXNET_USE_MKLDNN == 1 - if (src.IsMKLDNNData()) +#if MXNET_USE_MKLDNN == 100 + if (src.IsMKLDNNData()) { src = this->Reorder2Default(); + } #endif ndarray::Copy(src.data(), &dst, Context::CPU(), Context::CPU(), rctx); diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index 5db9d6e5defc..85d42ff48d35 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -46,7 +46,7 @@ #ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_BASE_INL_H_ #define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_BASE_INL_H_ -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include #include #include @@ -58,7 +58,7 @@ #include "mxnet/ndarray.h" #include "mxnet/resource.h" #include "mxnet/op_attr_types.h" -using namespace mkldnn; + namespace mxnet { // ===== CpuEngine ======================================= @@ -79,7 +79,7 @@ class CpuEngine { mkldnn::engine &get_engine() { return _cpu_engine; } protected: - CpuEngine() : _cpu_engine(mkldnn::engine::cpu, 0) {} + CpuEngine() : _cpu_engine(mkldnn::engine::kind::cpu, 0) {} ~CpuEngine() {} private: @@ -92,27 +92,22 @@ struct data_type_enum {}; template <> struct data_type_enum { - enum { type = mkldnn::memory::data_type::f32 }; + enum { type = static_cast(mkldnn::memory::data_type::f32) }; }; template <> struct data_type_enum { - enum { type = mkldnn::memory::data_type::s32 }; -}; - -template <> -struct data_type_enum { - enum { type = mkldnn::memory::data_type::s16 }; + enum { type = static_cast(mkldnn::memory::data_type::s32) }; }; template <> struct data_type_enum { - enum { type = mkldnn::memory::data_type::s8 }; + enum { type = static_cast(mkldnn::memory::data_type::s8) }; }; template <> struct data_type_enum { - enum { type = mkldnn::memory::data_type::u8 }; + enum { type = static_cast(mkldnn::memory::data_type::u8) }; }; static inline bool SupportMKLDNNArray(int dtype, const mxnet::TShape &shape) { @@ -198,7 +193,7 @@ static int GetTypeSize(int dtype) { static inline size_t GetArraySize(const NDArray &arr) { if (arr.IsMKLDNNData()) { - return arr.GetMKLDNNData()->get_primitive_desc().get_size(); + return arr.GetMKLDNNData()->get_desc().get_size(); } return arr.shape().Size() * GetTypeSize(arr.dtype()); } @@ -215,7 +210,7 @@ static inline mkldnn::memory::data_type get_mkldnn_type(int dtype) { return mkldnn::memory::data_type::u8; default: LOG(FATAL) << "unknown type for MKLDNN"; - return mkldnn::memory::data_type::data_undef; + return mkldnn::memory::data_type::undef; } } @@ -253,7 +248,7 @@ inline static mkldnn::memory::desc GetMemDesc(const NDArray &arr, int dtype = -1 mkldnn::memory::dims dims(ndim); dtype = (dtype == -1) ? arr.dtype() : dtype; for (size_t i = 0; i < dims.size(); i++) dims[i] = arr.shape()[i]; - return mkldnn::memory::desc{dims, get_mkldnn_type(dtype), mkldnn::memory::format::any}; + return mkldnn::memory::desc{dims, get_mkldnn_type(dtype), mkldnn::memory::format_tag::any}; } inline static mkldnn::memory::desc GetWeightDesc(const NDArray &arr, @@ -278,7 +273,7 @@ inline static mkldnn::memory::desc GetWeightDesc(const NDArray &arr, static_cast(arr.shape()[C]), static_cast(arr.shape()[H]), static_cast(arr.shape()[W])}; } - return mkldnn::memory::desc{tz, get_mkldnn_type(dtype), mkldnn::memory::format::any}; + return mkldnn::memory::desc{tz, get_mkldnn_type(dtype), mkldnn::memory::format_tag::any}; } } @@ -344,19 +339,24 @@ class TmpMemMgr { this->est_size = 0; } - mkldnn::memory *Alloc(const mkldnn::memory::primitive_desc &pd); + mkldnn::memory *Alloc(const mkldnn::memory::desc &md); }; +typedef std::unordered_map mkldnn_args_map_t; class MKLDNNStream { - std::vector net; + std::vector > net_prim_args; // Here we hold all memory related to the operators in the stream. std::vector > mem_holder; + mkldnn::stream s; public: static MKLDNNStream *Get(); - void RegisterPrim(const mkldnn::primitive &prim) { - net.push_back(prim); + MKLDNNStream(): s(CpuEngine::Get()->get_engine()) {} + + void RegisterPrimArgs(const mkldnn::primitive &prim, + const mkldnn_args_map_t &args) { + net_prim_args.emplace_back(prim, args); } void RegisterMem(std::shared_ptr mem) { @@ -364,7 +364,7 @@ class MKLDNNStream { } bool HasOps() const { - return !net.empty(); + return !net_prim_args.empty(); } /* @@ -373,9 +373,11 @@ class MKLDNNStream { * might want to separate mkldnn execution and memory cleanup. */ void Submit(bool cleanup = true) { - if (!net.empty()) { - mkldnn::stream(mkldnn::stream::kind::eager).submit(net).wait(); - net.clear(); + if (!net_prim_args.empty()) { + for (auto &v : net_prim_args) { + v.first.execute(s, v.second); + } + net_prim_args.clear(); } if (cleanup) Cleanup(); @@ -397,18 +399,18 @@ typedef std::pair mkldnn_output_t; void MKLDNNCopy(const mkldnn::memory &mem, const mkldnn::memory* this_mem); /* - * Here we want to get MKLDNN memory whose primitive desc is exactly the same as + * Here we want to get MKLDNN memory whose desc is exactly the same as * the given one. operator== can't guarantee that. == can return true even if * the formats are different. I need to double check its format. */ static inline mkldnn::memory *GetMKLDNNExact( - const mkldnn::memory *mem, mkldnn::memory::primitive_desc desc) { - mkldnn::memory::primitive_desc src_desc = mem->get_primitive_desc(); - if (desc == src_desc && desc.desc().data.format == src_desc.desc().data.format) { + const mkldnn::memory *mem, const mkldnn::memory::desc &desc) { + mkldnn::memory::desc src_desc = mem->get_desc(); + if (desc == src_desc) { return const_cast(mem); } else { std::shared_ptr ret(new mkldnn::memory( - desc, mem->get_data_handle())); + desc, CpuEngine::Get()->get_engine(), mem->get_data_handle())); MKLDNNStream::Get()->RegisterMem(ret); return ret.get(); } @@ -426,10 +428,10 @@ static inline mkldnn::memory *GetMKLDNNExact( * the output back to the output NDArray. */ mkldnn_output_t CreateMKLDNNMem(const NDArray &out_arr, - const mkldnn::memory::primitive_desc &desc, + const mkldnn::memory::desc &desc, OpReqType req, const NDArray* in_arr = nullptr); mkldnn_output_t CreateMKLDNNWeightGrad(const NDArray &out_arr, - const mkldnn::memory::primitive_desc &desc, + const mkldnn::memory::desc &desc, OpReqType req); /* This function has to be used with one of the functions above. */ void CommitOutput(const NDArray &arr, const mkldnn_output_t &res); @@ -458,13 +460,15 @@ static inline void CreateDefaultInputs(const std::vector &arrs, const mkldnn::memory *GetWeights(const NDArray &arr, int num_groups); const mkldnn::memory *GetWeights(const NDArray &arr, - const mkldnn::memory::primitive_desc &target_pd, + const mkldnn::memory::desc &target_md, int num_groups); -mkldnn_memory_format_t GetDefaultFormat(const mkldnn::memory::desc &desc); -mkldnn_memory_format_t GetDefaultFormat(int num_dims); -mkldnn::memory::primitive_desc GetPrimitiveDesc(mkldnn::memory::primitive_desc pd, - mkldnn_memory_format_t format); +bool IsDefaultFormat(const mkldnn::memory::desc &desc); +bool IsMKLDNN(const mkldnn::memory::desc &desc); + +mkldnn_format_tag_t GetDefaultFormat(const mkldnn::memory::desc &md); +mkldnn_format_tag_t GetDefaultFormat(int num_dims); +mkldnn::memory::desc GetDesc(const mkldnn::memory::desc &md, const mkldnn_format_tag_t &format); inline bool same_shape(const mxnet::TShape &shape, const mkldnn_dims_t dims, int ndims) { if (shape.ndim() != ndims) @@ -492,7 +496,7 @@ inline bool same_shape(const mxnet::TShape &shape, int dtype, } /* - * There is a large overhead of getting mkldnn::memory::primitive_desc from + * There is a large overhead of getting mkldnn::memory::desc from * mkldnn::memory. This class is created to cache the metadata of mkldnn memory * to provide a much more lightweight method to access them. */ @@ -502,16 +506,15 @@ class MKLDNNMemory { size_t size; // The number of bytes. public: - MKLDNNMemory(mkldnn::memory::primitive_desc pd, void *addr): desc(pd.desc()) { - mem.reset(new mkldnn::memory(pd, addr)); - size = pd.get_size(); + MKLDNNMemory(mkldnn::memory::desc md, void *addr): desc(md) { + mem.reset(new mkldnn::memory(md, CpuEngine::Get()->get_engine(), addr)); + size = desc.get_size(); } explicit MKLDNNMemory(std::shared_ptr mem): desc( - mem->get_primitive_desc().desc()) { + mem->get_desc()) { this->mem = mem; - mkldnn::memory::primitive_desc pd = mem->get_primitive_desc(); - size = pd.get_size(); + size = desc.get_size(); } void SetDataHandle(void *handle) { @@ -534,28 +537,29 @@ class MKLDNNMemory { return size; } - mkldnn::memory::primitive_desc GetPrimitiveDesc() const { - return mem->get_primitive_desc(); + mkldnn::memory::desc GetDesc() const { + return mem->get_desc(); } - mkldnn::memory::primitive_desc GetPrimitiveDesc(mkldnn_memory_format_t format) const { - return mxnet::GetPrimitiveDesc(mem->get_primitive_desc(), format); + mkldnn::memory::desc GetDesc(mkldnn_format_tag_t format) const { + mkldnn::memory::dims dims(desc.data.dims, desc.data.dims + desc.data.ndims); + mkldnn::memory::data_type cpp_type = + static_cast(desc.data.data_type); + mkldnn::memory::desc data_md(dims, cpp_type, + static_cast(format)); + return data_md; } - mkldnn_memory_format_t GetDefaultFormat() const { + mkldnn_format_tag_t GetDefaultFormat() const { return mxnet::GetDefaultFormat(desc); } - mkldnn_memory_format_t GetFormat() const { - return desc.data.format; - } - bool IsMKLDNN() const { - return GetFormat() != GetDefaultFormat(); + return mxnet::IsMKLDNN(desc); } - bool SameFormat(mkldnn::memory::primitive_desc pd) const { - return mem->get_primitive_desc() == pd; + bool SameFormat(mkldnn::memory::desc md) const { + return mem->get_desc() == md; } bool SameFormat(const mxnet::TShape &shape, int dtype) const { @@ -563,9 +567,8 @@ class MKLDNNMemory { } void ReorderTo(mkldnn::memory *other) const { - std::vector net; - net.push_back(mkldnn::reorder(*mem, *other)); - mkldnn::stream(mkldnn::stream::kind::eager).submit(net).wait(); + mkldnn::stream s(CpuEngine::Get()->get_engine()); + mkldnn::reorder(*mem, *other).execute(s, *mem, *other); } }; @@ -621,7 +624,7 @@ bool MKLDNNStorageType(const nnvm::NodeAttrs &attrs, if (debug) check.CopyResult(outputs, indice); struct MKLDNNPostEltwiseParam { - mkldnn::algorithm alg = mkldnn::algorithm::algorithm_undef; + mkldnn::algorithm alg = mkldnn::algorithm::undef; float scale = 1.f; float alpha = 0.f; float beta = 1.f; diff --git a/src/operator/nn/mkldnn/mkldnn_base.cc b/src/operator/nn/mkldnn/mkldnn_base.cc index a13337b122c3..31ffbbb471c4 100644 --- a/src/operator/nn/mkldnn/mkldnn_base.cc +++ b/src/operator/nn/mkldnn/mkldnn_base.cc @@ -17,7 +17,7 @@ * under the License. */ -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include #include "./mkldnn_base-inl.h" @@ -54,27 +54,27 @@ void *AlignMem(void *mem, size_t size, size_t alignment, size_t *space) { return reinterpret_cast(addr); } -mkldnn::memory *TmpMemMgr::Alloc(const mkldnn::memory::primitive_desc &pd) { +mkldnn::memory *TmpMemMgr::Alloc(const mkldnn::memory::desc &md) { // We need to include the size of the memory used for alignment. - this->est_size += pd.get_size() + alignment; - void *mem = AlignMem(this->curr_mem, pd.get_size(), alignment, &this->curr_size); + this->est_size += md.get_size() + alignment; + void *mem = AlignMem(this->curr_mem, md.get_size(), alignment, &this->curr_size); if (mem) { // The memory is allocated from the temporary memory space in the // operator. It'll only become invalid after we exit from the operator. - mkldnn_mem_ptr ret(new mkldnn::memory(pd, mem)); + mkldnn_mem_ptr ret(new mkldnn::memory(md, CpuEngine::Get()->get_engine(), mem)); MKLDNNStream::Get()->RegisterMem(ret); CHECK_EQ(mem, mem); - this->curr_size -= pd.get_size(); - this->curr_mem = static_cast(mem) + pd.get_size(); + this->curr_size -= md.get_size(); + this->curr_mem = static_cast(mem) + md.get_size(); return ret.get(); } else { // If curr_mem has been initialized and we still reach here. It means // the current allocated memory isn't enough. if (this->curr_mem && dmlc::GetEnv("MXNET_MKLDNN_DEBUG", false)) { - LOG(WARNING) << "Allocate " << pd.get_size() + LOG(WARNING) << "Allocate " << md.get_size() << " bytes with malloc directly"; } - mkldnn_mem_ptr ret(new mkldnn::memory(pd)); + mkldnn_mem_ptr ret(new mkldnn::memory(md, CpuEngine::Get()->get_engine())); MKLDNNStream::Get()->RegisterMem(ret); return ret.get(); } @@ -82,85 +82,88 @@ mkldnn::memory *TmpMemMgr::Alloc(const mkldnn::memory::primitive_desc &pd) { void MKLDNNCopy(const mkldnn::memory &mem, const mkldnn::memory* this_mem) { MKLDNNStream *stream = MKLDNNStream::Get(); + mkldnn::memory::desc from_desc = mem.get_desc(); + mkldnn::memory::desc this_desc = this_mem->get_desc(); + mkldnn_format_tag_t from_def_format = GetDefaultFormat(from_desc); + mkldnn_format_tag_t this_def_format = GetDefaultFormat(this_desc); - mkldnn::memory::primitive_desc from_pd = mem.get_primitive_desc(); - mkldnn::memory::desc from_desc = from_pd.desc(); - mkldnn::memory::primitive_desc this_pd = this_mem->get_primitive_desc(); - mkldnn::memory::desc this_desc = this_pd.desc(); - mkldnn_memory_format_t from_def_format = GetDefaultFormat(from_desc); - mkldnn_memory_format_t this_def_format = GetDefaultFormat(this_desc); - // It's possible that the memory and the NDArray don't have the same shape. - if (!same_shape(this_desc, from_desc) - // If the source memory uses the default layout, we can reshape directly. - && from_def_format == from_desc.data.format) { + if (!same_shape(this_desc, from_desc) && IsDefaultFormat(from_desc)) { // In this case, we can simply create a new MKLDNN memory for the required // shape. mkldnn::memory::dims dims(this_desc.data.dims, this_desc.data.dims + this_desc.data.ndims); auto this_dtype = static_cast(this_desc.data.data_type); - auto this_format = static_cast(GetDefaultFormat(this_desc)); - mkldnn::memory::desc data_md(dims, this_dtype, this_format); - mkldnn::memory::primitive_desc pd(data_md, from_pd.get_engine()); - mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, mem.get_data_handle())); + mkldnn::memory::desc data_md(dims, this_dtype, + static_cast(this_def_format)); + + mkldnn_mem_ptr tmp_mem(new mkldnn::memory(data_md, mem.get_engine(), mem.get_data_handle())); stream->RegisterMem(tmp_mem); - stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *this_mem)); + std::unordered_map args({{MKLDNN_ARG_FROM, *tmp_mem}, + {MKLDNN_ARG_TO, *this_mem}}); + stream->RegisterPrimArgs(mkldnn::reorder(*tmp_mem, *this_mem), args); } else if (!same_shape(this_desc, from_desc)) { // In this case, the source memory stores data in a customized layout. We // need to reorganize the data in memory before we can reshape. - mkldnn::memory::primitive_desc def_pd = GetPrimitiveDesc(from_pd, from_def_format); - mkldnn::memory *def_mem = TmpMemMgr::Get()->Alloc(def_pd); - stream->RegisterPrim(mkldnn::reorder(mem, *def_mem)); + mkldnn::memory::desc def_desc = GetDesc(from_desc, from_def_format); + mkldnn::memory *def_mem = TmpMemMgr::Get()->Alloc(def_desc); + std::unordered_map args({{MKLDNN_ARG_FROM, mem}, + {MKLDNN_ARG_TO, *def_mem}}); + stream->RegisterPrimArgs(mkldnn::reorder(mem, *def_mem), args); + // Now we can reshape it - mkldnn::memory::dims dims(this_desc.data.dims, - this_desc.data.dims + this_desc.data.ndims); - auto this_dtype = static_cast(this_desc.data.data_type); - auto this_format = static_cast(GetDefaultFormat(this_desc)); - mkldnn::memory::desc data_md(dims, this_dtype, this_format); - mkldnn::memory::primitive_desc pd(data_md, from_pd.get_engine()); - mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, def_mem->get_data_handle())); + mkldnn_mem_ptr tmp_mem(new mkldnn::memory(this_desc, + mem.get_engine(), def_mem->get_data_handle())); stream->RegisterMem(tmp_mem); - stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *this_mem)); - } else if (from_pd == this_pd) { + args = {{MKLDNN_ARG_FROM, *tmp_mem}, {MKLDNN_ARG_TO, *this_mem}}; + stream->RegisterPrimArgs(mkldnn::reorder(*tmp_mem, *this_mem), args); +} else if (this_desc == from_desc) { + std::unordered_map args({{MKLDNN_ARG_FROM, mem}, + {MKLDNN_ARG_TO, *this_mem}}); // If the layout is the same, we can just copy data. - stream->RegisterPrim(mkldnn::reorder(mem, *this_mem)); - } else { + stream->RegisterPrimArgs(mkldnn::reorder(mem, *this_mem), args); +} else { // If both are not using the default layouts. There isn't much we can do, // other than reorder data layout directly. - if (this_def_format != this_desc.data.format - && from_def_format != from_desc.data.format) { - stream->RegisterPrim(mkldnn::reorder(mem, *this_mem)); - } else if (this_def_format == this_desc.data.format) { + if (!IsDefaultFormat(this_desc) && !IsDefaultFormat(from_desc)) { + std::unordered_map args({{MKLDNN_ARG_FROM, mem}, + {MKLDNN_ARG_TO, *this_mem}}); + stream->RegisterPrimArgs(mkldnn::reorder(mem, *this_mem), args); + } else if (IsDefaultFormat(this_desc)) { // If the dest mem uses the default memory layout, we can simply use // the default format of the source memory to improve perf of reorder. - mkldnn::memory::primitive_desc pd = GetPrimitiveDesc(from_pd, - from_def_format); - mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, this_mem->get_data_handle())); + mkldnn::memory::desc desc = GetDesc(from_desc, from_def_format); + mkldnn_mem_ptr tmp_mem(new mkldnn::memory(desc, + mem.get_engine(), this_mem->get_data_handle())); stream->RegisterMem(tmp_mem); - stream->RegisterPrim(mkldnn::reorder(mem, *tmp_mem)); + std::unordered_map args({{MKLDNN_ARG_FROM, mem}, + {MKLDNN_ARG_TO, *tmp_mem}}); + stream->RegisterPrimArgs(mkldnn::reorder(mem, *tmp_mem), args); } else { // If the src mem uses the default memory layout, we can use // the default format of the source memory to improve perf. - mkldnn::memory::primitive_desc pd = GetPrimitiveDesc(this_pd, - this_def_format); - mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, mem.get_data_handle())); + mkldnn::memory::desc desc = GetDesc(this_desc, this_def_format); + mkldnn_mem_ptr tmp_mem(new mkldnn::memory(desc, + this_mem->get_engine(), mem.get_data_handle())); stream->RegisterMem(tmp_mem); - stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *this_mem)); + std::unordered_map args({{MKLDNN_ARG_FROM, *tmp_mem}, + {MKLDNN_ARG_TO, *this_mem}}); + stream->RegisterPrimArgs(mkldnn::reorder(*tmp_mem, *this_mem), args); } } } bool CanWriteTo(const NDArray &out_arr, const NDArray &in_arr, - const mkldnn::memory::primitive_desc &desc) { + const mkldnn::memory::desc &desc) { auto in_mem = in_arr.GetMKLDNNData(); bool add_same = in_mem->get_data_handle() == out_arr.GetMKLDNNData()->get_data_handle(); - bool pdesc_same = out_arr.GetMKLDNNData()->get_primitive_desc() == desc && - in_mem->get_primitive_desc() == desc; + bool pdesc_same = out_arr.GetMKLDNNData()->get_desc() == desc && + in_mem->get_desc() == desc; return add_same && pdesc_same; } mkldnn_output_t CreateMKLDNNMem(const NDArray &out_arr, - const mkldnn::memory::primitive_desc &desc, + const mkldnn::memory::desc &desc, OpReqType req, const NDArray* in_arr) { if (kAddTo == req) { @@ -188,7 +191,7 @@ mkldnn_output_t CreateMKLDNNMem(const NDArray &out_arr, } mkldnn_output_t CreateMKLDNNWeightGrad(const NDArray &out_arr, - const mkldnn::memory::primitive_desc &desc, + const mkldnn::memory::desc &desc, OpReqType req) { if (kAddTo == req) { auto tmp = TmpMemMgr::Get()->Alloc(desc); @@ -197,10 +200,8 @@ mkldnn_output_t CreateMKLDNNWeightGrad(const NDArray &out_arr, auto tmp = TmpMemMgr::Get()->Alloc(desc); return mkldnn_output_t(OutDataOp::CopyBack, tmp); } else { - auto _desc = desc; - auto def_format = GetDefaultFormat(_desc.desc()); mkldnn::memory *mem = nullptr; - if (def_format == _desc.desc().data.format) { + if (IsDefaultFormat(desc)) { mem = const_cast(out_arr).CreateMKLDNNData(desc); } if (mem == nullptr) { @@ -217,8 +218,8 @@ void CommitOutput(const NDArray &arr, const mkldnn_output_t &res) { const_cast(arr).CopyFrom(*res.second); } else if (res.first == AddBack) { auto res_memory = res.second; - auto target_pd = arr.GetMKLDNNData()->get_primitive_desc(); - auto mem = arr.GetMKLDNNData(res.second->get_primitive_desc()); + auto target_pd = arr.GetMKLDNNData()->get_desc(); + auto mem = arr.GetMKLDNNData(res.second->get_desc()); if (mem == nullptr) { auto tmp_memory = TmpMemMgr::Get()->Alloc(target_pd); MKLDNNCopy(*res_memory, tmp_memory); @@ -232,12 +233,12 @@ void CommitOutput(const NDArray &arr, const mkldnn_output_t &res) { const mkldnn::memory *GetWeights(const NDArray &arr, int num_groups) { const auto type = get_mkldnn_type(arr.dtype()); auto tz = mkldnn::memory::dims{0}; - auto format = mkldnn::memory::format::format_undef; + auto format_tag = mkldnn::memory::format_tag::undef; auto engine = CpuEngine::Get()->get_engine(); const int O = 0, I = 1, H = 2, W = 3; if (arr.shape().ndim() == 2) { tz = mkldnn::memory::dims{static_cast(arr.shape()[O]), static_cast(arr.shape()[I])}; - format = mkldnn::memory::format::oi; + format_tag = mkldnn::memory::format_tag::oi; } else if (arr.shape().ndim() == 3) { tz = num_groups > 1 ? mkldnn::memory::dims{num_groups, static_cast(arr.shape()[O] / num_groups), @@ -246,7 +247,8 @@ const mkldnn::memory *GetWeights(const NDArray &arr, int num_groups) { : mkldnn::memory::dims{static_cast(arr.shape()[O]), static_cast(arr.shape()[I]), static_cast(arr.shape()[H])}; - format = num_groups > 1 ? mkldnn::memory::format::goiw : mkldnn::memory::format::oiw; + format_tag = num_groups > 1 ? mkldnn::memory::format_tag::goiw + : mkldnn::memory::format_tag::oiw; } else if (arr.shape().ndim() == 4) { tz = num_groups > 1 ? mkldnn::memory::dims{num_groups, static_cast(arr.shape()[O] / num_groups), @@ -256,168 +258,99 @@ const mkldnn::memory *GetWeights(const NDArray &arr, int num_groups) { : mkldnn::memory::dims{ static_cast(arr.shape()[O]), static_cast(arr.shape()[I]), static_cast(arr.shape()[H]), static_cast(arr.shape()[W])}; - format = num_groups > 1 ? mkldnn::memory::format::goihw : mkldnn::memory::format::oihw; + format_tag = num_groups > 1 ? mkldnn::memory::format_tag::goihw + : mkldnn::memory::format_tag::oihw; } else { LOG(FATAL) << "The weight array has an unsupported number of dimensions"; } - const auto md = mkldnn::memory::desc{tz, type, format}; - const auto pd = mkldnn::memory::primitive_desc{md, engine}; - return arr.GetMKLDNNData(pd); + const auto md = mkldnn::memory::desc{tz, type, format_tag}; + return arr.GetMKLDNNData(md); } const mkldnn::memory *GetWeights(const NDArray &arr, - const mkldnn::memory::primitive_desc &target_pd, int num_groups) { - const mkldnn::memory *mem = arr.GetMKLDNNData(target_pd); + const mkldnn::memory::desc &target_desc, int num_groups) { + const mkldnn::memory *mem = arr.GetMKLDNNData(target_desc); // If the weight array already uses the target layout, simply return it directly. if (mem) return mem; mem = GetWeights(arr, num_groups); - if (mem == nullptr) mem = arr.GetMKLDNNDataReorder(target_pd); - if (mem->get_primitive_desc() == target_pd) return mem; + if (mem == nullptr) mem = arr.GetMKLDNNDataReorder(target_desc); + if (mem->get_desc() == target_desc) return mem; - auto ret = TmpMemMgr::Get()->Alloc(target_pd); - MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(*mem, *ret)); + auto ret = TmpMemMgr::Get()->Alloc(target_desc); + std::unordered_map args({{MKLDNN_ARG_FROM, *mem}, + {MKLDNN_ARG_TO, *ret}}); + MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::reorder(*mem, *ret), args); return ret; } -mkldnn_memory_format_t GetDefaultFormat(int num_dims) { + +// default: block and dims' stride increase monotonically +// mkldnn: 1.winograd 2.rnn packed 3. block and dims'stride is not increase monotonically +bool IsMKLDNN(const mkldnn::memory::desc &desc) { + bool rslt = true; + if (desc.data.format_kind == mkldnn_blocked) { + if (desc.data.format_desc.blocking.inner_nblks == 0) { + int i = 0; + for (i = 0; i < desc.data.ndims-1; i++) { + if (desc.data.format_desc.blocking.strides[i] + < desc.data.format_desc.blocking.strides[i + 1]) { + break; + } + } + if (i == desc.data.ndims-1) { + rslt = false; + } + } + } + return rslt; +} + +mkldnn_format_tag_t GetDefaultFormat(int num_dims) { switch (num_dims) { - case 1: return mkldnn_x; - case 2: return mkldnn_nc; - case 3: return mkldnn_ncw; - case 4: return mkldnn_nchw; - case 5: return mkldnn_goihw; + case 1: return mkldnn_a; + case 2: return mkldnn_ab; + case 3: return mkldnn_abc; + case 4: return mkldnn_abcd; + case 5: return mkldnn_abcde; default: - LOG(FATAL) << "Unsupported MKLDNN dimensions: " << num_dims; - return mkldnn_format_undef; + LOG(FATAL) << "Not implemented dimension (" << num_dims << ") for MKLDNN"; + return mkldnn_format_tag_undef; } } -mkldnn_memory_format_t GetDefaultFormat(const mkldnn::memory::desc &desc) { - if (desc.data.ndims == 1) { - return desc.data.format; - } else if (desc.data.ndims == 2) { - if (desc.data.format == mkldnn_io) - return mkldnn_oi; - else - return desc.data.format; - } else if (desc.data.ndims == 3) { - switch (desc.data.format) { - case mkldnn_ncw: - case mkldnn_nwc: - case mkldnn_nCw8c: - case mkldnn_nCw16c: - return mkldnn_ncw; - case mkldnn_oiw: - case mkldnn_wio: - case mkldnn_Owi8o: - case mkldnn_OIw8i8o: - case mkldnn_OIw8o8i: - case mkldnn_OIw16i16o: - case mkldnn_OIw16o16i: - case mkldnn_Oiw16o: - case mkldnn_Owi16o: - case mkldnn_OIw8i16o2i: - case mkldnn_OIw8o16i2o: - case mkldnn_IOw16o16i: - return mkldnn_oiw; - default: - LOG(FATAL) << "Unknown MKLDNN format for 3 dimensions: " << desc.data.format; - return mkldnn_format_undef; - } - } else if (desc.data.ndims == 4) { - switch (desc.data.format) { - case mkldnn_nchw: - case mkldnn_nhwc: - case mkldnn_chwn: - case mkldnn_nChw4c: - case mkldnn_nChw8c: - case mkldnn_nChw16c: - return mkldnn_nchw; - case mkldnn_oihw: - case mkldnn_ihwo: - case mkldnn_hwio: - case mkldnn_iohw: - case mkldnn_oIhw8i: - case mkldnn_oIhw16i: - case mkldnn_OIhw4i4o: - case mkldnn_OIhw8i8o: - case mkldnn_hwio_s8s8: - case mkldnn_OIhw16i16o: - case mkldnn_OIhw4i16o4i: - case mkldnn_OIhw4i16o4i_s8s8: - case mkldnn_OIhw8i16o2i: - case mkldnn_OIhw8o16i2o: - case mkldnn_OIhw8o8i: - case mkldnn_OIhw16o16i: - case mkldnn_IOhw16o16i: - case mkldnn_Oihw8o: - case mkldnn_Oihw16o: - case mkldnn_Ohwi8o: - case mkldnn_Ohwi16o: - case mkldnn_OhIw16o4i: - return mkldnn_oihw; - case mkldnn_goiw: - case mkldnn_gOwi8o: - case mkldnn_gOIw8o8i: - case mkldnn_gOIw8i8o: - case mkldnn_gOIw16i16o: - case mkldnn_gOIw16o16i: - case mkldnn_gOiw16o: - case mkldnn_gOwi16o: - case mkldnn_gOIw8i16o2i: - case mkldnn_gOIw8o16i2o: - case mkldnn_gIOw16o16i: - return mkldnn_goiw; - default: - LOG(FATAL) << "Unknown MKLDNN format for 4 dimensions: " << desc.data.format; - return mkldnn_format_undef; - } - } else if (desc.data.ndims == 5) { - switch (desc.data.format) { - case mkldnn_goihw: - case mkldnn_giohw: - case mkldnn_hwigo: - case mkldnn_hwigo_s8s8: - case mkldnn_gOIhw4i4o: - case mkldnn_gOIhw8i8o: - case mkldnn_gOIhw16i16o: - case mkldnn_gOIhw4i16o4i: - case mkldnn_gOIhw4i16o4i_s8s8: - case mkldnn_gOIhw8i16o2i: - case mkldnn_gOIhw8o16i2o: - case mkldnn_gOIhw8o8i: - case mkldnn_gOIhw4o4i: - case mkldnn_gOIhw16o16i: - case mkldnn_gIOhw16o16i: - case mkldnn_gOihw8o: - case mkldnn_Goihw8g: - case mkldnn_gOihw16o: - case mkldnn_Goihw16g: - case mkldnn_gOhwi8o: - case mkldnn_gOhwi16o: - case mkldnn_gOhIw16o4i: - case mkldnn_Goihw16g_s8s8: - return mkldnn_goihw; - default: - LOG(FATAL) << "Unknown MKLDNN format for 5 dimensions: " << desc.data.format; - return mkldnn_format_undef; +mkldnn_format_tag_t GetDefaultFormat(const mkldnn::memory::desc &desc) { + return GetDefaultFormat(desc.data.ndims); +} + +bool IsDefaultFormat(const mkldnn::memory::desc &desc) { + bool rslt = false; + if (desc.data.format_kind == mkldnn_blocked) { + if (desc.data.format_desc.blocking.inner_nblks == 0) { + int i = 0; + for (i = 0; i < desc.data.ndims-1; i++) { + if (desc.data.format_desc.blocking.strides[i] + < desc.data.format_desc.blocking.strides[i + 1]) { + break; + } + } + if (i == desc.data.ndims-1) { + rslt = true; + } } - } else { - LOG(FATAL) << "Unsupported dimensions: " << desc.data.ndims; - return mkldnn_format_undef; } + return rslt; } -mkldnn::memory::primitive_desc GetPrimitiveDesc(mkldnn::memory::primitive_desc pd, - mkldnn_memory_format_t format) { - mkldnn::memory::dims dims(pd.desc().data.ndims); +mkldnn::memory::desc GetDesc(const mkldnn::memory::desc &desc, + const mkldnn_format_tag_t &format) { + mkldnn::memory::dims dims(desc.data.ndims); for (size_t i = 0; i < dims.size(); i++) - dims[i] = pd.desc().data.dims[i]; - mkldnn::memory::format cpp_format = static_cast(format); + dims[i] = desc.data.dims[i]; + mkldnn::memory::format_tag cpp_format = static_cast(format); mkldnn::memory::data_type cpp_type = static_cast( - pd.desc().data.data_type); + desc.data.data_type); mkldnn::memory::desc data_md(dims, cpp_type, cpp_format); - return mkldnn::memory::primitive_desc(data_md, pd.get_engine()); + return mkldnn::memory::desc(dims, cpp_type, cpp_format); } void FallBackCompute(FCompute fn, const nnvm::NodeAttrs &attrs, @@ -508,10 +441,11 @@ static bool SimilarArray(const mxnet::NDArray &arr1, const mxnet::NDArray &arr2, std::atomic success(true); #pragma omp parallel for #ifdef _MSC_VER - for (int64_t i = 0; i < arr1.shape().Size(); i++) { + for (int64_t i = 0; i < arr1.shape().Size(); i++) #else - for (size_t i = 0; i < arr1.shape().Size(); i++) { + for (size_t i = 0; i < arr1.shape().Size(); i++) #endif + { if (std::abs(data1[i] - data2[i]) > atol + rtol * std::abs(data2[i])) success.store(false); } diff --git a/src/operator/nn/mkldnn/mkldnn_ops-inl.h b/src/operator/nn/mkldnn/mkldnn_ops-inl.h index 502abff6231b..122ad9fd0686 100644 --- a/src/operator/nn/mkldnn/mkldnn_ops-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_ops-inl.h @@ -26,7 +26,6 @@ #ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_OPS_INL_H_ #define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_OPS_INL_H_ -#if MXNET_USE_MKLDNN == 1 #include #include @@ -36,11 +35,15 @@ #include #include #include + +#if MXNET_USE_MKLDNN == 100 #include +#endif namespace mxnet { namespace op { +#if MXNET_USE_MKLDNN == 1 /* For fully connected. */ void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector &in_data, @@ -110,9 +113,6 @@ void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx const NDArray &out_grad, const NDArray &in_data, const OpReqType &req, const NDArray &in_grad); -void MKLDNNSum(const mkldnn::memory &arr1, const mkldnn::memory &arr2, - const mkldnn::memory &out); - void MKLDNNTransposeForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const NDArray &data, @@ -130,8 +130,14 @@ void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs, const NDArray &input, const OpReqType &req, const NDArray &output); +#endif + +#if MXNET_USE_MKLDNN == 100 +void MKLDNNSum(const mkldnn::memory &arr1, const mkldnn::memory &arr2, + const mkldnn::memory &out); +#endif + } // namespace op } // namespace mxnet -#endif // MXNET_USE_MKLDNN == 1 #endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_OPS_INL_H_ diff --git a/src/operator/nn/mkldnn/mkldnn_sum.cc b/src/operator/nn/mkldnn/mkldnn_sum.cc index 724b8a2613d6..69b6728fc0b5 100644 --- a/src/operator/nn/mkldnn/mkldnn_sum.cc +++ b/src/operator/nn/mkldnn/mkldnn_sum.cc @@ -28,37 +28,43 @@ #include "./mkldnn_ops-inl.h" #include "./mkldnn_base-inl.h" -#if MXNET_USE_MKLDNN == 1 namespace mxnet { namespace op { -void MKLDNNSum(const mkldnn::memory &arr1, const mkldnn::memory &arr2, - const mkldnn::memory &out) { - std::vector input_pds(2); +#if MXNET_USE_MKLDNN == 100 +void MKLDNNSum(const mkldnn::memory &arr1, + const mkldnn::memory &arr2, + const mkldnn::memory &out) { + std::vector input_pds(2); std::vector scales(2, 1); - std::vector inputs; - input_pds[0] = arr1.get_primitive_desc(); - input_pds[1] = arr2.get_primitive_desc(); + input_pds[0] = arr1.get_desc(); + input_pds[1] = arr2.get_desc(); CHECK(input_pds[0] == input_pds[0]); const mkldnn::memory *in_mem1 = &arr1; const mkldnn::memory *in_mem2 = &arr2; - auto output_pd = out.get_primitive_desc(); + auto output_pd = out.get_desc(); if (input_pds[0] != output_pd) { auto tmp_memory1 = TmpMemMgr::Get()->Alloc(output_pd); auto tmp_memory2 = TmpMemMgr::Get()->Alloc(output_pd); mxnet::MKLDNNCopy(arr1, tmp_memory1); mxnet::MKLDNNCopy(arr2, tmp_memory2); - input_pds[0] = tmp_memory1->get_primitive_desc(); - input_pds[1] = tmp_memory2->get_primitive_desc(); + input_pds[0] = tmp_memory1->get_desc(); + input_pds[1] = tmp_memory2->get_desc(); in_mem1 = tmp_memory1; in_mem2 = tmp_memory2; } - inputs.push_back(*in_mem1); - inputs.push_back(*in_mem2); - mkldnn::sum::primitive_desc sum_pd(scales, input_pds); - MKLDNNStream::Get()->RegisterPrim(mkldnn::sum(sum_pd, inputs, out)); + mkldnn::sum::primitive_desc sum_pd(output_pd, scales, input_pds, CpuEngine::Get()->get_engine()); + std::unordered_map args = { + { MKLDNN_ARG_MULTIPLE_SRC, *in_mem1 }, + { MKLDNN_ARG_MULTIPLE_SRC + 1, *in_mem2 }, + { MKLDNN_ARG_DST, out }, + }; + MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::sum(sum_pd), args); } +#endif + +#if MXNET_USE_MKLDNN == 1 class MKLDNNSumFwd { public: mkldnn::sum::primitive_desc fwd_pd; @@ -159,7 +165,7 @@ void MKLDNNSumForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, CommitOutput(out_data, out_mem); MKLDNNStream::Get()->Submit(); } +#endif } // namespace op } // namespace mxnet -#endif