Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[mkldnn-v1.0] Enable base code with new APIs. (#16064)
Browse files Browse the repository at this point in the history
* fix comments (#8)

* add base code for mkldnn 1.0

* fix comments

* Update mkldnn.mk

* add base code for mkldnn 1.0

* fix build

* fix lint

* fix lint
  • Loading branch information
TaoLv authored and pengzhao-intel committed Sep 10, 2019
1 parent 249f375 commit e0cbc92
Show file tree
Hide file tree
Showing 8 changed files with 363 additions and 436 deletions.
2 changes: 1 addition & 1 deletion include/mxnet/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ inline std::ostream& operator<<(std::ostream &out, const Context &ctx) {
#define ADD_FILELINE "\n\nDefined in " __FILE__ ":L" STRINGIZE(__LINE__)


#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
constexpr size_t kMKLDNNAlign = 64;
#endif

Expand Down
30 changes: 14 additions & 16 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
#include <algorithm>
#include <memory>
#include <algorithm>
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
#include <mkldnn.hpp>
#endif
#include "./base.h"
Expand Down Expand Up @@ -699,7 +699,7 @@ class NDArray {
ptr_->CheckAndAllocAuxData(i, aux_shape);
}

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
/*
* Create NDArray from mkldnn memory.
* mkldnn_mem The mkldnn memory to be managed.
Expand All @@ -709,7 +709,7 @@ class NDArray {
* Create NDArray from mkldnn memory descriptor.
* mem_pd The mkldnn memory descriptor to be created.
*/
explicit NDArray(mkldnn::memory::primitive_desc mem_pd);
explicit NDArray(const mkldnn::memory::desc &md);
/*
* Test if the data is stored in one of special MKLDNN format.
*/
Expand Down Expand Up @@ -737,15 +737,14 @@ class NDArray {
* This function returns mkldnn::memory with the given primitive_desc
* as long as the array size meets the required size in the given primitive_desc.
*/
const mkldnn::memory *GetMKLDNNData(
const mkldnn::memory::primitive_desc &desc) const;
const mkldnn::memory *GetMKLDNNData(const mkldnn::memory::desc &md) const;
/*
* This function returns mkldnn::memory with the given primitive_desc.
* The returned mkldnn::memory will have the same physical layout as
* the given primitive_desc.
*/
const mkldnn::memory *GetMKLDNNDataReorder(
const mkldnn::memory::primitive_desc &desc) const;
const mkldnn::memory::desc &md) const;

/*
* This function copies data from mkldnn memory.
Expand All @@ -755,16 +754,15 @@ class NDArray {
* This function allocates memory for array and creates mkldnn memory
* with the specified format.
*/
mkldnn::memory *CreateMKLDNNData(
const mkldnn::memory::primitive_desc &desc);
mkldnn::memory *CreateMKLDNNData(const mkldnn::memory::desc &md);

/*
* These are the async version of the methods above.
* It changes the layout of this NDArray, but it happens after all accesses to
* the array are complete.
*/
void Reorder2DefaultAsync();
void MKLDNNDataReorderAsync(const mkldnn::memory::primitive_desc &desc);
void MKLDNNDataReorderAsync(const mkldnn::memory::desc &md);

/*
* This creates a new NDArray with the reordered data.
Expand All @@ -789,7 +787,7 @@ class NDArray {
/*!
* \ Fix mkldnn memory descriptor mismatch from NDArray.
*/
void UpdateMKLDNNMemDesc(mkldnn::memory::format format);
void UpdateMKLDNNMemDesc(mkldnn::memory::format_tag format);
#endif

/*!
Expand Down Expand Up @@ -827,7 +825,7 @@ class NDArray {
*/
std::vector<Storage::Handle> aux_handles;

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
/*! This is created when data is stored in MKLDNN format.
*/
std::shared_ptr<MKLDNNMemory> mkl_mem_;
Expand Down Expand Up @@ -986,7 +984,7 @@ class NDArray {
inline void CheckAndAlloc(void) {
if (delay_alloc) {
shandle = Storage::Get()->Alloc(shandle.size, shandle.ctx);
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
mkl_mem_ = nullptr;
#endif
delay_alloc = false;
Expand All @@ -1001,7 +999,7 @@ class NDArray {
dbytes = std::max(dbytes, static_cast<uint64_t>(shandle.size));
if (delay_alloc) {
shandle = Storage::Get()->Alloc(dbytes, shandle.ctx);
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
mkl_mem_ = nullptr;
#endif
delay_alloc = false;
Expand All @@ -1010,7 +1008,7 @@ class NDArray {
Storage::Get()->Free(shandle);
// init storage
shandle = Storage::Get()->Alloc(dbytes, shandle.ctx);
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
mkl_mem_ = nullptr;
#endif
}
Expand Down Expand Up @@ -1046,15 +1044,15 @@ class NDArray {
// and allocate new storage
void CheckAndAllocData(const mxnet::TShape &shape, int dtype);

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
// Have MKL memory reference to the data in the default storage
// or create memory for MKLDNN.
void SetMKLMem(const mxnet::TShape &shape, int dtype);
// If the data is stored in MKLDNN layout, we reorder data in mkl_mem_ and
// save the result in shandle.
void Reorder2Default();
// Reroder data to a specified layout.
void MKLDNNDataReorder(const mkldnn::memory::primitive_desc &desc);
void MKLDNNDataReorder(const mkldnn::memory::desc &md);
bool IsMKLDNN() const;
bool IsDefault() const;
#endif
Expand Down
2 changes: 2 additions & 0 deletions mkldnn.mk
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@ mkldnn_build: $(MKLDNN_LIBFILE)

$(MKLDNN_LIBFILE):
mkdir -p $(MKLDNNROOT)
mkdir -p $(MKLDNNROOT)/lib
cmake $(MKLDNN_SUBMODDIR) -DCMAKE_INSTALL_PREFIX=$(MKLDNNROOT) -B$(MKLDNN_BUILDDIR) -DMKLDNN_ARCH_OPT_FLAGS="" -DMKLDNN_BUILD_TESTS=OFF -DMKLDNN_BUILD_EXAMPLES=OFF -DMKLDNN_ENABLE_JIT_PROFILING=OFF
$(MAKE) -C $(MKLDNN_BUILDDIR) VERBOSE=1
$(MAKE) -C $(MKLDNN_BUILDDIR) install
mkdir -p $(MXNET_LIBDIR)
if [ -f "$(MKLDNN_LIB64FILE)" ]; then \
cp $(MKLDNNROOT)/lib64/libmkldnn* $(MXNET_LIBDIR); \
cp $(MKLDNNROOT)/lib64/libmkldnn* $(MKLDNNROOT)/lib/; \
else \
cp $(MKLDNNROOT)/lib/libmkldnn* $(MXNET_LIBDIR); \
fi
Expand Down
Loading

0 comments on commit e0cbc92

Please sign in to comment.