diff --git a/.gitmodules b/.gitmodules index 170c105a6f48..0266fd2448e9 100644 --- a/.gitmodules +++ b/.gitmodules @@ -22,3 +22,7 @@ [submodule "3rdparty/googletest"] path = 3rdparty/googletest url = https://github.com/google/googletest.git +[submodule "3rdparty/mkldnn"] + path = 3rdparty/mkldnn + url = https://github.com/ashokei/mkl-dnn.git + branch = master diff --git a/3rdparty/mkldnn b/3rdparty/mkldnn new file mode 160000 index 000000000000..e9ef04c277c7 --- /dev/null +++ b/3rdparty/mkldnn @@ -0,0 +1 @@ +Subproject commit e9ef04c277c7ceebb09efd8739368e7bf02023ca diff --git a/Jenkinsfile b/Jenkinsfile index b7a8f60cb9b6..4fc12f3dab6f 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -4,6 +4,7 @@ // mxnet libraries mx_lib = 'lib/libmxnet.so, lib/libmxnet.a, dmlc-core/libdmlc.a, nnvm/lib/libnnvm.a' +mx_mkldnn_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libiomp5.so, lib/libmklml_gnu.so, lib/libmkldnn.so, lib/libmkldnn.so.0, lib/libmklml_intel.so, dmlc-core/libdmlc.a, nnvm/lib/libnnvm.a' // command to start a docker container docker_run = 'tests/ci_build/ci_build.sh' // timeout in minutes @@ -122,7 +123,7 @@ def python3_gpu_ut(docker_type) { } // Python 2 -def python2_mklml_ut(docker_type) { +def python2_mkldnn_ut(docker_type) { timeout(time: max_time, unit: 'MINUTES') { sh "${docker_run} ${docker_type} find . -name '*.pyc' -type f -delete" sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-2.7 --with-timer --verbose tests/python/cpu" @@ -130,7 +131,7 @@ def python2_mklml_ut(docker_type) { } // Python 3 -def python3_mklml_ut(docker_type) { +def python3_mkldnn_ut(docker_type) { timeout(time: max_time, unit: 'MINUTES') { sh "${docker_run} ${docker_type} find . -name '*.pyc' -type f -delete" sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-3.4 --with-timer --verbose tests/python/cpu" @@ -204,42 +205,40 @@ try { } } }, - 'CPU: MKLML': { + 'CPU: MKLDNN': { node('mxnetlinux-cpu') { - ws('workspace/build-mklml-cpu') { + ws('workspace/build-mkldnn-cpu') { init_git() def flag = """ \ DEV=1 \ USE_PROFILER=1 \ USE_CPP_PACKAGE=1 \ USE_BLAS=openblas \ - USE_MKL2017=1 \ - USE_MKL2017_EXPERIMENTAL=1 \ + USE_MKLDNN=1 \ -j\$(nproc) """ make("cpu_mklml", flag) - pack_lib('mklml_cpu') + pack_lib('mkldnn_cpu', mx_mkldnn_lib) } } }, - 'GPU: MKLML': { + 'GPU: MKLDNN': { node('mxnetlinux-cpu') { - ws('workspace/build-mklml-gpu') { + ws('workspace/build-mkldnn-gpu') { init_git() def flag = """ \ DEV=1 \ USE_PROFILER=1 \ USE_CPP_PACKAGE=1 \ USE_BLAS=openblas \ - USE_MKL2017=1 \ - USE_MKL2017_EXPERIMENTAL=1 \ + USE_MKLDNN=1 \ USE_CUDA=1 \ USE_CUDA_PATH=/usr/local/cuda \ USE_CUDNN=1 \ -j\$(nproc) """ make("build_cuda", flag) - pack_lib('mklml_gpu') + pack_lib('mkldnn_gpu', mx_mkldnn_lib) } } }, @@ -386,43 +385,43 @@ try { } } }, - 'Python2: MKLML-CPU': { + 'Python2: MKLDNN-CPU': { node('mxnetlinux-cpu') { - ws('workspace/ut-python2-mklml-cpu') { + ws('workspace/ut-python2-mkldnn-cpu') { init_git() - unpack_lib('mklml_cpu') + unpack_lib('mkldnn_cpu', mx_mkldnn_lib) python2_ut('cpu_mklml') - python2_mklml_ut('cpu_mklml') + python2_mkldnn_ut('cpu_mklml') } } }, - 'Python2: MKLML-GPU': { + 'Python2: MKLDNN-GPU': { node('mxnetlinux-gpu') { - ws('workspace/ut-python2-mklml-gpu') { + ws('workspace/ut-python2-mkldnn-gpu') { init_git() - unpack_lib('mklml_gpu') + unpack_lib('mkldnn_gpu', mx_mkldnn_lib) python2_gpu_ut('gpu_mklml') - python2_mklml_ut('gpu_mklml') + python2_mkldnn_ut('gpu_mklml') } } }, - 'Python3: MKLML-CPU': { + 'Python3: MKLDNN-CPU': { node('mxnetlinux-cpu') { - ws('workspace/ut-python3-mklml-cpu') { + ws('workspace/ut-python3-mkldnn-cpu') { init_git() - unpack_lib('mklml_cpu') + unpack_lib('mkldnn_cpu', mx_mkldnn_lib) python3_ut('cpu_mklml') - python3_mklml_ut('cpu_mklml') + python3_mkldnn_ut('cpu_mklml') } } }, - 'Python3: MKLML-GPU': { + 'Python3: MKLDNN-GPU': { node('mxnetlinux-gpu') { - ws('workspace/ut-python3-mklml-gpu') { + ws('workspace/ut-python3-mkldnn-gpu') { init_git() - unpack_lib('mklml_gpu') + unpack_lib('mkldnn_gpu', mx_mkldnn_lib) python3_gpu_ut('gpu_mklml') - python3_mklml_ut('gpu_mklml') + python3_mkldnn_ut('gpu_mklml') } } }, diff --git a/Makefile b/Makefile index aae0ba91a75f..2737596cec5c 100644 --- a/Makefile +++ b/Makefile @@ -42,11 +42,11 @@ endif # use customized config file include $(config) -ifeq ($(USE_MKL2017), 1) -# must run ./prepare_mkl before including mshadow.mk - RETURN_STRING := $(shell ./prepare_mkl.sh $(MKLML_ROOT)) - MKLROOT := $(firstword $(RETURN_STRING)) - export USE_MKLML = $(lastword $(RETURN_STRING)) +ifeq ($(USE_MKLDNN), 1) + RETURN_STRING := $(shell ./prepare_mkldnn.sh $(MKLDNN_ROOT)) + MKLDNNROOT := $(firstword $(RETURN_STRING)) + MKLROOT := $(lastword $(RETURN_STRING)) + export USE_MKLML = 1 endif include mshadow/make/mshadow.mk @@ -114,23 +114,20 @@ ifeq ($(USE_NNPACK), 1) LDFLAGS += -lnnpack endif -ifeq ($(USE_MKL2017), 1) - CFLAGS += -DMXNET_USE_MKL2017=1 +ifeq ($(USE_MKLDNN), 1) + CFLAGS += -DMXNET_USE_MKLDNN=1 CFLAGS += -DUSE_MKL=1 - CFLAGS += -I$(ROOTDIR)/src/operator/mkl/ - CFLAGS += -I$(MKLML_ROOT)/include - LDFLAGS += -L$(MKLML_ROOT)/lib - ifeq ($(USE_MKL2017_EXPERIMENTAL), 1) - CFLAGS += -DMKL_EXPERIMENTAL=1 - else - CFLAGS += -DMKL_EXPERIMENTAL=0 - endif - ifeq ($(UNAME_S), Darwin) - LDFLAGS += -lmklml - else - LDFLAGS += -Wl,--as-needed -lmklml_intel -lmklml_gnu + CFLAGS += -I$(ROOTDIR)/src/operator/nn/mkldnn/ + ifneq ($(MKLDNNROOT), $(MKLROOT)) + CFLAGS += -I$(MKLROOT)/include + LDFLAGS += -L$(MKLROOT)/lib endif - LDFLAGS += -liomp5 + CFLAGS += -I$(MKLDNNROOT)/include + LDFLAGS += -L$(MKLDNNROOT)/lib -lmkldnn -Wl,-rpath,'$${ORIGIN}' +endif + +ifeq ($(BN_DEBUG), 1) + CFLAGS += -DMXNET_BN_DEBUG=1 endif ifeq ($(USE_OPERATOR_TUNING), 1) @@ -144,7 +141,7 @@ endif # - for Ubuntu, installing atlas will not automatically install the atlas provided lapack library # silently switching lapack off instead of letting the build fail because of backward compatibility ifeq ($(USE_LAPACK), 1) -ifeq ($(USE_BLAS),$(filter $(USE_BLAS),blas openblas atlas)) +ifeq ($(USE_BLAS),$(filter $(USE_BLAS),blas openblas atlas mkl)) ifeq (,$(wildcard /lib/liblapack.a)) ifeq (,$(wildcard /usr/lib/liblapack.a)) ifeq (,$(wildcard /usr/lib64/liblapack.a)) @@ -162,7 +159,7 @@ ifeq ($(USE_LAPACK), 1) ifneq ($(USE_LAPACK_PATH), ) LDFLAGS += -L$(USE_LAPACK_PATH) endif - ifeq ($(USE_BLAS),$(filter $(USE_BLAS),blas openblas atlas)) + ifeq ($(USE_BLAS),$(filter $(USE_BLAS),blas openblas atlas mkl)) LDFLAGS += -llapack endif CFLAGS += -DMXNET_USE_LAPACK @@ -548,7 +545,8 @@ clean: cyclean $(EXTRA_PACKAGES_CLEAN) else clean: cyclean testclean $(EXTRA_PACKAGES_CLEAN) $(RM) -r build lib bin *~ */*~ */*/*~ */*/*/*~ R-package/NAMESPACE R-package/man R-package/R/mxnet_generated.R \ - R-package/inst R-package/src/image_recordio.h R-package/src/*.o R-package/src/*.so mxnet_*.tar.gz + R-package/inst R-package/src/image_recordio.h R-package/src/*.o R-package/src/*.so mxnet_*.tar.gz \ + external/mkldnn/install/* cd $(DMLC_CORE); $(MAKE) clean; cd - cd $(PS_PATH); $(MAKE) clean; cd - cd $(NNVM_PATH); $(MAKE) clean; cd - diff --git a/amalgamation/mxnet_predict0.cc b/amalgamation/mxnet_predict0.cc index f35591d82b22..cfee60559501 100644 --- a/amalgamation/mxnet_predict0.cc +++ b/amalgamation/mxnet_predict0.cc @@ -66,7 +66,7 @@ #include "src/operator/operator_util.cc" #include "src/operator/nn/activation.cc" #include "src/operator/nn/batch_norm.cc" -#include "src/operator/concat.cc" +#include "src/operator/nn/concat.cc" #include "src/operator/nn/convolution.cc" #include "src/operator/nn/deconvolution.cc" #include "src/operator/nn/dropout.cc" diff --git a/example/image-classification/common/data.py b/example/image-classification/common/data.py index dc8915cda4c8..05f5ddc4506e 100755 --- a/example/image-classification/common/data.py +++ b/example/image-classification/common/data.py @@ -112,7 +112,8 @@ def get_rec_iter(args, kv=None): image_shape = tuple([int(l) for l in args.image_shape.split(',')]) if 'benchmark' in args and args.benchmark: data_shape = (args.batch_size,) + image_shape - train = SyntheticDataIter(args.num_classes, data_shape, 500, np.float32) + train = SyntheticDataIter(args.num_classes, data_shape, + args.num_examples / args.batch_size, np.float32) return (train, None) if kv: (rank, nworker) = (kv.rank, kv.num_workers) diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index a18d2daec8c3..eee200022873 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -35,12 +35,13 @@ #include #include #include +#include +#if MXNET_USE_MKLDNN == 1 +#include +#endif #include "./base.h" #include "./storage.h" #include "./engine.h" -#if MKL_EXPERIMENTAL == 1 -#include -#endif // check c++11 #if DMLC_USE_CXX11 == 0 #error "cxx11 was required for ndarray module" @@ -61,6 +62,9 @@ enum NDArrayStorageType { kDefaultStorage, // dense kRowSparseStorage, // row sparse kCSRStorage, // csr +#if MXNET_USE_MKLDNN == 1 + kMKLDNNStorage, // MKLDNN +#endif }; enum NDArrayFormatErr { @@ -72,6 +76,7 @@ enum NDArrayFormatErr { kRSPIdxErr, // indices error for row sparse }; +class MKLDNNMemory; /*! * \brief ndarray interface @@ -80,9 +85,6 @@ class NDArray { public: /*! \brief default constructor */ NDArray() { -#if MKL_EXPERIMENTAL == 1 - Mkl_mem_ = MKLMemHolder::create(); -#endif } /*! * \brief constructs a new dynamic NDArray @@ -96,56 +98,14 @@ class NDArray { : ptr_(std::make_shared(shape, ctx, delay_alloc, dtype)), shape_(shape), dtype_(dtype), storage_type_(kDefaultStorage), entry_({nullptr, 0, 0}) { -#if MKL_EXPERIMENTAL == 1 - Mkl_mem_ = std::make_shared(); -#endif } /*! \brief constructor for NDArray with storage type */ NDArray(const NDArrayStorageType stype, const TShape &shape, Context ctx, bool delay_alloc = true, int dtype = mshadow::default_type_flag, std::vector aux_types = {}, std::vector aux_shapes = {}, - TShape storage_shape = TShape(mshadow::Shape1(0))) - : shape_(shape), dtype_(dtype), storage_type_(stype), - entry_({nullptr, 0, 0}) { - // Assign default aux types if not given - if (aux_types.size() == 0) { - if (stype == kRowSparseStorage) { - aux_types = {mshadow::kInt64}; - } else if (stype == kCSRStorage) { - aux_types = {mshadow::kInt64, mshadow::kInt64}; - } else { - LOG(FATAL) << "Unknown storage type " << stype; - } - } - // Assign default shapes if not given - // unknown shapes are intialized as {0} such that Size() would return 0 - if (aux_shapes.size() == 0) { - if (stype == kRowSparseStorage) { - aux_shapes = {TShape(mshadow::Shape1(0))}; - } else if (stype == kCSRStorage) { - // aux shapes for indptr and indices - aux_shapes = {TShape(mshadow::Shape1(0)), TShape(mshadow::Shape1(0))}; - } else { - LOG(FATAL) << "Unknown storage type " << stype; - } - } - if (storage_shape.Size() == 0) { - if (stype == kRowSparseStorage) { - storage_shape = shape; - storage_shape[0] = aux_shapes[rowsparse::kIdx][0]; - } else if (stype == kCSRStorage) { - storage_shape = aux_shapes[csr::kIdx]; - } else { - LOG(FATAL) << "Unknown storage type " << stype; - } - } - ptr_ = std::make_shared(stype, storage_shape, ctx, delay_alloc, - dtype, aux_types, aux_shapes); -#if MKL_EXPERIMENTAL == 1 - Mkl_mem_ = std::make_shared(); -#endif - } + TShape storage_shape = TShape(mshadow::Shape1(0))); + /*! * \brief constructing a static NDArray that shares data with TBlob * Use with caution: allocate ONLY ONE NDArray for each TBlob, @@ -157,17 +117,11 @@ class NDArray { : ptr_(std::make_shared(data, dev_id)), shape_(data.shape_), dtype_(data.type_flag_), storage_type_(kDefaultStorage), entry_({nullptr, 0, 0}) { -#if MKL_EXPERIMENTAL == 1 - Mkl_mem_ = std::make_shared(); -#endif } /*! \brief create ndarray from shared memory */ NDArray(int shared_pid, int shared_id, const TShape& shape, int dtype) : ptr_(std::make_shared(shared_pid, shared_id, shape, dtype)), shape_(shape), dtype_(dtype), storage_type_(kDefaultStorage), entry_({nullptr, 0, 0}) { -#if MKL_EXPERIMENTAL == 1 - Mkl_mem_ = std::make_shared(); -#endif } /*! @@ -184,11 +138,17 @@ class NDArray { const TBlob &data, const std::vector &aux_data, int dev_id) : ptr_(std::make_shared(stype, data, aux_data, dev_id)), shape_(shape), dtype_(data.type_flag_), storage_type_(stype), entry_({nullptr, 0, 0}) { -#if MKL_EXPERIMENTAL == 1 - Mkl_mem_ = std::make_shared(); -#endif } + inline bool IsView() const { + // Sparse arrays don't have a view. + if (storage_type() == kRowSparseStorage || storage_type() == kCSRStorage) + return false; + // If the array reuses memory, it's not a view. + if (reuse_) + return false; + return byte_offset_ > 0 || shape() != ptr_->storage_shape; + } /*! * \return the shape of current NDArray. @@ -271,9 +231,6 @@ class NDArray { << "Unexpected storage type: " << stype; res = TBlob(dptr, shape, ptr_->aux_handles[i].ctx.dev_mask(), type); }); -#if MKL_EXPERIMENTAL == 1 - res.Mkl_mem_ = Mkl_mem_; -#endif return res; } /*! @@ -534,15 +491,12 @@ class NDArray { CHECK_GE(ptr_->shandle.size, shape.Size() * mshadow::mshadow_sizeof(dtype)) << "NDArray.AsArray: target memory size is bigger"; -#if MKL_EXPERIMENTAL == 1 - if (Mkl_mem_ != nullptr) { - // convert prv to cpu - Mkl_mem_->check_and_prv_to_cpu(ptr_->shandle.dptr); - } -#endif + // We can't reuse memory in a view. + CHECK(!IsView()); NDArray ret = *this; ret.shape_ = shape; ret.dtype_ = dtype; + ret.reuse_ = true; return ret; } /*! @@ -611,6 +565,64 @@ class NDArray { << "CheckAndAllocAuxData is not intended for kDefaultStorage"; ptr_->CheckAndAllocAuxData(i, aux_shape); } + +#if MXNET_USE_MKLDNN == 1 + bool IsMKLDNN() const { + return ptr_->IsMKLDNN(); + } + bool IsDefault() const { + return ptr_->IsDefault(); + } + /* + * All functions below return a raw pointer to mkldnn memory. Actually there + * is a shared pointer that hold the memory either in NDArray or in MKLDNN + * stream. As long as we call these functions inside an operator, the return + * memory is always valid. + */ + + /* + * This function returns mkldnn::memory with the default primitive_desc. + */ + const mkldnn::memory *GetMKLDNNData() const; + /* + * This function returns mkldnn::memory with the given primitive_desc + * as long as the array size meets the required size in the given primitive_desc. + */ + const mkldnn::memory *GetMKLDNNData( + const mkldnn::memory::primitive_desc &desc) const; + /* + * This function returns mkldnn::memory with the given primitive_desc. + * The returned mkldnn::memory will have the same physical layout as + * the given primitive_desc. + */ + const mkldnn::memory *GetMKLDNNDataReorder( + const mkldnn::memory::primitive_desc &desc) const; + + void CopyFrom(const mkldnn::memory &mem); + mkldnn::memory *CreateMKLDNNData( + const mkldnn::memory::primitive_desc &desc); + + /* + * Reorder the memory to the specified layout. + */ + void Reorder(const mkldnn::memory::primitive_desc &desc); + void Reorder2Default() { + CHECK_EQ(storage_type(), kDefaultStorage); + ptr_->Reorder2Default(); + } + + void InvalidateData() { + // When we invalidate data, we don't need to care about the MKLDNN format. + ptr_->Mkl_mem_ = nullptr; + } + + /* + * This function is used inside operators to reshape an array. + * It's used by FullyConnected right now. + */ + NDArray ReshapeMKLDNN(const TShape &shape) const; +#endif + /*! * \brief Save list of ndarray into the Stream.x * \param fo The stream of output. @@ -645,6 +657,12 @@ class NDArray { for csr, aux_handles[0] = indptr, aux_handles[1] = indices */ std::vector aux_handles; + +#if MXNET_USE_MKLDNN == 1 + /*! This is created when data is stored in MKLDNN format. + */ + std::shared_ptr Mkl_mem_; +#endif /*! \brief variable from engine */ Engine::VarHandle var; /*! @@ -706,7 +724,7 @@ class NDArray { : static_data(false), delay_alloc(false) { var = Engine::Get()->NewVariable(); ctx = Context::CPUShared(0); - shandle.size = shape.Size() * mshadow::mshadow_sizeof(dtype);; + shandle.size = shape.Size() * mshadow::mshadow_sizeof(dtype); shandle.ctx = ctx; shandle.shared_pid = shared_pid; shandle.shared_id = shared_id; @@ -781,6 +799,9 @@ class NDArray { inline void CheckAndAlloc(void) { if (delay_alloc) { shandle = Storage::Get()->Alloc(shandle.size, shandle.ctx); +#if MXNET_USE_MKLDNN == 1 + Mkl_mem_ = nullptr; +#endif delay_alloc = false; } } @@ -789,15 +810,22 @@ class NDArray { // size is the number of bytes void CheckAndAlloc(uint64_t dbytes) { CHECK_EQ(kDefaultStorage, storage_type) - << "CheckAndAlloc(dbytes) is not intended for kDefaultStorage"; + << "CheckAndAlloc(dbytes) is not intended for kDefaultStorage"; + dbytes = std::max(dbytes, shandle.size); if (delay_alloc) { shandle = Storage::Get()->Alloc(dbytes, shandle.ctx); +#if MXNET_USE_MKLDNN == 1 + Mkl_mem_ = nullptr; +#endif delay_alloc = false; } else if (shandle.size < dbytes) { // free storage if necessary and alloc again if (shandle.size > 0) Storage::Get()->Free(shandle); // init storage shandle = Storage::Get()->Alloc(dbytes, shandle.ctx); +#if MXNET_USE_MKLDNN == 1 + Mkl_mem_ = nullptr; +#endif } } @@ -823,20 +851,24 @@ class NDArray { // storage shape is also updated // if data is already allocated, try reuse the storage. Otherwise, free the current one // and allocate new storage - inline void CheckAndAllocData(const TShape &shape, int dtype) { - CHECK_NE(aux_shapes.size(), 0) << "data is expected to be allocated after aux_data"; - auto dbytes = shape.Size() * mshadow::mshadow_sizeof(dtype); - if (shandle.size < dbytes) { - // free storage if necessary and alloc again - if (shandle.size > 0) Storage::Get()->Free(shandle); - // init storage - shandle = Storage::Get()->Alloc(dbytes, ctx); - } - // init shape - storage_shape = shape; - // delay_alloc is only set when data storage handle is present - delay_alloc = false; + void CheckAndAllocData(const TShape &shape, int dtype); + +#if MXNET_USE_MKLDNN == 1 + // Have MKL memory reference to the data in the default storage + // or create memory for MKLDNN. + void SetMKLMem(const TShape &shape, int dtype); + void ResetMKLMem() { + // If Mkl_mem_ isn't referencing to shandle, we need to reset Mkl_mem_. + if (Mkl_mem_ && Mkl_mem_->get_data_handle() != shandle.dptr) + Mkl_mem_ = nullptr; } + // In the data is stored in MKLDNN layout, we reorder data in Mkl_mem_ and + // save the result in shandle. + void Reorder2Default(); + bool IsMKLDNN() const; + bool IsDefault() const; +#endif + // create storage handle for aux data based on shape // this function assumes ctx, aux shapes and aux types are set // aux shape is also updated @@ -862,45 +894,11 @@ class NDArray { set_aux_shape(i, shape); } /*! \brief destructor */ - ~Chunk() { - bool skip_free = static_data || delay_alloc; - Storage::Handle h = this->shandle; - std::vector aux_h = this->aux_handles; - Engine::Get()->DeleteVariable([h, aux_h, skip_free](RunContext s) { - if (skip_free == false) { - Storage::Get()->Free(h); - for (size_t i = 0; i < aux_h.size(); i++) { - if (aux_h[i].size > 0) Storage::Get()->Free(aux_h[i]); - } - } - }, shandle.ctx, var); - } + ~Chunk(); }; // struct Chunk - void SetTBlob() const { - CHECK(ptr_ != nullptr); - TShape shape = shape_; - char *dptr = static_cast(ptr_->shandle.dptr); - auto stype = storage_type(); - if (stype == kDefaultStorage) { - dptr += byte_offset_; - } else if (stype == kCSRStorage || stype == kRowSparseStorage) { - shape = storage_shape(); - } else { - LOG(FATAL) << "unknown storage type " << stype; - } - tblob_.dptr_ = dptr; - tblob_.shape_ = shape; - tblob_.type_flag_ = dtype_; - tblob_.SetDLTensor(ptr_->shandle.ctx.dev_mask(), ptr_->shandle.ctx.dev_id); -#if MKL_EXPERIMENTAL == 1 - tblob_.Mkl_mem_ = Mkl_mem_; -#endif - } + void SetTBlob() const; -#if MKL_EXPERIMENTAL == 1 - std::shared_ptr Mkl_mem_; -#endif /*! \brief internal data of NDArray */ std::shared_ptr ptr_{nullptr}; /*! \brief shape of current NDArray */ @@ -909,6 +907,8 @@ class NDArray { size_t byte_offset_ = 0; /*! \brief type of data */ int dtype_ = -1; + /*! \brief whether the NDArray uses memory of another NDArray. */ + bool reuse_ = false; /*! \brief storage type of data */ NDArrayStorageType storage_type_ = kUndefinedStorage; /*! \brief node entry for autograd */ diff --git a/include/mxnet/tensor_blob.h b/include/mxnet/tensor_blob.h index b65cd2b434e4..168ddcca24b7 100755 --- a/include/mxnet/tensor_blob.h +++ b/include/mxnet/tensor_blob.h @@ -36,9 +36,6 @@ #include #include #include "./base.h" -#if MXNET_USE_MKL2017 == 1 -#include -#endif namespace mxnet { /* Forward declaration for friend declaration in TBlob */ @@ -66,17 +63,10 @@ class TBlob { /*! \brief type flag of the tensor blob */ int type_flag_; - /*! \brief storing mkl chunk buffer blob, use for experimental only */ -#if MKL_EXPERIMENTAL == 1 - std::shared_ptr Mkl_mem_; -#endif /*! \brief default constructor, default copy assign will work */ TBlob(void) : dptr_(NULL), type_flag_(mshadow::DataType::kFlag) { -#if MKL_EXPERIMENTAL == 1 - Mkl_mem_ = NULL; -#endif SetDLTensor(cpu::kDevMask, 0); } /*! @@ -90,9 +80,6 @@ class TBlob { TBlob(DType *dptr, const TShape &shape, int dev_mask, int dev_id = -1) : dptr_(dptr), shape_(shape), type_flag_(mshadow::DataType::kFlag) { -#if MKL_EXPERIMENTAL == 1 - Mkl_mem_ = NULL; -#endif SetDLTensor(dev_mask, dev_id); } /*! @@ -105,9 +92,6 @@ class TBlob { */ TBlob(void *dptr, const TShape &shape, int dev_mask, int type_flag, int dev_id = -1) : dptr_(dptr), shape_(shape), type_flag_(type_flag) { -#if MKL_EXPERIMENTAL == 1 - Mkl_mem_ = NULL; -#endif SetDLTensor(dev_mask, dev_id); } /*! @@ -135,9 +119,6 @@ class TBlob { shape_ = src.shape_; type_flag_ = mshadow::DataType::kFlag; SetDLTensor(Device::kDevMask, -1); -#if MKL_EXPERIMENTAL == 1 - Mkl_mem_ = NULL; -#endif return *this; } /*! @@ -172,11 +153,6 @@ class TBlob { CHECK(mshadow::DataType::kFlag == type_flag_) << "TBlob.get_with_shape: data type do not match specified type." << "Expected: " << type_flag_ << " v.s. given " << mshadow::DataType::kFlag; -#if MKL_EXPERIMENTAL == 1 - if (Mkl_mem_ != nullptr) { - Mkl_mem_->check_and_prv_to_cpu(dptr_); - } -#endif return mshadow::Tensor(static_cast(dptr_), shape_.FlatTo2D(), shape_[shape_.ndim() - 1], @@ -217,11 +193,6 @@ class TBlob { CHECK(mshadow::DataType::kFlag == type_flag_) << "TBlob.get_with_shape: data type do not match specified type." << "Expected: " << type_flag_ << " v.s. given " << mshadow::DataType::kFlag; -#if MKL_EXPERIMENTAL == 1 - if (Mkl_mem_ != nullptr) { - Mkl_mem_->check_and_prv_to_cpu(dptr_); - } -#endif return static_cast(dptr_); } /*! \brief device mask of the corresponding device */ diff --git a/prepare_mkldnn.sh b/prepare_mkldnn.sh new file mode 100755 index 000000000000..525ee14775cf --- /dev/null +++ b/prepare_mkldnn.sh @@ -0,0 +1,114 @@ +#!/bin/bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# set -ex +# +# All modification made by Intel Corporation: © 2016 Intel Corporation +# +# All contributions by the University of California: +# Copyright (c) 2014, 2015, The Regents of the University of California (Regents) +# All rights reserved. +# +# All other contributions: +# Copyright (c) 2014, 2015, the respective contributors +# All rights reserved. +# For the list of contributors go to https://github.com/BVLC/caffe/blob/master/CONTRIBUTORS.md +# +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of Intel Corporation nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +MXNET_ROOTDIR="$(pwd)" +MKLDNN_ROOTDIR="$MXNET_ROOTDIR/3rdparty/mkldnn/" +MKLDNN_SRCDIR="$MKLDNN_ROOTDIR/src" +MKLDNN_BUILDDIR="$MKLDNN_ROOTDIR/build" +MKLDNN_INSTALLDIR="$MKLDNN_ROOTDIR/install" +MKLDNN_LIBDIR="$MXNET_ROOTDIR/lib" + +# MKLDNN install destination +HOME_MKLDNN=$1 +if [ ! -z "$HOME_MKLDNN" ]; then + mkdir -p $HOME_MKLDNN + if [ ! -w $HOME_MKLDNN ]; then + echo "MKLDNN install to $HOME_MKLDNN failed, please try with sudo" >&2 + exit 1 + fi +fi + +if [ -z $MKLDNNROOT ]; then +if [ ! -f "$MKLDNN_INSTALLDIR/lib/libmkldnn.so" ]; then + mkdir -p $MKLDNN_INSTALLDIR + cd $MKLDNN_ROOTDIR + if [ -z $MKLROOT ] && [ ! -f $MKLDNN_INSTALLDIR/include/mkl_cblas.h ]; then + rm -rf external && cd scripts && ./prepare_mkl.sh && cd .. + cp -a external/*/* $MKLDNN_INSTALLDIR/. + fi + echo "Building MKLDNN ..." >&2 + cd $MXNET_ROOTDIR + g++ --version >&2 + cmake $MKLDNN_ROOTDIR -DCMAKE_INSTALL_PREFIX=$MKLDNN_INSTALLDIR -B$MKLDNN_BUILDDIR + make -C $MKLDNN_BUILDDIR -j$(cat /proc/cpuinfo | grep processor | wc -l) VERBOSE=1 >&2 + make -C $MKLDNN_BUILDDIR install + rm -rf $MKLDNN_BUILDDIR + mkdir -p $MKLDNN_LIBDIR + cp $MKLDNN_INSTALLDIR/lib/* $MKLDNN_LIBDIR +fi +MKLDNNROOT=$MKLDNN_INSTALLDIR +fi + +if [ -z $MKLROOT ] && [ -f $MKLDNNROOT/include/mkl_cblas.h ]; then + MKLROOT=$MKLDNNROOT; +fi + +# user specified MKLDNN install folder +if [ -d "$HOME_MKLDNN" ]; then + # skip if user specificed MKLDNNROOT + [ "$MKLDNNROOT" != "$HOME_MKLDNN" ] && rsync -a $MKLDNNROOT/include $MKLDNNROOT/lib $HOME_MKLDNN/. + [ "$MKLROOT" != "$HOME_MKLDNN" ] && rsync -a $MKLROOT/include $MKLROOT/lib $HOME_MKLDNN/. + # update ldconfig if possible + if [ -w /etc/ld.so.conf.d ]; then + echo "$HOME_MKLDNN/lib" > /etc/ld.so.conf.d/mxnmkldnn.conf && ldconfig + fi +# return value to calling script (Makefile,cmake) + echo $HOME_MKLDNN $HOME_MKLDNN +else + echo $MKLDNNROOT $MKLROOT +fi + diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index 58bc8d38f685..e48c61e8de10 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -1287,6 +1287,10 @@ def check_consistency(sym, ctx_list, scale=1.0, grad_req='write', arr[:] = arg_params[name] for name, arr in exe.aux_dict.items(): arr[:] = aux_params[name] + # We need to initialize the gradient arrays if it's add. + if (grad_req == "add"): + for arr in exe.grad_arrays: + arr[:] = np.zeros(arr.shape, dtype=arr.dtype) dtypes = [np.dtype(exe.outputs[0].dtype) for exe in exe_list] max_idx = np.argmax(dtypes) diff --git a/src/common/exec_utils.h b/src/common/exec_utils.h index dcd1504fb88e..2862c2cc4795 100644 --- a/src/common/exec_utils.h +++ b/src/common/exec_utils.h @@ -43,19 +43,61 @@ namespace common { indices are not recorded * \return true if any source NDArray need to cast storage */ -inline bool SetupDefaultBlobs(const std::vector& src, - std::vector *blobs, - std::vector *temp_src, - std::vector *temp_dst, - std::unordered_map *idx_map = nullptr) { +inline bool SetupDefaultBlobsIn(const std::vector& src, + const std::vector *bufs, + std::vector *blobs, + std::vector *temp_src, + std::vector *temp_dst, + std::unordered_map *idx_map) { bool require_cast = false; for (size_t i = 0; i < src.size(); i++) { auto& nd = src[i]; - if (nd.storage_type() != kDefaultStorage) { - if (idx_map != nullptr) { - (*idx_map)[i] = temp_dst->size(); - } - NDArray temp(nd.shape(), nd.ctx(), false, nd.dtype()); + bool is_default = nd.storage_type() == kDefaultStorage; +#if MXNET_USE_MKLDNN == 1 + // We have to make sure it's default storage and default layout. + is_default = nd.IsDefault(); +#endif + if (!is_default) { + (*idx_map)[i] = temp_dst->size(); + NDArray temp = bufs != nullptr ? bufs->at(i) : NDArray(nd.shape(), nd.ctx(), + true, nd.dtype()); +#if MXNET_USE_MKLDNN == 1 + CHECK(temp.IsDefault()); +#endif + temp_src->emplace_back(nd); + temp_dst->emplace_back(temp); + blobs->emplace_back(temp.data()); + require_cast = true; + } else { + blobs->push_back(nd.data()); + } + } + return require_cast; +} + +inline bool SetupDefaultBlobsOut(const std::vector& src, + const std::vector &req, + const std::vector *bufs, + std::vector *blobs, + std::vector *temp_src, + std::vector *temp_dst) { + bool require_cast = false; + for (size_t i = 0; i < src.size(); i++) { + auto& nd = src[i]; + bool is_default = nd.storage_type() == kDefaultStorage; +#if MXNET_USE_MKLDNN == 1 + // If it's writeTo, we don't need to worry whether it contains valid data. + if (req[i] == kWriteTo && is_default) + const_cast(nd).InvalidateData(); + // We have to make sure it's default storage and default layout. + is_default = nd.IsDefault(); +#endif + if (!is_default) { + NDArray temp = bufs != nullptr ? bufs->at(i) : NDArray(nd.shape(), nd.ctx(), + true, nd.dtype()); +#if MXNET_USE_MKLDNN == 1 + CHECK(temp.IsDefault()); +#endif temp_src->emplace_back(nd); temp_dst->emplace_back(temp); blobs->emplace_back(temp.data()); @@ -76,6 +118,9 @@ inline bool SetupDefaultBlobs(const std::vector& src, */ inline void SetupDefaultBlobsInOut(const std::vector &ndinputs, const std::vector &ndoutputs, + const std::vector &req, + const std::vector *in_bufs, + const std::vector *out_bufs, std::vector *input_blobs, std::vector *output_blobs, std::vector *pre_temp_src, @@ -85,9 +130,11 @@ inline void SetupDefaultBlobsInOut(const std::vector &ndinputs, std::unordered_map *in_temp_idx_map, const std::vector &mutate_idx) { // populate input blobs - SetupDefaultBlobs(ndinputs, input_blobs, pre_temp_src, pre_temp_dst, in_temp_idx_map); + SetupDefaultBlobsIn(ndinputs, in_bufs, input_blobs, pre_temp_src, pre_temp_dst, + in_temp_idx_map); // populate output blobs - SetupDefaultBlobs(ndoutputs, output_blobs, post_temp_dst, post_temp_src); + SetupDefaultBlobsOut(ndoutputs, req, out_bufs, output_blobs, post_temp_dst, + post_temp_src); // add mutable inputs to post temp list for (const auto idx : mutate_idx) { auto map_iter = in_temp_idx_map->find(idx); diff --git a/src/common/utils.cc b/src/common/utils.cc index 784fcf8651ae..939b3e8d0a1b 100644 --- a/src/common/utils.cc +++ b/src/common/utils.cc @@ -41,5 +41,21 @@ void CastStorageDispatch(const OpContext& ctx, mxnet::op::CastStorageComputeImpl(ctx, input, output); } +std::string stype_string(const int x) { + switch (x) { + case kDefaultStorage: + return "default"; + case kCSRStorage: + return "csr"; + case kRowSparseStorage: + return "row_sparse"; +#if MXNET_USE_MKLDNN == 1 + case kMKLDNNStorage: + return "mkldnn"; +#endif + } + return "unknown"; +} + } // namespace common } // namespace mxnet diff --git a/src/common/utils.h b/src/common/utils.h index 6f7e452565c2..8ccf48083634 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -217,6 +217,28 @@ void CheckFormatImpl(const RunContext &rctx, const NDArray &input, template void CastStorageDispatch(const OpContext& ctx, const NDArray& input, const NDArray& output); +/*! \brief returns true if one of storage types in `inputs` is the same as target `stype`. + */ +inline bool ContainsStorage(const std::vector& inputs, + NDArrayStorageType type) { + for (const auto &i : inputs) { + if (i.storage_type() == type) + return true; + } + return false; +} + +/*! \brief returns true if one of storage types in `vstorage` is the same as target `stype`. + */ +inline bool ContainsStorage(const std::vector &vstorages, + NDArrayStorageType type) { + for (const auto& i : vstorages) { + if (i == type) + return true; + } + return false; +} + /*! \brief returns true if all storage types in `vstorage` are the same as target `stype`. * false is returned for empty inputs. */ @@ -326,17 +348,7 @@ inline std::string dispatch_mode_string(const DispatchMode x) { /*! \brief get string representation of storage_type */ -inline std::string stype_string(const int x) { - switch (x) { - case kDefaultStorage: - return "default"; - case kCSRStorage: - return "csr"; - case kRowSparseStorage: - return "row_sparse"; - } - return "unknown"; -} +std::string stype_string(const int x); /*! \brief get string representation of device type */ inline std::string dev_type_string(const int dev_type) { diff --git a/src/executor/attach_op_execs_pass.cc b/src/executor/attach_op_execs_pass.cc index 1bcc40a894dd..e4d49554620f 100644 --- a/src/executor/attach_op_execs_pass.cc +++ b/src/executor/attach_op_execs_pass.cc @@ -30,11 +30,8 @@ #include "../common/utils.h" #include "../common/exec_utils.h" #include "./exec_pass.h" -#if MXNET_USE_MKL2017 == 1 -#include -#include "../operator/mkl/mkl_memory-inl.h" -#include "../operator/mkl/mkl_util-inl.h" -#endif +#include "../operator/nn/mkldnn/mkldnn_base-inl.h" + namespace mxnet { namespace op { @@ -58,23 +55,34 @@ class StorageFallbackOpExecutor : public OpExecutor { protected: // initialize the data blobs void InitBlobs() { - using namespace common; if (!init_) { - in_data_.clear(); out_data_.clear(); - pre_temp_src_.clear(); pre_temp_dst_.clear(); - post_temp_src_.clear(); post_temp_dst_.clear(); - in_temp_idx_map_.clear(); - SetupDefaultBlobsInOut(in_array, out_array, &in_data_, &out_data_, - &pre_temp_src_, &pre_temp_dst_, - &post_temp_src_, &post_temp_dst_, - &in_temp_idx_map_, mutate_idx_); + pre_temp_buf_.clear(); + post_temp_buf_.clear(); + for (size_t i = 0; i < in_array.size(); i++) { + auto &nd = in_array[i]; + pre_temp_buf_.emplace_back(nd.shape(), nd.ctx(), true, nd.dtype()); + } + for (size_t i = 0; i < out_array.size(); i++) { + auto &nd = out_array[i]; + post_temp_buf_.emplace_back(nd.shape(), nd.ctx(), true, nd.dtype()); + } init_ = true; } } // storage fallback before fcompute is launched void PreFCompute(bool is_gpu) { + using namespace common; InitBlobs(); + in_data_.clear(); out_data_.clear(); + pre_temp_src_.clear(); pre_temp_dst_.clear(); + post_temp_src_.clear(); post_temp_dst_.clear(); + in_temp_idx_map_.clear(); + SetupDefaultBlobsInOut(in_array, out_array, req, &pre_temp_buf_, &post_temp_buf_, + &in_data_, &out_data_, + &pre_temp_src_, &pre_temp_dst_, + &post_temp_src_, &post_temp_dst_, + &in_temp_idx_map_, mutate_idx_); common::CastNonDefaultStorage(pre_temp_src_, pre_temp_dst_, op_ctx, is_gpu); } @@ -85,6 +93,8 @@ class StorageFallbackOpExecutor : public OpExecutor { // default storage tensor blobs for fcompute std::vector in_data_, out_data_; + // These are NDArray buffers for cast storage. + std::vector pre_temp_buf_, post_temp_buf_; // source NDArray for cast storage std::vector pre_temp_src_, post_temp_src_; // destination NDArray for cast storage @@ -106,10 +116,6 @@ class StatefulComputeExecutor : public StorageFallbackOpExecutor { PreFCompute(is_gpu); fcompute_(state_, op_ctx, in_data_, req, out_data_); PostFCompute(is_gpu); -#if MKL_EXPERIMENTAL == 1 - mkl_tblobs_prv_to_cpu(in_data_); - mkl_tblobs_prv_to_cpu(out_data_); -#endif } ExecType exec_type() const override { @@ -175,10 +181,6 @@ class FComputeExecutor : public StorageFallbackOpExecutor { PreFCompute(is_gpu); fcompute_(attrs_, op_ctx, in_data_, req, out_data_); PostFCompute(is_gpu); -#if MKL_EXPERIMENTAL == 1 - mkl_tblobs_prv_to_cpu(in_data_); - mkl_tblobs_prv_to_cpu(out_data_); -#endif } ExecType exec_type() const override { @@ -202,6 +204,9 @@ class FComputeExExecutor : public OpExecutor { public: void Run(RunContext rctx, bool is_gpu) override { op_ctx.run_ctx = rctx; +#if MXNET_USE_MKLDNN == 1 + InvalidateOutputs(out_array, req); +#endif fcompute_(attrs_, op_ctx, in_array, req, out_array); } diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 42508b1bad46..80ee03808558 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -54,6 +54,14 @@ GraphExecutor::~GraphExecutor() { } } +inline bool SharableStorage(NDArrayStorageType stype) { + bool ret = stype == kDefaultStorage; +#if MXNET_USE_MKLDNN == 1 + ret = ret || stype == kMKLDNNStorage; +#endif + return ret; +} + inline NDArray InitZeros(const NDArrayStorageType stype, const TShape &shape, const Context &ctx, const int dtype) { // NDArray with default storage @@ -693,7 +701,7 @@ static NDArray ReshapeOrCreate(const std::string& name, const Context& ctx, std::unordered_map* shared_buffer, bool enable_row_sparse_sharing) { - bool stype_shareable = dest_arg_stype == kDefaultStorage; + bool stype_shareable = SharableStorage(dest_arg_stype); if (enable_row_sparse_sharing) { stype_shareable = stype_shareable || dest_arg_stype == kRowSparseStorage; } @@ -793,7 +801,7 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, const NDArray& in_arg_nd = shared_exec->in_arg_map().at(arg_name); auto arg_nd_stype = in_arg_nd.storage_type(); // for model parameter, both default storage and row_sparse storage can be shared - bool shareable_arg_stype = inferred_stype == kDefaultStorage || + bool shareable_arg_stype = SharableStorage(inferred_stype) || inferred_stype == kRowSparseStorage; // try to reuse memory from shared_exec CHECK(shareable_arg_stype) << "Inferred storage type " @@ -827,8 +835,8 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, auto grad_oid = grad_store_.size() + num_forward_outputs_; auto grad_eid = idx.entry_id(idx.outputs()[grad_oid]); auto grad_stype = (NDArrayStorageType) inferred_stypes[grad_eid]; - if (nullptr != shared_exec && grad_stype == kDefaultStorage && - shared_exec->arg_grad_map().at(arg_name).storage_type() == kDefaultStorage) { + if (nullptr != shared_exec && SharableStorage(grad_stype) && + shared_exec->arg_grad_map().at(arg_name).storage_type() == grad_stype) { // try to reuse memory from shared_exec arg_grad_vec->emplace_back(shared_exec->arg_grad_map().at(arg_name)); } else { @@ -1209,7 +1217,8 @@ void GraphExecutor::InitDataEntryMemory(std::vector* shared_pool) { const NDArray& src = data_pool_.at(storage_id); data_entry_[i] = src.AsArray(vshape[i], vdtype[i]); } else { - data_entry_[i] = NDArray(storage_type, vshape[i], data_context[i]); + data_entry_[i] = NDArray(storage_type, vshape[i], data_context[i], + true, vdtype[i]); } if (log_verbose_) { LOG(INFO) << "\tinit data entry\t" << i << "\tas " << common::stype_string(storage_type); diff --git a/src/executor/infer_graph_attr_pass.cc b/src/executor/infer_graph_attr_pass.cc index 67e61aa357c2..3bf3c9b8545a 100644 --- a/src/executor/infer_graph_attr_pass.cc +++ b/src/executor/infer_graph_attr_pass.cc @@ -416,11 +416,6 @@ nnvm::Graph InferStorageType(nnvm::Graph&& graph, DispatchModeVector dispatch_modes(graph.indexed_graph().num_nodes(), DispatchMode::kUndefined); graph.attrs["dispatch_mode"] = std::make_shared(std::move(dispatch_modes)); } - // initialize unknown values for dispatch modes - if (graph.attrs.count("dispatch_mode") == 0) { - DispatchModeVector dispatch_modes(graph.indexed_graph().num_nodes(), DispatchMode::kUndefined); - graph.attrs["dispatch_mode"] = std::make_shared(std::move(dispatch_modes)); - } // initialize the dev_mask vector from the context vector if (graph.attrs.count("dev_mask") == 0) { CHECK_GT(graph.attrs.count("context"), 0); diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index eaa95a5f2418..93a8bc6c54b2 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -214,6 +214,12 @@ nnvm::Graph Imperative::CachedOp::GetForwardGraph( StorageVector storage(idx.num_node_entries(), exec::kBadStorageID); for (const auto i : idx.input_nodes()) storage[idx.entry_id(i, 0)] = exec::kExternalStorageID; + const auto& stypes = g.GetAttr("storage_type"); + CHECK_EQ(stypes.size(), storage.size()); + for (size_t i = 0; i < stypes.size(); i++) { + if (stypes[i] != kDefaultStorage) + storage[i] = exec::kDynamicStorageID; + } auto mem_plan = PlanMemory( &g, std::move(storage), g.GetAttr >( @@ -320,6 +326,10 @@ nnvm::Graph Imperative::CachedOp::GetBackwardGraph( for (size_t i = 0; i < num_forward_entries; ++i) storage[i] = exec::kExternalStorageID; for (const auto i : idx.input_nodes()) storage[idx.entry_id(i, 0)] = exec::kExternalStorageID; for (const auto i : idx.outputs()) storage[idx.entry_id(i)] = exec::kExternalStorageID; + for (size_t i = 0; i < stypes.size(); i++) { + if (stypes[i] != kDefaultStorage) + storage[i] = exec::kDynamicStorageID; + } auto mem_plan = PlanMemory( &g, std::move(storage), g.GetAttr >("backward_ref_count"), diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index add568d6c04c..3bbd52fa0102 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -356,9 +356,9 @@ inline void PushFCompute(const FCompute& fn, // mapping from index in input_blobs to index in pre_temp_dst std::unordered_map in_temp_idx_map; // setup blobs - SetupDefaultBlobsInOut(inputs, outputs, &input_blobs, &output_blobs, - &pre_temp_src, &pre_temp_dst, &post_temp_src, - &post_temp_dst, &in_temp_idx_map, mutate_idx); + SetupDefaultBlobsInOut(inputs, outputs, req, nullptr, nullptr, + &input_blobs, &output_blobs, &pre_temp_src, &pre_temp_dst, + &post_temp_src, &post_temp_dst, &in_temp_idx_map, mutate_idx); // setup context OpContext opctx{is_train, rctx, engine::CallbackOnComplete(), requested}; bool is_gpu = ctx.dev_mask() == gpu::kDevMask; @@ -454,9 +454,9 @@ inline void PushOperator(const OpStatePtr& state, // mapping from index in input_blobs to index in pre_temp_dst std::unordered_map in_temp_idx_map; // populate input blobs and output blobs - SetupDefaultBlobsInOut(inputs, outputs, &input_blobs, &output_blobs, - &pre_temp_src, &pre_temp_dst, &post_temp_src, &post_temp_dst, - &in_temp_idx_map, mutate_idx); + SetupDefaultBlobsInOut(inputs, outputs, req, nullptr, nullptr, + &input_blobs, &output_blobs, &pre_temp_src, &pre_temp_dst, + &post_temp_src, &post_temp_dst, &in_temp_idx_map, mutate_idx); // setup contexts bool is_gpu = rctx.get_ctx().dev_mask() == gpu::kDevMask; // pre-fcompute fallback @@ -601,6 +601,7 @@ inline bool CheckAndInferStorageType(nnvm::Graph* p_g, exec::DevMaskVector&& dev } if (match) return true; } + g.attrs.erase("dispatch_mode"); g.attrs.erase("storage_type"); g.attrs.erase("storage_type_inputs"); if (node_range.second > node_range.first) { diff --git a/src/kvstore/kvstore_dist.h b/src/kvstore/kvstore_dist.h index b00d0de935f7..d0a968154afb 100644 --- a/src/kvstore/kvstore_dist.h +++ b/src/kvstore/kvstore_dist.h @@ -32,11 +32,6 @@ #include "mxnet/engine.h" #include "ps/ps.h" #include "./kvstore_dist_server.h" -#if MKL_EXPERIMENTAL == 1 -#include -#include "../operator/mkl/mkl_memory-inl.h" -#include "../operator/mkl/mkl_util-inl.h" -#endif namespace mxnet { namespace kvstore { @@ -228,9 +223,6 @@ class KVStoreDist : public KVStoreLocal { PSKV& pskv = (gradient_compression_->get_type() == CompressionType::kNone) ? EncodeDefaultKey(key, size, false) : EncodeCompressedKey(key, size, false); -#if MKL_EXPERIMENTAL == 1 - mkl_set_tblob_eager_mode(recv_buf.data()); -#endif real_t* data = recv_buf.data().dptr(); // false means not to delete data when SArray is deleted auto vals = new ps::SArray(data, size, false); @@ -380,9 +372,6 @@ class KVStoreDist : public KVStoreLocal { [this, key, pskv, small_buf](RunContext rctx, Engine::CallbackOnComplete cb) { size_t size = small_buf.shape().Size(); real_t* data = small_buf.data().dptr(); -#if MKL_EXPERIMENTAL == 1 - mkl_set_tblob_eager_mode(small_buf.data()); -#endif // do push. false means no delete ps::SArray vals(data, size, false); CHECK_NOTNULL(ps_worker_)->ZPush( @@ -407,9 +396,6 @@ class KVStoreDist : public KVStoreLocal { // convert to ps keys size_t size = send_buf.shape().Size(); real_t* data = send_buf.data().dptr(); -#if MKL_EXPERIMENTAL == 1 - mkl_set_tblob_eager_mode(send_buf.data()); -#endif // do push. false means no delete ps::SArray vals(data, size, false); CHECK_NOTNULL(ps_worker_)->ZPush( @@ -431,9 +417,6 @@ class KVStoreDist : public KVStoreLocal { using namespace rowsparse; auto push_to_servers = [this, key, send_buf] (RunContext rctx, Engine::CallbackOnComplete cb) { -#if MKL_EXPERIMENTAL == 1 - mkl_set_tblob_eager_mode(send_buf.data()); -#endif real_t* data = send_buf.data().dptr(); const int64_t num_rows = send_buf.aux_shape(kIdx)[0]; const auto offsets = send_buf.aux_data(kIdx).dptr(); @@ -472,9 +455,6 @@ class KVStoreDist : public KVStoreLocal { // allocate memory for the buffer size_t num_rows = indices.shape().Size(); recv_buf.CheckAndAlloc({mshadow::Shape1(num_rows)}); -#if MKL_EXPERIMENTAL == 1 - mkl_set_tblob_eager_mode(recv_buf.data()); -#endif real_t* data = recv_buf.data().dptr(); const auto offsets = indices.data().dptr(); const auto unit_len = recv_buf.shape().ProdShape(1, recv_buf.shape().ndim()); diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h index 1bb84fdc1114..5646d9eef866 100644 --- a/src/kvstore/kvstore_local.h +++ b/src/kvstore/kvstore_local.h @@ -256,7 +256,13 @@ class KVStoreLocal : public KVStore { auto validator = [this](const int key, const NDArray& nd) -> bool { auto stype = nd.storage_type(); // valid NDArray - if (stype == kDefaultStorage || stype == kRowSparseStorage) return true; + auto valid_stype = stype == kDefaultStorage || stype == kRowSparseStorage; +#if MXNET_USE_MKLDNN == 1 + // When it's kMKLDNNStorage, it'll be converted to a data layout + // compatible to the default storage. + valid_stype = valid_stype || stype == kMKLDNNStorage; +#endif + if (valid_stype) return true; // invalid NDArray, abort LOG(FATAL) << "Unexpected storage type detected during kvstore push: " << stype; return false; @@ -272,8 +278,15 @@ class KVStoreLocal : public KVStore { std::vector> *grouped_vals) { // check if the storage type of a value is valid auto validator = [this](const int key, const NDArray* nd) -> bool { + auto stype = nd->storage_type(); // valid - if (nd->storage_type() == kDefaultStorage) return true; + auto valid_stype = stype == kDefaultStorage; +#if MXNET_USE_MKLDNN == 1 + // When it's kMKLDNNStorage, it'll be converted to a data layout + // compatible to the default storage. + valid_stype = valid_stype || stype == kMKLDNNStorage; +#endif + if (valid_stype) return true; // invalid, print warning messages once if (this->warnings_printed_.find(key) == this->warnings_printed_.end()) { LOG(INFO) << "Warning: non-default weights detected during kvstore pull. " diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 8a3bb8d59b0f..0def300271d1 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -31,10 +31,14 @@ #include #include #include +#if MXNET_USE_MKLDNN == 1 +#include +#endif #include "./ndarray_function.h" #include "../common/utils.h" #include "../operator/tensor/matrix_op-inl.h" #include "../operator/tensor/init_op.h" +#include "../operator/nn/mkldnn/mkldnn_base-inl.h" #if MXNET_USE_OPENCV #include @@ -46,6 +50,104 @@ DMLC_REGISTRY_ENABLE(::mxnet::NDArrayFunctionReg); namespace mxnet { +NDArray::NDArray(const NDArrayStorageType stype, const TShape &shape, Context ctx, + bool delay_alloc, int dtype, std::vector aux_types, + std::vector aux_shapes, TShape storage_shape) : shape_(shape), + dtype_(dtype), storage_type_(stype), entry_({nullptr, 0, 0}) { + // Assign default aux types if not given + if (aux_types.size() == 0 + && stype != kDefaultStorage) { + if (stype == kRowSparseStorage) { + aux_types = {mshadow::kInt64}; + } else if (stype == kCSRStorage) { + aux_types = {mshadow::kInt64, mshadow::kInt64}; + } else { + LOG(FATAL) << "Unknown storage type " << stype; + } + } + // Assign default shapes if not given + // unknown shapes are intialized as {0} such that Size() would return 0 + if (aux_shapes.size() == 0 + && stype != kDefaultStorage) { + if (stype == kRowSparseStorage) { + aux_shapes = {TShape(mshadow::Shape1(0))}; + } else if (stype == kCSRStorage) { + // aux shapes for indptr and indices + aux_shapes = {TShape(mshadow::Shape1(0)), TShape(mshadow::Shape1(0))}; + } else { + LOG(FATAL) << "Unknown storage type " << stype; + } + } + if (storage_shape.Size() == 0 + && stype != kDefaultStorage) { + if (stype == kRowSparseStorage) { + storage_shape = shape; + storage_shape[0] = aux_shapes[rowsparse::kIdx][0]; + } else if (stype == kCSRStorage) { + storage_shape = aux_shapes[csr::kIdx]; + } else { + LOG(FATAL) << "Unknown storage type " << stype; + } + } + if (stype == kDefaultStorage) + ptr_ = std::make_shared(shape, ctx, delay_alloc, dtype); + else + ptr_ = std::make_shared(stype, storage_shape, ctx, delay_alloc, + dtype, aux_types, aux_shapes); +} + +struct ChunkMem { + Storage::Handle h; + std::vector aux_h; +#if MXNET_USE_MKLDNN == 1 + std::shared_ptr mem; +#endif +}; + +NDArray::Chunk::~Chunk() { + bool skip_free = static_data || delay_alloc; + ChunkMem mem; + mem.h = this->shandle; + mem.aux_h = this->aux_handles; +#if MXNET_USE_MKLDNN == 1 + // We want to delete mkldnn memory after deleting the variable. + mem.mem = this->Mkl_mem_; +#endif + Engine::Get()->DeleteVariable([mem, skip_free](RunContext s) { + if (skip_free == false) { +#if MXNET_USE_MKLDNN == 1 + if (mem.mem) { + CHECK_LE(mem.mem->get_primitive_desc().get_size(), mem.h.size); + CHECK_EQ(mem.mem->get_data_handle(), mem.h.dptr); + } +#endif + if (mem.h.size > 0) Storage::Get()->Free(mem.h); + for (size_t i = 0; i < mem.aux_h.size(); i++) { + if (mem.aux_h[i].size > 0) Storage::Get()->Free(mem.aux_h[i]); + } + } + }, shandle.ctx, var); +} + +void NDArray::Chunk::CheckAndAllocData(const TShape &shape, int dtype) { + CHECK_NE(aux_shapes.size(), 0) + << "data is expected to be allocated after aux_data"; + auto dbytes = shape.Size() * mshadow::mshadow_sizeof(dtype); + if (shandle.size < dbytes) { + // free storage if necessary and alloc again + if (shandle.size > 0) Storage::Get()->Free(shandle); + // init storage + shandle = Storage::Get()->Alloc(dbytes, ctx); +#if MXNET_USE_MKLDNN == 1 + Mkl_mem_ = nullptr; +#endif + } + // init shape + storage_shape = shape; + // delay_alloc is only set when data storage handle is present + delay_alloc = false; +} + NDArray NDArray::grad() const { if (Imperative::AGInfo::IsNone(*this)) return NDArray(); Imperative::AGInfo& info = Imperative::AGInfo::Get(entry_.node); @@ -64,14 +166,55 @@ nnvm::Symbol NDArray::get_autograd_symbol() const { return ret; } +#if MXNET_USE_MKLDNN == 1 + +struct EmptyMKLDNNDeleter { + void operator()(mkldnn::memory *mem) { + } +}; + +NDArray NDArray::ReshapeMKLDNN(const TShape &shape) const { + CHECK(!is_none()) << "NDArray is not initialized"; + CHECK_GE(shape_.Size(), shape.Size()) + << "NDArray.Reshape: target shape size is larger current shape"; + CHECK_EQ(storage_type(), kDefaultStorage); + if (!IsMKLDNN()) { + NDArray ret = this->Detach(); + ret.shape_ = shape; + return ret; + } else { + NDArray ret(shape, ctx(), true, dtype()); + // We shouldn't submit the reorder primitive here because submit will + // be called in operators. + auto format = GetDefaultFormat(ptr_->Mkl_mem_->get_primitive_desc().desc()); + CHECK_NE(format, ptr_->Mkl_mem_->get_primitive_desc().desc().data.format); + auto def_pd = GetPrimitiveDesc(ptr_->Mkl_mem_->get_primitive_desc(), format); + auto def_mem = TmpMemMgr::Get()->Alloc(def_pd); + MKLDNNStream *stream = MKLDNNStream::Get(); + stream->RegisterMem(ptr_->Mkl_mem_); + stream->RegisterPrim(mkldnn::reorder(*ptr_->Mkl_mem_, *def_mem)); + // def_mem points to a memory region in the temp space. It's only valid + // inside an operator. As such, the returned NDArray can only be valid + // inside an operator and the shared point doesn't need to do anything + // when it's destroyed. + ret.ptr_->Mkl_mem_ = std::shared_ptr(def_mem, + EmptyMKLDNNDeleter()); + ret.ptr_->shandle.dptr = def_mem->get_data_handle(); + ret.ptr_->shandle.size = def_mem->get_primitive_desc().get_size(); + ret.ptr_->delay_alloc = false; + ret.ptr_->static_data = true; + ret.byte_offset_ = byte_offset_; + return ret; + } +} + +#endif + NDArray NDArray::Reshape(const TShape &shape) const { CHECK(!is_none()) << "NDArray is not initialized"; - auto stype = storage_type(); - // reshape is not supported for non-default ndarray with dismatching shapes - CHECK((shape_ == shape) || stype == kDefaultStorage) - << "Reshape for storage type " << stype << " is not implemented yet"; CHECK_GE(shape_.Size(), shape.Size()) << "NDArray.Reshape: target shape size is larger current shape"; + CHECK_EQ(storage_type(), kDefaultStorage); NDArray ret = this->Detach(); ret.shape_ = shape; return ret; @@ -95,7 +238,6 @@ NDArray NDArray::ReshapeWithRecord(const TShape &shape) { return ret; } - NDArray NDArray::Slice(index_t begin, index_t end) const { CHECK(!is_none()) << "NDArray is empty"; CHECK_LE(begin, end) @@ -127,8 +269,8 @@ NDArray NDArray::SliceWithRecord(index_t begin, index_t end) { } NDArray NDArray::At(index_t idx) const { - CHECK(storage_type() == kDefaultStorage) << "Storage type " - << storage_type() << " doesn't support At()"; + CHECK(storage_type() == kDefaultStorage) + << "Storage type " << storage_type() << " doesn't support At()"; NDArray ret = this->Slice(idx, idx+1); if (shape_.ndim() > 1) { return ret.Reshape(TShape(shape_.data()+1, shape_.data()+shape_.ndim())); @@ -181,6 +323,409 @@ void NDArray::set_fresh_out_grad(bool state) const { info.fresh_out_grad = state; } +#if MXNET_USE_MKLDNN == 1 +static inline bool same_shape(const TShape &shape, mkldnn_dims_t dims, int ndims) { + if (shape.ndim() != (size_t)ndims) + return false; + for (int i = 0; i < ndims; i++) + if (shape[i] != dims[i]) + return false; + return true; +} + +static inline bool same_shape(const TShape &shape, int dtype, mkldnn::memory::desc desc) { + return same_shape(shape, desc.data.dims, desc.data.ndims) + && get_mkldnn_type(dtype) == desc.data.data_type; +} + +bool NDArray::Chunk::IsMKLDNN() const { + if (storage_type != kDefaultStorage) + return false; + if (Mkl_mem_ == nullptr) + return false; + auto desc = Mkl_mem_->get_primitive_desc().desc(); + return desc.data.format != GetDefaultFormat(desc); +} + +bool NDArray::Chunk::IsDefault() const { + if (storage_type != kDefaultStorage) + return false; + // If we don't have mkldnn memory yet, we just assume it's not the default + // format. + if (Mkl_mem_ == nullptr) + return true; + auto desc = Mkl_mem_->get_primitive_desc().desc(); + return desc.data.format == GetDefaultFormat(desc); +} + +void NDArray::Chunk::Reorder2Default() { + if (Mkl_mem_ == nullptr) + return; + + auto format = GetDefaultFormat(Mkl_mem_->get_primitive_desc().desc()); + CHECK(format != Mkl_mem_->get_primitive_desc().desc().data.format); + + auto def_pd = GetPrimitiveDesc(Mkl_mem_->get_primitive_desc(), format); + mkldnn_mem_ptr def_mem(new mkldnn::memory(def_pd)); + // This may be called in MKLDNN operators. We can't use MKLDNNStream here. + std::vector net; + net.push_back(mkldnn::reorder(*Mkl_mem_, *def_mem)); + mkldnn::stream(mkldnn::stream::kind::eager).submit(net).wait(); + + CHECK(shandle.size >= def_pd.get_size()); + CheckAndAlloc(def_pd.get_size()); + // TODO(zhengda) We need to avoid memory copy here. + memcpy(shandle.dptr, def_mem->get_data_handle(), def_pd.get_size()); + Mkl_mem_.reset(new mkldnn::memory(def_pd, shandle.dptr)); +} + +void NDArray::Chunk::SetMKLMem(const TShape &shape, int dtype) { + // The shape of the array and the one of the MKL memory may mismatch. + // For example, if the array stores parameters, the MKL memory may store data + // in 5 dimensions while the NDArray stores data in 4 dimensions. + if (Mkl_mem_ && Mkl_mem_->get_data_handle() == shandle.dptr + && same_shape(shape, dtype, Mkl_mem_->get_primitive_desc().desc())) { + return; + } + + mkldnn::memory::dims dims; + // These are shapes supprted by MKLDNN. + if (shape.ndim() == 1 || shape.ndim() == 2 || shape.ndim() == 4 + || shape.ndim() == 5) { + dims.resize(shape.ndim()); + for (size_t i = 0; i < dims.size(); i++) + dims[i] = shape[i]; + } else if (shape.ndim() == 3) { + // If there are 3 dimensions, we'll force it to 4 dimensions. + dims.resize(shape.ndim() + 1); + dims[0] = 1; + for (size_t i = 0; i < shape.ndim(); i++) + dims[i + 1] = shape[i]; + } else { + LOG(FATAL) << "MKLDNN doesn't support " << shape.ndim() << " dimensions"; + } + mkldnn::memory::format layout = mkldnn::memory::format::format_undef; + switch (dims.size()) { + case 1: layout = mkldnn::memory::format::x; break; + case 2: layout = mkldnn::memory::format::nc; break; + case 4: layout = mkldnn::memory::format::nchw; break; + // This isn't the right layout when the data has 5 dimensions in MXNet. + // MXNet interprets 5 dimensions as ncdhw, but MKLDNN doesn't have + // a corresponding format. + case 5: layout = mkldnn::memory::format::goihw; break; + } + mkldnn::memory::desc data_md{dims, get_mkldnn_type(dtype), layout}; + auto cpu_engine = CpuEngine::Get()->get_engine(); + if (shandle.dptr == nullptr) { + CHECK(delay_alloc); + CheckAndAlloc(); + } + mkldnn::memory::primitive_desc pd(data_md, cpu_engine); + CHECK(shandle.size >= pd.get_size()); + Mkl_mem_.reset(new mkldnn::memory(pd, shandle.dptr)); +} + +/* + * Here we want to get MKLDNN memory whose primitive desc is exactly the same as + * the given one. operator== can't guarantee that. == can return true even if + * the formats are different. I need to double check its format. + */ +static inline mkldnn::memory *GetMKLDNNExact( + const mkldnn::memory *mem, mkldnn::memory::primitive_desc desc) { + auto src_desc = mem->get_primitive_desc(); + if (desc == src_desc && desc.desc().data.format == src_desc.desc().data.format) { + return const_cast(mem); + } else { + std::shared_ptr ret(new mkldnn::memory( + desc, mem->get_data_handle())); + MKLDNNStream::Get()->RegisterMem(ret); + return ret.get(); + } +} + +const mkldnn::memory *NDArray::GetMKLDNNData( + const mkldnn::memory::primitive_desc &desc) const { + if (desc.get_size() != shape().Size() * GetTypeSize(dtype_)) { + LOG(FATAL) << "The size of NDArray doesn't match the requested MKLDNN memory desc"; + return nullptr; + } + auto mem = GetMKLDNNData(); + mkldnn::memory::primitive_desc _desc = desc; + auto desc1 = mem->get_primitive_desc().desc(); + auto desc2 = _desc.desc(); + // The MKL memory has the same format and shape as required, + // or both use the default format, we can return the MKL memory. + if (mem->get_primitive_desc() == desc + || (desc1.data.format == GetDefaultFormat(desc1) + && desc2.data.format == GetDefaultFormat(desc2))) { + return GetMKLDNNExact(ptr_->Mkl_mem_.get(), desc); + } else { + return nullptr; + } +} + +const mkldnn::memory *NDArray::GetMKLDNNDataReorder( + const mkldnn::memory::primitive_desc &desc) const { + if (desc.get_size() != shape().Size() * GetTypeSize(dtype_)) { + LOG(FATAL) << "The size of NDArray doesn't match the requested MKLDNN memory desc"; + return nullptr; + } + CHECK(storage_type() == kDefaultStorage); + + auto mem = GetMKLDNNData(); + // If the memory descriptor matches, it's easy. + MKLDNNStream *stream = MKLDNNStream::Get(); + if (mem->get_primitive_desc() == desc) { + return GetMKLDNNExact(mem, desc); + } + + mkldnn::memory::primitive_desc _desc = desc; + // Now we need to determine if we should reorder the memory. + // If both use the default formats, we think we don't need to reorder. + auto desc1 = mem->get_primitive_desc().desc(); + auto desc2 = _desc.desc(); + if (desc1.data.format == GetDefaultFormat(desc1) && + desc2.data.format == GetDefaultFormat(desc2)) { + mkldnn_mem_ptr ret(new mkldnn::memory(desc, mem->get_data_handle())); + stream->RegisterMem(ret); + return ret.get(); + } else { + auto ret = TmpMemMgr::Get()->Alloc(desc); + stream->RegisterPrim(mkldnn::reorder(*mem, *ret)); + return ret; + } +} + +const mkldnn::memory *NDArray::GetMKLDNNData() const { + CHECK(storage_type() == kDefaultStorage); + // If this array uses MKLDNN layout and it's a view, we have to change its + // layout to the default layout. + if (IsMKLDNN() && IsView()) + ptr_->Reorder2Default(); + ptr_->SetMKLMem(IsView() ? ptr_->storage_shape : shape_, dtype_); + // If shandle has data, the data in shandle and Mkl_mem_ should match. + if (ptr_->shandle.dptr) + CHECK(ptr_->shandle.dptr == ptr_->Mkl_mem_->get_data_handle()); + MKLDNNStream::Get()->RegisterMem(ptr_->Mkl_mem_); + auto pd = ptr_->Mkl_mem_->get_primitive_desc(); + if (IsView()) { + // Sliced array must use the default layout. + CHECK_EQ(GetDefaultFormat(pd.desc()), pd.desc().data.format); + } + if (IsView()) { + void *off_addr = static_cast(ptr_->Mkl_mem_->get_data_handle()) + + byte_offset_; + + // Create the primitive desc for the new mkldnn memory. + mkldnn::memory::dims dims(pd.desc().data.ndims); + // The first dimension has been sliced. + dims[0] = shape()[0]; + for (size_t i = 1; i < dims.size(); i++) + dims[i] = pd.desc().data.dims[i]; + mkldnn::memory::format cpp_format = static_cast( + pd.desc().data.format); + mkldnn::memory::data_type cpp_type = static_cast( + pd.desc().data.data_type); + mkldnn::memory::desc data_md(dims, cpp_type, cpp_format); + mkldnn::memory::primitive_desc new_pd(data_md, pd.get_engine()); + + std::shared_ptr ret(new mkldnn::memory(new_pd, off_addr)); + MKLDNNStream::Get()->RegisterMem(ret); + return ret.get(); + } else { + return ptr_->Mkl_mem_.get(); + } +} + +void NDArray::Reorder(const mkldnn::memory::primitive_desc &pd) { + CHECK_EQ(storage_type(), kDefaultStorage); + // If the memory already uses the specified layout, don't do anything. + if (ptr_->Mkl_mem_ != nullptr && ptr_->Mkl_mem_->get_primitive_desc() == pd) + return; + auto _pd = pd; + auto _desc = _pd.desc(); + auto def_format = GetDefaultFormat(_desc); + // If the memory is default, don't do anything. + if (def_format == _desc.data.format && ptr_->IsDefault()) + return; + // If the specified layout is default, we should use Reorder2Default. + if (def_format == _desc.data.format) { + ptr_->Reorder2Default(); + return; + } + + std::shared_ptr new_mem(new mkldnn::memory(pd)); + ptr_->SetMKLMem(shape_, dtype_); + auto old_mem = ptr_->Mkl_mem_; + // It's possible that the specified layout has a different number of dimensions. + if (old_mem->get_primitive_desc().desc().data.ndims != _desc.data.ndims) { + // For now, we only support reorder from the default layout. + CHECK(ptr_->IsDefault()); + auto def_pd = GetPrimitiveDesc(pd, def_format); + old_mem.reset(new mkldnn::memory(def_pd, old_mem->get_data_handle())); + } + // This may be called in MKLDNN operators. We can't use MKLDNNStream here. + std::vector net; + net.push_back(mkldnn::reorder(*old_mem, *new_mem)); + mkldnn::stream(mkldnn::stream::kind::eager).submit(net).wait(); + + CHECK(ptr_->shandle.size >= pd.get_size()); + ptr_->CheckAndAlloc(pd.get_size()); + // TODO(zhengda) We need to avoid memory copy here. + memcpy(ptr_->shandle.dptr, new_mem->get_data_handle(), pd.get_size()); + ptr_->Mkl_mem_.reset(new mkldnn::memory(pd, ptr_->shandle.dptr)); +} + +void NDArray::CopyFrom(const mkldnn::memory &mem) { + if (ptr_ == nullptr) { + LOG(FATAL) << "The NDArray hasn't been initialized"; + return; + } + if (ptr_->Mkl_mem_.get() == &mem) + return; + + if (mem.get_primitive_desc().get_size() != shape().Size() * GetTypeSize(dtype_)) { + LOG(FATAL) << "The size of NDArray doesn't match the requested MKLDNN memory desc"; + return; + } + + MKLDNNStream *stream = MKLDNNStream::Get(); + // If this array uses MKLDNN layout and it's a view, we have to change its + // layout to the default layout. + if (IsMKLDNN() && IsView()) + ptr_->Reorder2Default(); + ptr_->SetMKLMem(IsView() ? ptr_->storage_shape : shape_, + dtype_); + stream->RegisterMem(ptr_->Mkl_mem_); + auto from_desc = mem.get_primitive_desc().desc(); + auto this_desc = ptr_->Mkl_mem_->get_primitive_desc().desc(); + auto from_def_format = GetDefaultFormat(from_desc); + if (IsView()) { + // Sliced array must use the default layout. + CHECK_EQ(GetDefaultFormat(this_desc), this_desc.data.format); + } + // It's possible that the memory and the NDArray don't have the same shape. + if (!same_shape(shape_, from_desc.data.dims, from_desc.data.ndims) + // If the source memory uses the default layout, we can reshape directly. + && from_def_format == from_desc.data.format) { + // In this case, we can simply create a new MKLDNN memory for the required + // shape. + mkldnn::memory::dims dims(this_desc.data.dims, + this_desc.data.dims + this_desc.data.ndims); + auto this_dtype = static_cast(this_desc.data.data_type); + auto this_format = static_cast(GetDefaultFormat(this_desc)); + mkldnn::memory::desc data_md(dims, this_dtype, this_format); + mkldnn::memory::primitive_desc pd(data_md, mem.get_primitive_desc().get_engine()); + mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, mem.get_data_handle())); + stream->RegisterMem(tmp_mem); + stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *ptr_->Mkl_mem_)); + } else if (!same_shape(shape_, from_desc.data.dims, from_desc.data.ndims)) { + // In this case, the source memory stores data in a customized layout. We + // need to reorganize the data in memory before we can reshape. + auto def_pd = GetPrimitiveDesc(mem.get_primitive_desc(), from_def_format); + auto def_mem = TmpMemMgr::Get()->Alloc(def_pd); + stream->RegisterPrim(mkldnn::reorder(mem, *def_mem)); + // Now we can reshape it + mkldnn::memory::dims dims(this_desc.data.dims, + this_desc.data.dims + this_desc.data.ndims); + auto this_dtype = static_cast(this_desc.data.data_type); + auto this_format = static_cast(GetDefaultFormat(this_desc)); + mkldnn::memory::desc data_md(dims, this_dtype, this_format); + mkldnn::memory::primitive_desc pd(data_md, mem.get_primitive_desc().get_engine()); + mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, def_mem->get_data_handle())); + stream->RegisterMem(tmp_mem); + stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *ptr_->Mkl_mem_)); + } else if (mem.get_primitive_desc() == ptr_->Mkl_mem_->get_primitive_desc()) { + // If the layout is the same, we can just copy data. + stream->RegisterPrim(mkldnn::reorder(mem, *ptr_->Mkl_mem_)); + } else { + auto src_def = GetDefaultFormat(mem.get_primitive_desc().desc()); + auto dst_def = GetDefaultFormat(ptr_->Mkl_mem_->get_primitive_desc().desc()); + // If both are not using the default layouts. There isn't much we can do, + // other than reorder data layout directly. + if (dst_def != ptr_->Mkl_mem_->get_primitive_desc().desc().data.format + && src_def != mem.get_primitive_desc().desc().data.format) { + stream->RegisterPrim(mkldnn::reorder(mem, *ptr_->Mkl_mem_)); + } else if (dst_def == ptr_->Mkl_mem_->get_primitive_desc().desc().data.format) { + // If the dest mem uses the default memory layout, we can simply use + // the default format of the source memory to improve perf of reorder. + auto pd = GetPrimitiveDesc(ptr_->Mkl_mem_->get_primitive_desc(), src_def); + mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, ptr_->Mkl_mem_->get_data_handle())); + stream->RegisterMem(tmp_mem); + stream->RegisterPrim(mkldnn::reorder(mem, *tmp_mem)); + } else { + // If the src mem uses the default memory layout, we can use + // the default format of the source memory to improve perf. + auto pd = GetPrimitiveDesc(mem.get_primitive_desc(), dst_def); + mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, mem.get_data_handle())); + stream->RegisterMem(tmp_mem); + stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *ptr_->Mkl_mem_)); + } + } +} +mkldnn::memory::primitive_desc GetPrimitiveDesc(mkldnn::memory::primitive_desc pd, + mkldnn_memory_format_t format); + +mkldnn::memory *NDArray::CreateMKLDNNData(const mkldnn::memory::primitive_desc &desc) { + // This array shouldn't be a view. + CHECK(!IsView()); + + if (desc.get_size() != shape().Size() * GetTypeSize(dtype_)) { + LOG(FATAL) << "The size of NDArray doesn't match the requested MKLDNN memory desc"; + return nullptr; + } + + mkldnn::memory::primitive_desc _desc = desc; + auto required_format = _desc.desc().data.format; + auto def_format = GetDefaultFormat(_desc.desc()); + // If the required format is a default format, we don't need to worry about the shape. + // If the shape isn't the same, it actually implicitly reshapes data. + if (required_format == def_format) { + ptr_->SetMKLMem(shape_, dtype_); + MKLDNNStream::Get()->RegisterMem(ptr_->Mkl_mem_); + return GetMKLDNNExact(ptr_->Mkl_mem_.get(), desc); + } + + if (ptr_->Mkl_mem_) + CHECK(ptr_->Mkl_mem_->get_data_handle() == ptr_->shandle.dptr); + ptr_->ResetMKLMem(); + if (ptr_->Mkl_mem_ && ptr_->Mkl_mem_->get_primitive_desc() == desc) { + MKLDNNStream::Get()->RegisterMem(ptr_->Mkl_mem_); + return GetMKLDNNExact(ptr_->Mkl_mem_.get(), desc); + } + + CHECK(ptr_->shandle.size >= desc.get_size()); + ptr_->CheckAndAlloc(desc.get_size()); + ptr_->Mkl_mem_.reset(new mkldnn::memory(desc, ptr_->shandle.dptr)); + MKLDNNStream::Get()->RegisterMem(ptr_->Mkl_mem_); + return ptr_->Mkl_mem_.get(); +} +#endif + +void NDArray::SetTBlob() const { + CHECK(ptr_ != nullptr); + TShape shape = shape_; + char *dptr = static_cast(ptr_->shandle.dptr); + auto stype = storage_type(); + if (stype == kDefaultStorage) { +#if MXNET_USE_MKLDNN == 1 + if (IsMKLDNN()) { + ptr_->Reorder2Default(); + dptr = static_cast(ptr_->shandle.dptr); + } +#endif + dptr += byte_offset_; + } else if (stype == kCSRStorage || stype == kRowSparseStorage) { + CHECK_EQ(byte_offset_, 0); + shape = storage_shape(); + } else { + LOG(FATAL) << "unknown storage type " << stype; + } + tblob_.dptr_ = dptr; + tblob_.shape_ = shape; + tblob_.type_flag_ = dtype_; + tblob_.SetDLTensor(ptr_->shandle.ctx.dev_mask(), ptr_->shandle.ctx.dev_id); +} /*! * \brief run a ternary operation @@ -449,11 +994,49 @@ inline void CopyFromToRspImpl(const NDArray& from, const NDArray& to, RunContext // Make a copy of a dense NDArray template inline void CopyFromToDnsImpl(const NDArray& from, const NDArray& to, RunContext ctx) { - using namespace mshadow; - CHECK_EQ(from.storage_type(), to.storage_type()) << "Copying with different storage type"; - TBlob tmp = to.data(); - ndarray::Copy(from.data(), &tmp, - from.ctx(), to.ctx(), ctx); +#if MXNET_USE_MKLDNN == 1 + // If neither is MKLDNN, we can copy data normally. + if (!from.IsMKLDNN() && !to.IsMKLDNN()) { +#endif + using namespace mshadow; + CHECK_EQ(from.storage_type(), to.storage_type()) << "Copying with different storage type"; + TBlob tmp = to.data(); + ndarray::Copy(from.data(), &tmp, + from.ctx(), to.ctx(), ctx); +#if MXNET_USE_MKLDNN == 1 + } else if (SupportMKLDNN(from.dtype(), from.shape()) + && SupportMKLDNN(to.dtype(), to.shape())) { + // If we copy data directly, we need to make sure both NDArrays are supported + // by MKLDNN. + auto from_mem = from.GetMKLDNNData(); + auto to_mem = to.GetMKLDNNData(); + if (from_mem->get_primitive_desc() == to_mem->get_primitive_desc()) { + size_t size = std::min(from_mem->get_primitive_desc().get_size(), + to_mem->get_primitive_desc().get_size()); + memcpy(to_mem->get_data_handle(), from_mem->get_data_handle(), size); + } else { + std::vector net; + net.push_back(mkldnn::reorder(*from_mem, *to_mem)); + mkldnn::stream(mkldnn::stream::kind::eager).submit(net).wait(); + } + } else { + // In this case, one of the NDArray isn't supported by MKLDNN, we need + // to convert the MKLDNN array to the default format first and copy data + // with Copy(). + NDArray tmp_from = from; + if (tmp_from.IsMKLDNN()) { + tmp_from = NDArray(from.shape(), from.ctx(), false, from.dtype()); + auto tmp_mem = from.GetMKLDNNData(); + tmp_from.CopyFrom(*tmp_mem); + MKLDNNStream::Get()->Submit(); + } + CHECK(tmp_from.IsDefault()); + CHECK(to.IsDefault()); + TBlob tmp = to.data(); + ndarray::Copy(from.data(), &tmp, + from.ctx(), to.ctx(), ctx); + } +#endif } // Make a copy of an NDArray based on storage type diff --git a/src/operator/nn/activation-inl.h b/src/operator/nn/activation-inl.h index d8da30b7263a..31b417a7f321 100644 --- a/src/operator/nn/activation-inl.h +++ b/src/operator/nn/activation-inl.h @@ -46,6 +46,7 @@ namespace op { namespace activation { enum ActivationOpInputs {kData}; enum ActivationOpOutputs {kOut}; +enum ActivationOpResource {kTempSpace}; enum ActivationOpType {kReLU, kSigmoid, kTanh, kSoftReLU}; } // activation @@ -60,6 +61,16 @@ struct ActivationParam : public dmlc::Parameter { .add_enum("softrelu", activation::kSoftReLU) .describe("Activation function to be applied."); } + + bool operator==(const ActivationParam& other) const { + return this->act_type == other.act_type; + } + +#if MXNET_USE_MKLDNN == 1 + uint64_t GetHash() const { + return act_type; + } +#endif }; template @@ -100,31 +111,25 @@ void ActivationBackward(const OpContext &ctx, const TBlob &out_grad, } template -void ActivationCompute(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - CHECK_EQ(inputs.size(), 1U); - CHECK_EQ(outputs.size(), 1U); - const ActivationParam& param = nnvm::get(attrs.parsed); - MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { +void ActivationComputeImpl(const ActivationParam ¶m, const OpContext &ctx, + const TBlob &input, OpReqType req, const TBlob &output) { + MSHADOW_REAL_TYPE_SWITCH(input.type_flag_, DType, { switch (param.act_type) { case activation::kReLU: ActivationForward( - ctx, inputs[0], req[0], outputs[0]); + ctx, input, req, output); break; case activation::kSigmoid: ActivationForward( - ctx, inputs[0], req[0], outputs[0]); + ctx, input, req, output); break; case activation::kTanh: ActivationForward( - ctx, inputs[0], req[0], outputs[0]); + ctx, input, req, output); break; case activation::kSoftReLU: ActivationForward( - ctx, inputs[0], req[0], outputs[0]); + ctx, input, req, output); break; default: LOG(FATAL) << "unknown activation type"; @@ -133,36 +138,26 @@ void ActivationCompute(const nnvm::NodeAttrs& attrs, } template -void ActivationGradCompute(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { -#if MXNET_USE_CUDNN == 1 - CHECK_EQ(inputs.size(), 3U); -#else - CHECK_EQ(inputs.size(), 2U); -#endif - CHECK_EQ(outputs.size(), 1U); - CHECK_EQ(req.size(), 1U); - const ActivationParam& param = nnvm::get(attrs.parsed); - MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { +void ActivationGradComputeImpl(const ActivationParam ¶m, const OpContext &ctx, + const TBlob &out_grad, const TBlob &out_data, + OpReqType req, const TBlob &output) { + MSHADOW_REAL_TYPE_SWITCH(out_grad.type_flag_, DType, { switch (param.act_type) { case activation::kReLU: ActivationBackward( - ctx, inputs[0], inputs[1], req[0], outputs[0]); + ctx, out_grad, out_data, req, output); break; case activation::kSigmoid: ActivationBackward( - ctx, inputs[0], inputs[1], req[0], outputs[0]); + ctx, out_grad, out_data, req, output); break; case activation::kTanh: ActivationBackward( - ctx, inputs[0], inputs[1], req[0], outputs[0]); + ctx, out_grad, out_data, req, output); break; case activation::kSoftReLU: ActivationBackward( - ctx, inputs[0], inputs[1], req[0], outputs[0]); + ctx, out_grad, out_data, req, output); break; default: LOG(FATAL) << "unknown activation type"; @@ -170,6 +165,35 @@ void ActivationGradCompute(const nnvm::NodeAttrs& attrs, }); } +template +void ActivationCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + const ActivationParam& param = nnvm::get(attrs.parsed); + ActivationComputeImpl(param, ctx, inputs[0], req[0], outputs[0]); +} + +template +void ActivationGradCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { +#if MXNET_USE_CUDNN == 1 + CHECK_EQ(inputs.size(), 3U); +#else + CHECK_EQ(inputs.size(), 2U); +#endif + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + const ActivationParam& param = nnvm::get(attrs.parsed); + ActivationGradComputeImpl(param, ctx, inputs[0], inputs[1], req[0], outputs[0]); +} + } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_NN_ACTIVATION_INL_H_ diff --git a/src/operator/nn/activation.cc b/src/operator/nn/activation.cc index c437b685ddc6..1a6974e5fae7 100644 --- a/src/operator/nn/activation.cc +++ b/src/operator/nn/activation.cc @@ -26,11 +26,10 @@ #include "./activation-inl.h" #include "../mshadow_op.h" #include "../tensor/elemwise_unary_op.h" -#if MXNET_USE_MKL2017 == 1 -#include -#include "../mkl/mkl_memory-inl.h" -#include "../mkl/mkl_relu-inl.h" -#endif // MXNET_USE_MKL2017 +#if MXNET_USE_MKLDNN == 1 +#include "./mkldnn/mkldnn_base-inl.h" +#include "./mkldnn/mkldnn_ops-inl.h" +#endif // MXNET_USE_MKLDNN namespace mxnet { namespace op { @@ -51,6 +50,91 @@ struct ActivationGrad { } }; +#if MXNET_USE_MKLDNN == 1 +static void ActivationComputeExCPU(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const ActivationParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + if (SupportMKLDNN(inputs[0])) { + MKLDNNActivationForward(attrs, ctx, inputs[0], req[0], outputs[0]); + return; + } + ActivationComputeImpl(param, ctx, inputs[0].data(), req[0], outputs[0].data()); +} + +void ActivationGradComputeExCPU(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { +#if MXNET_USE_CUDNN == 1 + CHECK_EQ(inputs.size(), 3U); +#else + CHECK_EQ(inputs.size(), 2U); +#endif + const ActivationParam& param = nnvm::get(attrs.parsed); + if (SupportMKLDNN(inputs[0])) { + MKLDNNActivationBackward(attrs, ctx, inputs[0], inputs[1], req[0], + outputs[0]); + return; + } + ActivationGradComputeImpl(param, ctx, inputs[0].data(), inputs[1].data(), + req[0], outputs[0].data()); +} +#endif + +inline static bool ActivationStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1); + CHECK_EQ(out_attrs->size(), 1); + const ActivationParam& param = nnvm::get(attrs.parsed); + bool ret = ElemwiseStorageType<1, 1, false, false, false>(attrs, dev_mask, + dispatch_mode, + in_attrs, out_attrs); +#if MXNET_USE_MKLDNN == 1 + if (dev_mask == mshadow::cpu::kDevMask && SupportMKLDNNAct(param)) { + *dispatch_mode = DispatchMode::kFComputeEx; + } +#endif + return ret; +} + +inline static bool BackwardActStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { +#if MXNET_USE_CUDNN == 1 + CHECK_EQ(in_attrs->size(), 3U); +#else + CHECK_EQ(in_attrs->size(), 2U); +#endif + CHECK_EQ(out_attrs->size(), 1U); + const ActivationParam& param = nnvm::get(attrs.parsed); +#if MXNET_USE_CUDNN == 1 + bool ret = ElemwiseStorageType<3, 1, false, false, false>(attrs, dev_mask, + dispatch_mode, + in_attrs, out_attrs); +#else + bool ret = ElemwiseStorageType<2, 1, false, false, false>(attrs, dev_mask, + dispatch_mode, + in_attrs, out_attrs); +#endif +#if MXNET_USE_MKLDNN == 1 + if (dev_mask == mshadow::cpu::kDevMask && SupportMKLDNNAct(param)) { + *dispatch_mode = DispatchMode::kFComputeEx; + } +#endif + return ret; +} + MXNET_OPERATOR_REGISTER_UNARY(Activation) .describe(R"code(Applies an activation function element-wise to the input. @@ -63,7 +147,11 @@ The following activation functions are supported: )code" ADD_FILELINE) .set_attr_parser(ParamParser) +.set_attr("FInferStorageType", ActivationStorageType) .set_attr("FCompute", ActivationCompute) +#if MXNET_USE_MKLDNN == 1 +.set_attr("FComputeEx", ActivationComputeExCPU) +#endif .set_attr("FGradient", ActivationGrad{"_backward_Activation"}) .add_arguments(ActivationParam::__FIELDS__()); @@ -71,12 +159,21 @@ NNVM_REGISTER_OP(_backward_Activation) .set_num_inputs(3) .set_num_outputs(1) .set_attr("TIsBackward", true) +.set_attr("FInferStorageType", BackwardActStorageType) .set_attr("FInferShape", ElemwiseShape<3, 1>) .set_attr("FInferType", ElemwiseType<3, 1>) .set_attr("FInplaceOption", [](const NodeAttrs& attrs){ return std::vector >{{0, 0}}; }) +#if MXNET_USE_MKLDNN == 1 +.set_attr("FResourceRequest", [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; +}) +#endif .set_attr_parser(ParamParser) +#if MXNET_USE_MKLDNN == 1 +.set_attr("FComputeEx", ActivationGradComputeExCPU) +#endif .set_attr("FCompute", ActivationGradCompute); } // namespace op diff --git a/src/operator/nn/batch_norm-inl.h b/src/operator/nn/batch_norm-inl.h index a6b11fc647f6..7bd13fcff0b9 100644 --- a/src/operator/nn/batch_norm-inl.h +++ b/src/operator/nn/batch_norm-inl.h @@ -50,6 +50,7 @@ namespace batchnorm { enum BatchNormOpInputs {kData, kGamma, kBeta, kInMovingMean, kInMovingVar}; // kGamma: weights, kBeta: biases enum BatchNormOpOutputs {kOut, kMean, kVar}; // req, out_data +enum BatchNormOpResource {kTempSpace}; enum BatchNormOpAuxiliary {kMovingMean, kMovingVar}; // aux_states /*! \brief Default channel axis if none specified int he params */ @@ -84,6 +85,28 @@ struct BatchNormParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(cudnn_off).set_default(false) .describe("Do not select CUDNN operator, if available"); } + + bool operator==(const BatchNormParam& other) const { + return this->eps == other.eps && + this->momentum == other.momentum && + this->fix_gamma == other.fix_gamma && + this->use_global_stats == other.use_global_stats && + this->output_mean_var == other.output_mean_var && + this->axis == other.axis && + this->cudnn_off == other.cudnn_off; + } + +#if MXNET_USE_MKLDNN == 1 + uint64_t GetHash() const { + uint64_t hash = 0; + hash = hash * 2 + momentum * 10; + hash = hash * 2 + fix_gamma; + hash = hash * 2 + use_global_stats; + hash = hash * 2 + output_mean_var; + hash = hash * 2 + axis; + return hash; + } +#endif }; static inline bool IsBNWriting(const OpReqType ort) { @@ -91,40 +114,40 @@ static inline bool IsBNWriting(const OpReqType ort) { } template -void DoBNForward(mshadow::Stream *stream, - const OpContext &ctx, const BatchNormParam& param, - const std::vector &in_data, - const std::vector &req, - const std::vector &out_data, - const std::vector &aux_states); +void BatchNormForwardImpl(mshadow::Stream *stream, + const OpContext &ctx, const BatchNormParam& param, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data, + const std::vector &aux_states); template -void DoBNBackward(mshadow::Stream *stream, - const OpContext &ctx, const BatchNormParam& param, - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data, - const std::vector &req, - const std::vector &in_grad, - const std::vector &aux_states); +void BatchNormBackwardImpl(mshadow::Stream *stream, + const OpContext &ctx, const BatchNormParam& param, + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &req, + const std::vector &in_grad, + const std::vector &aux_states); #if MXNET_USE_CUDA template -void DoBNForward(mshadow::Stream *stream, - const OpContext &ctx, const BatchNormParam& param, - const std::vector &in_data, - const std::vector &req, - const std::vector &out_data, - const std::vector &aux_states); +void BatchNormForwardImpl(mshadow::Stream *stream, + const OpContext &ctx, const BatchNormParam& param, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data, + const std::vector &aux_states); template -void DoBNBackward(mshadow::Stream *stream, - const OpContext &ctx, const BatchNormParam& param, - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data, - const std::vector &req, - const std::vector &in_grad, - const std::vector &aux_states); +void BatchNormBackwardImpl(mshadow::Stream *stream, + const OpContext &ctx, const BatchNormParam& param, + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &req, + const std::vector &in_grad, + const std::vector &aux_states); #endif // MXNET_USE_CUDA /*! @@ -139,11 +162,11 @@ void DoBNBackward(mshadow::Stream *stream, * \sa OpReqType, OpContext */ template -void BNForward(const OpContext &ctx, const BatchNormParam& param, - const std::vector &in_data, - const std::vector &req, - const std::vector &out_data, - const std::vector &aux_states) { +void BatchNormForward(const OpContext &ctx, const BatchNormParam& param, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data, + const std::vector &aux_states) { using namespace mshadow; using namespace mshadow::expr; @@ -158,7 +181,8 @@ void BNForward(const OpContext &ctx, const BatchNormParam& param, CHECK_EQ(req[batchnorm::kOut], kWriteTo); } Stream *s = ctx.get_stream(); - DoBNForward(s, ctx, param, in_data, req, out_data, aux_states); + BatchNormForwardImpl(s, ctx, param, in_data, req, + out_data, aux_states); } /*! @@ -190,20 +214,20 @@ void BNForward(const OpContext &ctx, const BatchNormParam& param, * \sa OperatorProperty, OpReqType, OpContext */ template -void BNBackward(const OpContext &ctx, const BatchNormParam& param, - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data, - const std::vector &req, - const std::vector &in_grad, - const std::vector &aux_states) { +void BatchNormBackward(const OpContext &ctx, const BatchNormParam& param, + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &req, + const std::vector &in_grad, + const std::vector &aux_states) { CHECK_EQ(out_grad.size(), param.output_mean_var ? 3U : 1U); CHECK_EQ(in_data.size(), 3U); CHECK_EQ(out_data.size(), 3U); CHECK_EQ(in_grad.size(), 3U); mshadow::Stream *s = ctx.get_stream(); - DoBNBackward(s, ctx, param, out_grad, in_data, - out_data, req, in_grad, aux_states); + BatchNormBackwardImpl(s, ctx, param, out_grad, in_data, + out_data, req, in_grad, aux_states); } template @@ -218,7 +242,8 @@ void BatchNormCompute(const nnvm::NodeAttrs& attrs, std::vector aux_states(inputs.begin() + batchnorm::kInMovingMean, inputs.end()); MSHADOW_REAL_TYPE_SWITCH_EX(inputs[0].type_flag_, DType, AccReal, { - BNForward(ctx, param, in_data, req, outputs, aux_states); + BatchNormForward(ctx, param, in_data, req, outputs, + aux_states); }); } @@ -242,8 +267,8 @@ void BatchNormGradCompute(const nnvm::NodeAttrs& attrs, std::vector in_grad(outputs.begin(), outputs.begin() + 3); MSHADOW_REAL_TYPE_SWITCH_EX(out_grad[0].type_flag_, DType, AccReal, { - BNBackward(ctx, param, out_grad, in_data, out_data, req, - in_grad, aux_states); + BatchNormBackward(ctx, param, out_grad, in_data, out_data, req, + in_grad, aux_states); }); } diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index bb5a70658d21..ba74050a81d0 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -32,6 +32,9 @@ #endif // MXNET_USE_MKL2017 #include #include "../elemwise_op_common.h" +#if MXNET_USE_MKLDNN == 1 +#include "./mkldnn/mkldnn_batch_norm-inl.h" +#endif /*! \brief inverse standard deviation <-> variance */ #define VARIANCE_TO_INVSTD(__var$, __eps$) (1.0/sqrt((__var$) + DType(__eps$))) @@ -90,12 +93,12 @@ static inline void ForEachFast(const BNTensor3 &in_data, /*! \brief Forward CPU */ template -void DoBNForward(mshadow::Stream *, - const OpContext &ctx, const BatchNormParam& param_, - const std::vector &in_data, - const std::vector &req, - const std::vector &out_data, - const std::vector &aux_states) { +void BatchNormForwardImpl(mshadow::Stream *, + const OpContext &ctx, const BatchNormParam& param_, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data, + const std::vector &aux_states) { // Input batchnorm::BNTensor3 inputData(in_data[batchnorm::kData], param_.axis); const TBlob &weights = in_data[batchnorm::kGamma]; @@ -190,14 +193,14 @@ void DoBNForward(mshadow::Stream *, } template -void DoBNBackward(mshadow::Stream *, - const OpContext &ctx, const BatchNormParam& param_, - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data, - const std::vector &req, - const std::vector &in_grad, - const std::vector &aux_states) { +void BatchNormBackwardImpl(mshadow::Stream *, + const OpContext &ctx, const BatchNormParam& param_, + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &req, + const std::vector &in_grad, + const std::vector &aux_states) { // Input Data batchnorm::BNTensor3 inputData(in_data[batchnorm::kData], param_.axis); const TBlob &weights = in_data[batchnorm::kGamma]; @@ -379,6 +382,130 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs, return true; } +static inline bool similar_array(const mxnet::NDArray &arr1, + const mxnet::NDArray &arr2, + float tol) { + float *data1 = reinterpret_cast(arr1.data().dptr_); + float *data2 = reinterpret_cast(arr2.data().dptr_); + if (arr1.shape().Size() != arr2.shape().Size()) + return false; + for (size_t i = 0; i < arr1.shape().Size(); i++) { + if (std::abs(data1[i] - data2[i]) > tol) { + // printf("similar_array: %.8f, %.8f \n", data1[i], data2[i]); + return false; + } + } + std::cout << "similar_array: passed all check, tol=" << tol << std::endl; + return true; +} + +#if MXNET_USE_MKLDNN == 1 +static inline bool SupportMKLDNNBN(const NDArray &input, const BatchNormParam ¶m) { + TShape shape = input.shape(); + return SupportMKLDNN(input) && shape.ndim() == 4 + && param.axis == mxnet::op::batchnorm::DEFAULT_AXIS + && shape[param.axis] % 8 == 0; +} + +void BatchNormComputeExCPU(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + CHECK_EQ(inputs.size(), 5U); + const BatchNormParam ¶m = nnvm::get(attrs.parsed); + // MKLDNN batchnorm only works well on the special MKLDNN layout. + if (SupportMKLDNNBN(inputs[0], param) && inputs[0].IsMKLDNN()) { + std::vector in_data(inputs.begin(), inputs.begin() + batchnorm::kInMovingMean); + std::vector aux_states(inputs.begin() + batchnorm::kInMovingMean, inputs.end()); + + switch (inputs[0].dtype()) { + case mshadow::kFloat32: + MKLDNNBatchNormForward(ctx, param, in_data, req, outputs, aux_states); + return; + } + } + FallBackCompute(BatchNormCompute, attrs, ctx, inputs, req, outputs); +} + +void BatchNormGradComputeExCPU(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + CHECK_EQ(inputs.size(), 11U); + const BatchNormParam ¶m = nnvm::get(attrs.parsed); + int num_out_grads = param.output_mean_var ? 3U : 1U; + int in_data_start = 3; + int aux_states_start = in_data_start + batchnorm::kInMovingMean; + int out_data_start = in_data_start + batchnorm::kInMovingVar + 1; + + TShape shape = inputs[0].shape(); + // MKLDNN batchnorm only works well on the special MKLDNN layout. + if (SupportMKLDNNBN(inputs[0], param) + && (inputs[in_data_start].IsMKLDNN() || inputs[0].IsMKLDNN())) { + std::vector out_grad(inputs.begin(), inputs.begin() + num_out_grads); + std::vector in_data(inputs.begin() + in_data_start, + inputs.begin() + aux_states_start); + std::vector aux_states(inputs.begin() + aux_states_start, + inputs.begin() + out_data_start); + std::vector out_data(inputs.begin() + out_data_start, inputs.end()); + std::vector in_grad(outputs.begin(), outputs.begin() + 3); + + if (inputs[0].dtype() == mshadow::kFloat32) { + MKLDNNBatchNormBackward(ctx, param, out_grad, in_data, + out_data, req, in_grad, aux_states); + return; + } + } + FallBackCompute(BatchNormGradCompute, attrs, ctx, inputs, req, outputs); +} +#endif + +static inline bool BatchNormStorageType(const nnvm::NodeAttrs &attrs, + const int dev_mask, + DispatchMode *dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 5); + CHECK_EQ(out_attrs->size(), 3); +#if MXNET_USE_MKLDNN == 1 + if (dev_mask == mshadow::cpu::kDevMask) + *dispatch_mode = DispatchMode::kFComputeEx; + else +#endif + *dispatch_mode = DispatchMode::kFCompute; + for (int& v : *in_attrs) { + if (v == - 1) v = kDefaultStorage; + } + for (size_t i = 0; i < out_attrs->size(); i++) { + (*out_attrs)[i] = kDefaultStorage; + } + return true; +} + +static inline bool backward_BatchNormStorageType(const nnvm::NodeAttrs &attrs, + const int dev_mask, + DispatchMode *dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 11); + CHECK_EQ(out_attrs->size(), 5); +#if MXNET_USE_MKLDNN == 1 + if (dev_mask == mshadow::cpu::kDevMask) + *dispatch_mode = DispatchMode::kFComputeEx; + else +#endif + *dispatch_mode = DispatchMode::kFCompute; + for (int& v : *in_attrs) { + if (v == - 1) v = kDefaultStorage; + } + for (size_t i = 0; i < out_attrs->size(); i++) { + (*out_attrs)[i] = kDefaultStorage; + } + return true; +} + NNVM_REGISTER_OP(BatchNorm) .describe(R"code(Batch normalization. @@ -446,8 +573,17 @@ then set ``gamma`` to 1 and its gradient to 0. }) .set_attr("FInferShape", BatchNormShape) .set_attr("FInferType", BatchNormType) +.set_attr("FInferStorageType", BatchNormStorageType) .set_attr("FCompute", BatchNormCompute) +#if MXNET_USE_MKLDNN == 1 +.set_attr("FComputeEx", BatchNormComputeExCPU) +#endif .set_attr("FGradient", ElemwiseGradUseInOut{"_backward_BatchNorm"}) +#if MXNET_USE_MKLDNN == 1 +.set_attr("FResourceRequest", [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; +}) +#endif .add_argument("data", "NDArray-or-Symbol", "Input data to batch normalization") .add_argument("gamma", "NDArray-or-Symbol", "gamma array") .add_argument("beta", "NDArray-or-Symbol", "beta array") @@ -468,7 +604,16 @@ then set ``gamma`` to 1 and its gradient to 0. NNVM_REGISTER_OP(_backward_BatchNorm) .set_num_outputs(5) .set_attr("TIsBackward", true) +.set_attr("FInferStorageType", backward_BatchNormStorageType) +#if MXNET_USE_MKLDNN == 1 +.set_attr("FResourceRequest", [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; +}) +#endif .set_attr_parser(ParamParser) +#if MXNET_USE_MKLDNN == 1 +.set_attr("FComputeEx", BatchNormGradComputeExCPU) +#endif .set_attr("FCompute", BatchNormGradCompute); } // namespace op diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index 682c286f4a3a..80c15976b65f 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -593,12 +593,12 @@ static inline uint32_t SetupFlags(const OpContext &ctx, /*! \brief Forward batch-norm pass on GPU */ template -void DoBNForward(mshadow::Stream *stream, - const OpContext &ctx, const BatchNormParam& param_, - const std::vector &in_data, - const std::vector &req, - const std::vector &out_data, - const std::vector &aux_states) { +void BatchNormForwardImpl(mshadow::Stream *stream, + const OpContext &ctx, const BatchNormParam& param_, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data, + const std::vector &aux_states) { batchnorm::cuda::BatchNormalizationUpdateOutput( stream, ctx, @@ -614,14 +614,14 @@ void DoBNForward(mshadow::Stream *stream, /*! \brief Backward batch-norm pass on GPU */ template -void DoBNBackward(mshadow::Stream *stream, - const OpContext &ctx, const BatchNormParam& param_, - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data, - const std::vector &req, - const std::vector &in_grad, - const std::vector &aux_states) { +void BatchNormBackwardImpl(mshadow::Stream *stream, + const OpContext &ctx, const BatchNormParam& param_, + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &req, + const std::vector &in_grad, + const std::vector &aux_states) { batchnorm::cuda::BatchNormalizationBackward( stream, ctx, @@ -671,12 +671,12 @@ void BatchNormCompute(const nnvm::NodeAttrs& attrs, }) } else { MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DType, AccReal, { - BNForward(ctx, param, in_data, req, outputs, aux_states); + BatchNormForward(ctx, param, in_data, req, outputs, aux_states); }) } #else MSHADOW_REAL_TYPE_SWITCH_EX(inputs[0].type_flag_, DType, AccReal, { - BNForward(ctx, param, in_data, req, outputs, aux_states); + BatchNormForward(ctx, param, in_data, req, outputs, aux_states); }); #endif } @@ -706,13 +706,13 @@ void BatchNormGradCompute(const nnvm::NodeAttrs& attrs, }) } else { MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DType, AccReal, { - BNBackward(ctx, param, out_grad, + BatchNormBackward(ctx, param, out_grad, in_data, out_data, req, in_grad, aux_states); }) } #else MSHADOW_REAL_TYPE_SWITCH_EX(out_grad[0].type_flag_, DType, AccReal, { - BNBackward(ctx, param, out_grad, + BatchNormBackward(ctx, param, out_grad, in_data, out_data, req, in_grad, aux_states); }); #endif diff --git a/src/operator/nn/concat-inl.h b/src/operator/nn/concat-inl.h index 4da7e68dd841..a7f1fa85f612 100644 --- a/src/operator/nn/concat-inl.h +++ b/src/operator/nn/concat-inl.h @@ -23,8 +23,8 @@ * \brief * \author Bing Xu */ -#ifndef MXNET_OPERATOR_CONCAT_INL_H_ -#define MXNET_OPERATOR_CONCAT_INL_H_ +#ifndef MXNET_OPERATOR_NN_CONCAT_INL_H_ +#define MXNET_OPERATOR_NN_CONCAT_INL_H_ #include #include #include @@ -42,6 +42,7 @@ namespace op { namespace concat_enum { enum ConcatOpInputs {kData0, kData1, kData2, kData3, kData4}; +enum ConcatOpResource {kTempSpace}; enum ConcatOpOutputs {kOut}; } // namespace concat_enum @@ -94,28 +95,26 @@ class ConcatOp { Concatenate(data, &out, 1, req[concat_enum::kOut]); } - void Backward(const OpContext &ctx, - const std::vector &out_grad, + void Backward(const OpContext &ctx, const TBlob &out_grad, const std::vector &req, const std::vector &in_grad) { using namespace mshadow; using namespace mshadow::expr; - CHECK_EQ(out_grad.size(), 1U); CHECK_EQ(in_grad.size(), static_cast(size_)); - int axis = CheckAxis(dimension_, out_grad[concat_enum::kData0].ndim()); + int axis = CheckAxis(dimension_, out_grad.ndim()); Stream *s = ctx.get_stream(); std::vector > grad_in(size_); Tensor grad; size_t leading = 1, trailing = 1; for (int i = 0; i < axis; ++i) { - leading *= out_grad[concat_enum::kOut].shape_[i]; + leading *= out_grad.shape_[i]; } - for (int i = axis + 1; i < out_grad[concat_enum::kOut].ndim(); ++i) { - trailing *= out_grad[concat_enum::kOut].shape_[i]; + for (int i = axis + 1; i < out_grad.ndim(); ++i) { + trailing *= out_grad.shape_[i]; } - size_t mid = out_grad[concat_enum::kOut].shape_[axis]; + size_t mid = out_grad.shape_[axis]; Shape<3> oshape = Shape3(leading, mid, trailing); - grad = out_grad[concat_enum::kOut].get_with_shape(oshape, s); + grad = out_grad.get_with_shape(oshape, s); for (int i = 0; i < size_; ++i) { Shape<3> dshape = Shape3(leading, in_grad[i].shape_[axis], trailing); @@ -151,11 +150,11 @@ void ConcatGradCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, MSHADOW_TYPE_SWITCH(inputs[concat_enum::kOut].type_flag_, DType, { ConcatOp op; op.Init(param); - op.Backward(ctx, inputs, req, outputs); + op.Backward(ctx, inputs[concat_enum::kOut], req, outputs); }); } } // namespace op } // namespace mxnet -#endif // MXNET_OPERATOR_CONCAT_INL_H_ +#endif // MXNET_OPERATOR_NN_CONCAT_INL_H_ diff --git a/src/operator/nn/concat.cc b/src/operator/nn/concat.cc index c5b7c6288b16..12b60113e11b 100644 --- a/src/operator/nn/concat.cc +++ b/src/operator/nn/concat.cc @@ -25,6 +25,9 @@ */ #include "./concat-inl.h" +#include "./mkldnn/mkldnn_ops-inl.h" +#include "./mkldnn/mkldnn_base-inl.h" +#include "../../common/utils.h" namespace mxnet { namespace op { @@ -102,12 +105,90 @@ static bool ConcatType(const nnvm::NodeAttrs& attrs, return true; } +inline static bool ConcatForwardInferStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK(!in_attrs->empty()); + CHECK_EQ(out_attrs->size(), 1U); +#if MXNET_USE_MKLDNN == 1 + const ConcatParam& param = nnvm::get(attrs.parsed); + if (dev_mask == mshadow::cpu::kDevMask + && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage) + && param.dim > 0) + *dispatch_mode = DispatchMode::kFComputeEx; + else +#endif + *dispatch_mode = DispatchMode::kFCompute; + (*out_attrs)[0] = kDefaultStorage; + return true; +} + +inline static bool BackwardConcatStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { +#if MXNET_USE_MKLDNN == 1 + const ConcatParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(out_attrs->size(), in_attrs->size() - 1); + if (dev_mask == mshadow::cpu::kDevMask + && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage) + && param.dim > 0) + *dispatch_mode = DispatchMode::kFComputeEx; + else +#endif + *dispatch_mode = DispatchMode::kFCompute; + for (size_t i = 0; i < out_attrs->size(); i++) + (*out_attrs)[i] = kDefaultStorage; + return true; +} + +#if MXNET_USE_MKLDNN == 1 +static void ConcatComputeExCPU(const nnvm::NodeAttrs& attrs, + const OpContext& op_ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK(!inputs.empty()); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + if (req[0] == kNullOp) return; + // MKLDNN support 2D and 4D concat + if ((inputs[0].shape().ndim() == 2 || inputs[0].shape().ndim() == 4) + && inputs[0].dtype() == mshadow::kFloat32) { + MKLDNNConcatForward(attrs, op_ctx, inputs, req, outputs); + return; + } + FallBackCompute(ConcatCompute, attrs, op_ctx, inputs, req, outputs); +} + +static void ConcatGradComputeExCPU(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + if ((inputs[0].shape().ndim() == 2 || inputs[0].shape().ndim() == 4) + && inputs[0].dtype() == mshadow::kFloat32) { + MKLDNNConcatBackward(attrs, ctx, inputs, req, outputs); + return; + } + FallBackCompute(ConcatGradCompute, attrs, ctx, inputs, req, outputs); +} +#endif + struct ConcatGrad { const char *op_name; std::vector operator()(const nnvm::NodePtr& n, const std::vector& ograds) const { - const ConcatParam& param = nnvm::get(n->attrs.parsed); + CHECK_EQ(ograds.size(), 1); std::vector heads(ograds.begin(), ograds.end()); +#if MXNET_USE_MKLDNN == 1 + for (size_t i = 0; i < n->inputs.size(); i++) { + heads.push_back(n->inputs[i]); + } +#endif return MakeGradNode(op_name, n, heads, n->attrs.dict); } }; @@ -162,9 +243,18 @@ Example:: } return ret; }) +#if MXNET_USE_MKLDNN == 1 +.set_attr("FResourceRequest", [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; +}) +#endif .set_attr("FInferShape", ConcatShape) .set_attr("FInferType", ConcatType) +.set_attr("FInferStorageType", ConcatForwardInferStorageType) .set_attr("FCompute", ConcatCompute) +#if MXNET_USE_MKLDNN == 1 +.set_attr("FComputeEx", ConcatComputeExCPU) +#endif .set_attr("FGradient", ConcatGrad{"_backward_Concat"}) .set_attr("key_var_num_args", "num_args") .add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate") @@ -178,7 +268,16 @@ NNVM_REGISTER_OP(_backward_Concat) return params.num_args; }) .set_attr_parser(ParamParser) +#if MXNET_USE_MKLDNN == 1 +.set_attr("FResourceRequest", [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; +}) +#endif .set_attr("TIsBackward", true) +.set_attr("FInferStorageType", BackwardConcatStorageType) +#if MXNET_USE_MKLDNN == 1 +.set_attr("FComputeEx", ConcatGradComputeExCPU) +#endif .set_attr("FCompute", ConcatGradCompute); } // namespace op diff --git a/src/operator/nn/convolution-inl.h b/src/operator/nn/convolution-inl.h index 6204f75c4697..e6cf1f52387c 100644 --- a/src/operator/nn/convolution-inl.h +++ b/src/operator/nn/convolution-inl.h @@ -118,6 +118,29 @@ struct ConvolutionParam : public dmlc::Parameter { this->cudnn_off == other.cudnn_off && this->layout == other.layout; } +#if MXNET_USE_MKLDNN == 1 + static uint64_t ComputeHash(const TShape &shape) { + uint64_t hash = 0; + for (size_t i = 0; i < shape.ndim(); i++) + hash = hash * 2 + shape[i]; + return hash; + } + + uint64_t GetHash() const { + uint64_t hash = 0; + hash = hash * 2 + ComputeHash(kernel); + hash = hash * 2 + ComputeHash(stride); + hash = hash * 2 + ComputeHash(dilate); + hash = hash * 2 + ComputeHash(pad); + hash = hash * 2 + num_filter; + hash = hash * 2 + num_group; + hash = hash * 2 + workspace; + hash = hash * 2 + no_bias; + if (layout.has_value()) + hash = hash * 2 + layout.value(); + return hash; + } +#endif }; } // namespace op diff --git a/src/operator/nn/convolution.cc b/src/operator/nn/convolution.cc index 60c56d69d340..01d876490d8a 100644 --- a/src/operator/nn/convolution.cc +++ b/src/operator/nn/convolution.cc @@ -26,11 +26,8 @@ #include "./convolution-inl.h" #include "../elemwise_op_common.h" -#if MXNET_USE_MKL2017 == 1 -#include -#include "../mkl/mkl_memory-inl.h" -#include "../mkl/mkl_convolution-inl.h" -#endif // MXNET_USE_MKL2017 +#include "./mkldnn/mkldnn_ops-inl.h" +#include "./mkldnn/mkldnn_base-inl.h" #if MXNET_USE_NNPACK == 1 #include "./nnpack/nnpack_convolution-inl.h" #endif // MXNET_USE_NNPACK @@ -51,6 +48,32 @@ static inline std::vector ListArguments(const ConvolutionParam& par } } +#if MXNET_USE_MKLDNN == 1 +static void ConvolutionComputeExCPU(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + if (SupportMKLDNNConv(inputs[0])) { + MKLDNNConvolutionForward(attrs, ctx, inputs, req, outputs); + return; + } + FallBackCompute(ConvolutionCompute, attrs, ctx, inputs, req, outputs); +} + +static void ConvolutionGradComputeExCPU(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + if (SupportMKLDNNConv(inputs[0])) { + MKLDNNConvolutionBackward(attrs, ctx, inputs, req, outputs); + return; + } + FallBackCompute(ConvolutionGradCompute, attrs, ctx, inputs, req, outputs); +} +#endif + static bool ConvolutionShape(const nnvm::NodeAttrs& attrs, std::vector *in_shape, std::vector *out_shape) { @@ -67,50 +90,50 @@ static bool ConvolutionShape(const nnvm::NodeAttrs& attrs, if (dshp.ndim() == 0) return false; if (param_.kernel.ndim() == 1) { - // 1d conv - CHECK_EQ(dshp.ndim(), 3U) << "Input data should be 3D in batch-num_filter-x"; - Shape<3> dshape = ConvertLayout(dshp.get<3>(), param_.layout.value(), kNCW); - Shape<3> wshape = Shape3(param_.num_filter / param_.num_group, dshape[1] / param_.num_group, - param_.kernel[0]); - wshape = ConvertLayout(wshape, kNCW, param_.layout.value()); - wshape[0] *= param_.num_group; - SHAPE_ASSIGN_CHECK(*in_shape, conv::kWeight, wshape); - if (!param_.no_bias) { - SHAPE_ASSIGN_CHECK(*in_shape, conv::kBias, Shape1(param_.num_filter)); - } + // 1d conv + CHECK_EQ(dshp.ndim(), 3U) << "Input data should be 3D in batch-num_filter-x"; + Shape<3> dshape = ConvertLayout(dshp.get<3>(), param_.layout.value(), kNCW); + Shape<3> wshape = Shape3(param_.num_filter / param_.num_group, dshape[1] / param_.num_group, + param_.kernel[0]); + wshape = ConvertLayout(wshape, kNCW, param_.layout.value()); + wshape[0] *= param_.num_group; + SHAPE_ASSIGN_CHECK(*in_shape, conv::kWeight, wshape); + if (!param_.no_bias) { + SHAPE_ASSIGN_CHECK(*in_shape, conv::kBias, Shape1(param_.num_filter)); + } - const index_t dilated_ksize_x = param_.DilatedKernelSize(0); - CHECK_EQ(dshape[1] % param_.num_group, 0U) \ - << "input num_filter must divide group size"; - CHECK_EQ(param_.num_filter % param_.num_group, 0U) \ - << "output num_filter must divide group size"; - CHECK_GT(param_.kernel.Size(), 0U) \ - << "incorrect kernel size: " << param_.kernel; - CHECK_GT(param_.stride.Size(), 0U) \ - << "incorrect stride size: " << param_.stride; - CHECK_GT(param_.dilate.Size(), 0U) \ - << "incorrect dilate size: " << param_.dilate; - Shape<3> oshape; - oshape[0] = dshape[0]; - oshape[1] = param_.num_filter; - oshape[2] = dshape[2] ? - (AddPad(dshape[2], param_.pad[0]) - dilated_ksize_x) / param_.stride[0] + 1 : 0; - SHAPE_ASSIGN_CHECK(*out_shape, 0, ConvertLayout(oshape, kNCW, param_.layout.value())); - // Perform incomplete shape inference. Fill in the missing values in data shape. - // 1) We can always fill in the batch_size. - // 2) We can back-calculate the input height/width if the corresponding stride is 1. - oshape = ConvertLayout((*out_shape)[0].get<3>(), param_.layout.value(), kNCW); - dshape[0] = oshape[0]; - if (oshape[2] && param_.stride[0] == 1) { - dshape[2] = oshape[2] + dilated_ksize_x - 1 - 2 * param_.pad[0]; - } - SHAPE_ASSIGN_CHECK(*in_shape, conv::kData, - ConvertLayout(dshape, kNCW, param_.layout.value())); - // Check whether the kernel sizes are valid - if (dshape[2] != 0) { - CHECK_LE(dilated_ksize_x, AddPad(dshape[2], param_.pad[0])) << "kernel size exceed input"; - } - return true; + const index_t dilated_ksize_x = param_.DilatedKernelSize(0); + CHECK_EQ(dshape[1] % param_.num_group, 0U) \ + << "input num_filter must divide group size"; + CHECK_EQ(param_.num_filter % param_.num_group, 0U) \ + << "output num_filter must divide group size"; + CHECK_GT(param_.kernel.Size(), 0U) \ + << "incorrect kernel size: " << param_.kernel; + CHECK_GT(param_.stride.Size(), 0U) \ + << "incorrect stride size: " << param_.stride; + CHECK_GT(param_.dilate.Size(), 0U) \ + << "incorrect dilate size: " << param_.dilate; + Shape<3> oshape; + oshape[0] = dshape[0]; + oshape[1] = param_.num_filter; + oshape[2] = dshape[2] ? + (AddPad(dshape[2], param_.pad[0]) - dilated_ksize_x) / param_.stride[0] + 1 : 0; + SHAPE_ASSIGN_CHECK(*out_shape, 0, ConvertLayout(oshape, kNCW, param_.layout.value())); + // Perform incomplete shape inference. Fill in the missing values in data shape. + // 1) We can always fill in the batch_size. + // 2) We can back-calculate the input height/width if the corresponding stride is 1. + oshape = ConvertLayout((*out_shape)[0].get<3>(), param_.layout.value(), kNCW); + dshape[0] = oshape[0]; + if (oshape[2] && param_.stride[0] == 1) { + dshape[2] = oshape[2] + dilated_ksize_x - 1 - 2 * param_.pad[0]; + } + SHAPE_ASSIGN_CHECK(*in_shape, conv::kData, + ConvertLayout(dshape, kNCW, param_.layout.value())); + // Check whether the kernel sizes are valid + if (dshape[2] != 0) { + CHECK_LE(dilated_ksize_x, AddPad(dshape[2], param_.pad[0])) << "kernel size exceed input"; + } + return true; } else if (param_.kernel.ndim() == 2) { // 2d conv CHECK_EQ(dshp.ndim(), 4U) \ @@ -259,6 +282,48 @@ static bool ConvolutionType(const nnvm::NodeAttrs& attrs, return true; } +inline static bool ConvStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + const ConvolutionParam& param = nnvm::get(attrs.parsed); + uint32_t in_expected = param.no_bias ? 2 : 3; + CHECK_EQ(in_attrs->size(), in_expected); + CHECK_EQ(out_attrs->size(), 1); + +#if MXNET_USE_MKLDNN == 1 + if (dev_mask == mshadow::cpu::kDevMask) + *dispatch_mode = DispatchMode::kFComputeEx; + else +#endif + *dispatch_mode = DispatchMode::kFCompute; + (*out_attrs)[0] = kDefaultStorage; + return true; +} + +inline static bool BackwardConvStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + const ConvolutionParam& param = nnvm::get(attrs.parsed); + uint32_t in_expected = param.no_bias ? 3 : 4; + uint32_t out_expected = param.no_bias ? 2 : 3; + CHECK_EQ(in_attrs->size(), in_expected); + CHECK_EQ(out_attrs->size(), out_expected); + +#if MXNET_USE_MKLDNN == 1 + if (dev_mask == mshadow::cpu::kDevMask) + *dispatch_mode = DispatchMode::kFComputeEx; + else +#endif + *dispatch_mode = DispatchMode::kFCompute; + for (size_t i = 0; i < out_attrs->size(); i++) + (*out_attrs)[i] = kDefaultStorage; + return true; +} + static void ConvolutionParamParser(nnvm::NodeAttrs* attrs) { using namespace mshadow; ConvolutionParam param_; @@ -400,18 +465,11 @@ There are other options to tune the performance. }) .set_attr("FInferShape", ConvolutionShape) .set_attr("FInferType", ConvolutionType) -.set_attr("FInferStorageType", [](const nnvm::NodeAttrs& attrs, - const int dev_mask, DispatchMode* dispatch_mode, - std::vector *in_attrs, std::vector *out_attrs) { - const ConvolutionParam& params = nnvm::get(attrs.parsed); - if (params.no_bias) - return ElemwiseStorageType<2, 1, false, false, false>(attrs, dev_mask, - dispatch_mode, in_attrs, out_attrs); - else - return ElemwiseStorageType<3, 1, false, false, false>(attrs, dev_mask, - dispatch_mode, in_attrs, out_attrs); -}) +.set_attr("FInferStorageType", ConvStorageType) .set_attr("FCompute", ConvolutionCompute) +#if MXNET_USE_MKLDNN == 1 +.set_attr("FComputeEx", ConvolutionComputeExCPU) +#endif .set_attr("FGradient", ConvolutionGrad{"_backward_Convolution"}) .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; @@ -427,10 +485,14 @@ NNVM_REGISTER_OP(_backward_Convolution) return params.no_bias ? 2 : 3; }) .set_attr("TIsBackward", true) +.set_attr("FInferStorageType", BackwardConvStorageType) .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; }) .set_attr_parser(ConvolutionParamParser) +#if MXNET_USE_MKLDNN == 1 +.set_attr("FComputeEx", ConvolutionGradComputeExCPU) +#endif .set_attr("FCompute", ConvolutionGradCompute); } // namespace op diff --git a/src/operator/nn/cudnn/cudnn_activation-inl.h b/src/operator/nn/cudnn/cudnn_activation-inl.h index 35827917c7d5..a89e7bfaf080 100644 --- a/src/operator/nn/cudnn/cudnn_activation-inl.h +++ b/src/operator/nn/cudnn/cudnn_activation-inl.h @@ -41,6 +41,7 @@ class CuDNNActivationOp { nan_prop_ = CUDNN_NOT_PROPAGATE_NAN; CUDNN_CALL(cudnnCreateActivationDescriptor(&desc_)); #endif + CUDNN_CALL(cudnnCreateTensorDescriptor(&shape_desc_)); } void Init(const ActivationParam ¶m) { @@ -62,7 +63,6 @@ class CuDNNActivationOp { #if CUDNN_MAJOR >= 5 CUDNN_CALL(cudnnSetActivationDescriptor(desc_, mode_, nan_prop_, relu_ceil_)); #endif - CUDNN_CALL(cudnnCreateTensorDescriptor(&shape_desc_)); } ~CuDNNActivationOp() { diff --git a/src/operator/nn/cudnn/cudnn_deconvolution-inl.h b/src/operator/nn/cudnn/cudnn_deconvolution-inl.h index eee955c96816..3c80cdcba4c2 100644 --- a/src/operator/nn/cudnn/cudnn_deconvolution-inl.h +++ b/src/operator/nn/cudnn/cudnn_deconvolution-inl.h @@ -201,7 +201,7 @@ class CuDNNDeconvolutionOp { using namespace mshadow::expr; size_t expected = param_.no_bias == 0 ? 3 : 2; CHECK_EQ(out_grad.size(), 1U); - CHECK_EQ(in_data.size(), 2U); + CHECK_EQ(in_data.size(), param_.no_bias ? 2U : 3U); CHECK_EQ(in_grad.size(), expected); Stream *s = ctx.get_stream(); @@ -217,6 +217,7 @@ class CuDNNDeconvolutionOp { CHECK_NE(req[deconv::kBias], kWriteInplace); } CHECK_NE(req[deconv::kData], kWriteInplace); + GetTempSize(ctx); Tensor workspace = AllocateTempWorkspace(ctx, backward_workspace_byte_); size_t workspace_size = TensorSizeBytes(workspace); for (uint32_t g = 0; g < param_.num_group; ++g) { diff --git a/src/operator/nn/deconvolution-inl.h b/src/operator/nn/deconvolution-inl.h index 88875848aef0..352ca66fb690 100644 --- a/src/operator/nn/deconvolution-inl.h +++ b/src/operator/nn/deconvolution-inl.h @@ -163,6 +163,31 @@ struct DeconvolutionParam : public dmlc::Parameter { this->cudnn_off == other.cudnn_off && this->layout == other.layout; } +#if MXNET_USE_MKLDNN == 1 + static uint64_t ComputeHash(const TShape &shape) { + uint64_t hash = 0; + for (size_t i = 0; i < shape.ndim(); i++) + hash = hash * 2 + shape[i]; + return hash; + } + + uint64_t GetHash() const { + uint64_t hash = 0; + hash = hash * 2 + ComputeHash(kernel); + hash = hash * 2 + ComputeHash(stride); + hash = hash * 2 + ComputeHash(dilate); + hash = hash * 2 + ComputeHash(pad); + hash = hash * 2 + ComputeHash(adj); + hash = hash * 2 + ComputeHash(target_shape); + hash = hash * 2 + num_filter; + hash = hash * 2 + num_group; + hash = hash * 2 + workspace; + hash = hash * 2 + no_bias; + if (layout.has_value()) + hash = hash * 2 + layout.value(); + return hash; + } +#endif }; } // namespace op @@ -331,7 +356,7 @@ class DeconvolutionOp { // TODO(bing): check the BLAS Handle, be careful CHECK_EQ(out_grad.size(), 1U); size_t expected = param_.no_bias == 0 ? 3 : 2; - CHECK_EQ(in_data.size(), 2U); + CHECK_EQ(in_data.size(), expected); CHECK_EQ(in_grad.size(), expected); CHECK_EQ(req.size(), expected); CHECK_EQ(in_data[deconv::kWeight].CheckContiguous(), true); diff --git a/src/operator/nn/deconvolution.cc b/src/operator/nn/deconvolution.cc index 8e2865e6f729..d5a3d3254748 100644 --- a/src/operator/nn/deconvolution.cc +++ b/src/operator/nn/deconvolution.cc @@ -25,6 +25,8 @@ */ #include "./deconvolution-inl.h" +#include "./mkldnn/mkldnn_ops-inl.h" +#include "./mkldnn/mkldnn_base-inl.h" namespace mxnet { namespace op { @@ -254,6 +256,75 @@ static bool DeconvolutionType(const nnvm::NodeAttrs& attrs, return true; } +inline static bool DeconvStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + const DeconvolutionParam& param = nnvm::get(attrs.parsed); + uint32_t in_expected = param.no_bias ? 2 : 3; + CHECK_EQ(in_attrs->size(), in_expected); + CHECK_EQ(out_attrs->size(), 1); + +#if MXNET_USE_MKLDNN == 1 + if (dev_mask == mshadow::cpu::kDevMask) + *dispatch_mode = DispatchMode::kFComputeEx; + else +#endif + *dispatch_mode = DispatchMode::kFCompute; + (*out_attrs)[0] = kDefaultStorage; + return true; +} + +inline static bool BackwardDeconvStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + const DeconvolutionParam& param = nnvm::get(attrs.parsed); + uint32_t out_expected = param.no_bias ? 2 : 3; + CHECK_EQ(in_attrs->size(), param.no_bias ? 3U : 4U); + CHECK_EQ(out_attrs->size(), out_expected); + +#if MXNET_USE_MKLDNN == 1 + if (dev_mask == mshadow::cpu::kDevMask) + *dispatch_mode = DispatchMode::kFComputeEx; + else +#endif + *dispatch_mode = DispatchMode::kFCompute; + for (size_t i = 0; i < out_attrs->size(); i++) + (*out_attrs)[i] = kDefaultStorage; + return true; +} + +#if MXNET_USE_MKLDNN == 1 +static void DeconvolutionComputeExCPU(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + if (SupportMKLDNNConv(inputs[0])) { + MKLDNNDeconvolutionForward(attrs, ctx, inputs, req, outputs); + return; + } + FallBackCompute(DeconvolutionCompute, attrs, ctx, inputs, req, + outputs); +} + +static void DeconvolutionGradComputeExCPU(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + if (SupportMKLDNNConv(inputs[0])) { + MKLDNNDeconvolutionBackward(attrs, ctx, inputs, req, outputs); + return; + } + FallBackCompute(DeconvolutionGradCompute, attrs, ctx, inputs, req, + outputs); +} +#endif + static void DeconvolutionParamParser(nnvm::NodeAttrs* attrs) { using namespace mshadow; DeconvolutionParam param_; @@ -288,6 +359,9 @@ struct DeconvolutionGrad { std::vector heads(ograds.begin(), ograds.end()); heads.push_back(n->inputs[deconv::kData]); heads.push_back(n->inputs[deconv::kWeight]); + const DeconvolutionParam& param = nnvm::get(n->attrs.parsed); + if (!param.no_bias) + heads.push_back(n->inputs[deconv::kBias]); return MakeGradNode(op_name, n, heads, n->attrs.dict); } }; @@ -312,10 +386,14 @@ NNVM_REGISTER_OP(Deconvolution) }) .set_attr("FInferShape", DeconvolutionShape) .set_attr("FInferType", DeconvolutionType) +.set_attr("FInferStorageType", DeconvStorageType) .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; }) .set_attr("FCompute", DeconvolutionCompute) +#if MXNET_USE_MKLDNN == 1 +.set_attr("FComputeEx", DeconvolutionComputeExCPU) +#endif .set_attr("FGradient", DeconvolutionGrad{"_backward_Deconvolution"}) .add_argument("data", "NDArray-or-Symbol", "Input tensor to the deconvolution operation.") .add_argument("weight", "NDArray-or-Symbol", "Weights representing the kernel.") @@ -329,10 +407,14 @@ NNVM_REGISTER_OP(_backward_Deconvolution) return params.no_bias ? 2 : 3; }) .set_attr("TIsBackward", true) +.set_attr("FInferStorageType", BackwardDeconvStorageType) .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; }) .set_attr_parser(DeconvolutionParamParser) +#if MXNET_USE_MKLDNN == 1 +.set_attr("FComputeEx", DeconvolutionGradComputeExCPU) +#endif .set_attr("FCompute", DeconvolutionGradCompute); } // namespace op diff --git a/src/operator/nn/deconvolution.cu b/src/operator/nn/deconvolution.cu index 58bbafcf839e..c7395428c2a0 100644 --- a/src/operator/nn/deconvolution.cu +++ b/src/operator/nn/deconvolution.cu @@ -39,13 +39,9 @@ static CuDNNDeconvolutionOp &GetCuDNNDeconvOp(const DeconvolutionParam& p int backward_compute_type, const std::vector& in_shape, const std::vector& out_shape, - const Context& ctx, bool backward) { - // Convolution forward has to be called before backward for this operator. - // So we can't make this operator thread local. backward might be called - // in another thread. - static CuDNNDeconvolutionOp op; - if (!backward) - op.Init(param, forward_compute_type, backward_compute_type, in_shape, out_shape, ctx); + const Context& ctx) { + static thread_local CuDNNDeconvolutionOp op; + op.Init(param, forward_compute_type, backward_compute_type, in_shape, out_shape, ctx); return op; } #endif @@ -82,7 +78,7 @@ void DeconvolutionCompute(const nnvm::NodeAttrs& attrs, in_shape[i] = inputs[i].shape_; } GetCuDNNDeconvOp(param, compute_type, compute_type, - in_shape, out_shape, ctx.run_ctx.ctx, false).Forward(ctx, inputs, req, outputs); + in_shape, out_shape, ctx.run_ctx.ctx).Forward(ctx, inputs, req, outputs); } }) #else @@ -129,7 +125,7 @@ void DeconvolutionGradCompute(const nnvm::NodeAttrs& attrs, in_shape[i] = in_data[i].shape_; } GetCuDNNDeconvOp(param, compute_type, compute_type, - in_shape, out_shape, ctx.run_ctx.ctx, true).Backward(ctx, + in_shape, out_shape, ctx.run_ctx.ctx).Backward(ctx, std::vector{out_grad}, in_data, req, in_grad); } }) diff --git a/src/operator/nn/dropout-inl.h b/src/operator/nn/dropout-inl.h index 343201062dbe..1d2c9eaeb456 100644 --- a/src/operator/nn/dropout-inl.h +++ b/src/operator/nn/dropout-inl.h @@ -38,13 +38,6 @@ #include "../operator_common.h" #include "../mshadow_op.h" -#if defined(USE_MKL) && defined(_OPENMP) -#include - -#include -#include -#endif // USE_MKL && _OPENMP - namespace dropout { enum DropoutOpInputs {kData}; enum DropoutOpOutputs {kOut, kMask}; @@ -55,28 +48,6 @@ enum DropoutOpMode {kTraining, kAlways}; namespace mxnet { namespace op { -#if defined(USE_MKL) && defined(_OPENMP) -static void bernoulli_generate(int n, double p, int* r) { - const int seed = 17 + rand() % 4096; // NOLINT(runtime/threadsafe_fn) - const int nthr = engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); -# pragma omp parallel num_threads(nthr) - { - const int ithr = omp_get_thread_num(); - const int avg_amount = (n + nthr - 1) / nthr; - const int my_offset = ithr * avg_amount; - const int my_amount = std::min(my_offset + avg_amount, n) - my_offset; - if (my_amount > 0) { - VSLStreamStatePtr stream; - vslNewStream(&stream, VSL_BRNG_MCG31, seed); - vslSkipAheadStream(stream, my_offset); - viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, stream, my_amount, - r + my_offset, p); - vslDeleteStream(&stream); - } - } -} -#endif // USE_MKL && _OPENMP - struct DropoutParam : public dmlc::Parameter { float p; int mode; @@ -109,23 +80,10 @@ void DropoutForward(const OpContext &ctx, const DropoutParam ¶m, Tensor out = out_data[dropout::kOut].FlatTo2D(s); if (ctx.is_train || mode_ == dropout::kAlways) { Tensor mask = out_data[dropout::kMask].FlatTo2D(s); -#if !defined(__CUDACC__) && defined(USE_MKL) && defined(_OPENMP) - DType* outptr = out.dptr_; - DType* dataptr = data.dptr_; - auto maskptr = reinterpret_cast(mask.dptr_); - int count = mask.shape_[0]*mask.shape_[1]; - bernoulli_generate(count, pkeep_, maskptr); - const float pk_1 = 1.0f / pkeep_; -#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) - for (int i = 0; i < count; ++i) { - outptr[i] = dataptr[i] * maskptr[i] * pk_1; - } -#else Random *prnd = ctx.requested[dropout::kRandom].get_random(s); mask = tcast(F( prnd->uniform(mask.shape_), pkeep_) * (1.0f / pkeep_)); Assign(out, req[dropout::kOut], data * mask); -#endif // USE_MKL && _OPENMP } else { Assign(out, req[dropout::kOut], F(data)); } @@ -143,20 +101,7 @@ void DropoutBackward(const OpContext &ctx, const DropoutParam ¶m, Tensor mask = out_data_mask.FlatTo2D(s); Tensor gdata = in_grad.FlatTo2D(s); if (ctx.is_train || mode_ == dropout::kAlways) { -#if !defined(__CUDACC__) && defined(USE_MKL) && defined(_OPENMP) - real_t pkeep_ = 1.0f - param.p; - DType* ingradptr = gdata.dptr_; - DType* outgradptr = grad.dptr_; - auto maskptr = reinterpret_cast(mask.dptr_); - int count = mask.shape_[0]*mask.shape_[1]; - const float pk_1 = 1.0f / pkeep_; -#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) - for (int i = 0; i < count; ++i) { - ingradptr[i] = outgradptr[i] * maskptr[i] * pk_1; - } -#else // USE_MKL && _OPENMP Assign(gdata, req, grad * mask); -#endif // USE_MKL && _OPENMP } else { Assign(gdata, req, F(grad)); } diff --git a/src/operator/nn/fully_connected-inl.h b/src/operator/nn/fully_connected-inl.h index 4646d3a5e199..e8e95643e647 100644 --- a/src/operator/nn/fully_connected-inl.h +++ b/src/operator/nn/fully_connected-inl.h @@ -43,6 +43,7 @@ namespace op { // These enums are only visible within this header namespace fullc { enum FullyConnectedOpInputs {kData, kWeight, kBias}; +enum FullyConnectedOpResource {kTempSpace}; enum FullyConnectedOpOutputs {kOut}; } // fullc diff --git a/src/operator/nn/fully_connected.cc b/src/operator/nn/fully_connected.cc index c4edf6dcab9b..70c310306e49 100644 --- a/src/operator/nn/fully_connected.cc +++ b/src/operator/nn/fully_connected.cc @@ -23,6 +23,8 @@ * \brief fully connect operator */ #include "./fully_connected-inl.h" +#include "./mkldnn/mkldnn_ops-inl.h" +#include "./mkldnn/mkldnn_base-inl.h" #if MXNET_USE_NNPACK == 1 #include "./nnpack/nnpack_fully_connected-inl.h" #endif // MXNET_USE_NNPACK @@ -71,6 +73,32 @@ static bool FullyConnectedShape(const nnvm::NodeAttrs& attrs, return true; } +#if MXNET_USE_MKLDNN == 1 +void FullyConnectedComputeExCPU(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + if (SupportMKLDNN(inputs[0])) { + MKLDNNFCForward(attrs, ctx, inputs, req, outputs); + return; + } + FallBackCompute(FullyConnectedCompute, attrs, ctx, inputs, req, outputs); +} + +void FullyConnectedGradComputeExCPU(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + if (SupportMKLDNN(inputs[0])) { + MKLDNNFCBackward(attrs, ctx, inputs, req, outputs); + return; + } + FallBackCompute(FullyConnectedGradCompute, attrs, ctx, inputs, req, outputs); +} +#endif + static bool FullyConnectedType(const nnvm::NodeAttrs& attrs, std::vector *in_type, std::vector *out_type) { CHECK_GE(in_type->size(), 1U); @@ -89,6 +117,49 @@ struct FullyConnectedGrad { } }; +inline static bool FCStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + const FullyConnectedParam& param = nnvm::get(attrs.parsed); + uint32_t in_expected = param.no_bias ? 2 : 3; + CHECK_EQ(in_attrs->size(), in_expected); + CHECK_EQ(out_attrs->size(), 1); + +#if MXNET_USE_MKLDNN == 1 + if (dev_mask == mshadow::cpu::kDevMask) + *dispatch_mode = DispatchMode::kFComputeEx; + else +#endif + *dispatch_mode = DispatchMode::kFCompute; + (*out_attrs)[0] = kDefaultStorage; + return true; +} + +inline static bool BackwardFCStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + const FullyConnectedParam& param = nnvm::get(attrs.parsed); + uint32_t out_expected = param.no_bias ? 2 : 3; + CHECK_EQ(in_attrs->size(), 3U); + CHECK_EQ(out_attrs->size(), out_expected); + +#if 0 + // TODO(zhengda) let's disable MKLDNN for FullyConnected for now. + // It seems there is a bug. + if (dev_mask == mshadow::cpu::kDevMask) + *dispatch_mode = DispatchMode::kFComputeEx; + else +#endif + *dispatch_mode = DispatchMode::kFCompute; + for (size_t i = 0; i < out_attrs->size(); i++) + (*out_attrs)[i] = kDefaultStorage; + return true; +} + DMLC_REGISTER_PARAMETER(FullyConnectedParam); NNVM_REGISTER_OP(FullyConnected) @@ -119,6 +190,7 @@ If ``no_bias`` is set to be true, then the ``bias`` term is ignored. }) .set_num_outputs(1) .set_attr_parser(ParamParser) +.set_attr("FInferStorageType", FCStorageType) .set_attr("FListInputNames", [](const NodeAttrs& attrs) { const FullyConnectedParam& params = nnvm::get(attrs.parsed); if (!params.no_bias) { @@ -127,9 +199,17 @@ If ``no_bias`` is set to be true, then the ``bias`` term is ignored. return std::vector{"data", "weight"}; } }) +#if MXNET_USE_MKLDNN == 1 +.set_attr("FResourceRequest", [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; +}) +#endif .set_attr("FInferShape", FullyConnectedShape) .set_attr("FInferType", FullyConnectedType) .set_attr("FCompute", FullyConnectedCompute) +#if MXNET_USE_MKLDNN == 1 +.set_attr("FComputeEx", FullyConnectedComputeExCPU) +#endif .set_attr("FGradient", FullyConnectedGrad{"_backward_FullyConnected"}) .add_argument("data", "NDArray-or-Symbol", "Input data.") .add_argument("weight", "NDArray-or-Symbol", "Weight matrix.") @@ -137,15 +217,25 @@ If ``no_bias`` is set to be true, then the ``bias`` term is ignored. .add_arguments(FullyConnectedParam::__FIELDS__()); NNVM_REGISTER_OP(_backward_FullyConnected) +.set_num_inputs(3) .set_num_outputs([](const NodeAttrs& attrs) { const FullyConnectedParam& params = nnvm::get(attrs.parsed); return params.no_bias ? 2 : 3; }) +#if MXNET_USE_MKLDNN == 1 +.set_attr("FResourceRequest", [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; +}) +#endif .set_attr("TIsBackward", true) .set_attr("FInplaceOption", [](const NodeAttrs& attrs){ return std::vector >{{1, 0}}; }) +.set_attr("FInferStorageType", BackwardFCStorageType) .set_attr_parser(ParamParser) +#if MXNET_USE_MKLDNN == 1 +.set_attr("FComputeEx", FullyConnectedGradComputeExCPU) +#endif .set_attr("FCompute", FullyConnectedGradCompute); } // namespace op diff --git a/src/operator/nn/lrn-inl.h b/src/operator/nn/lrn-inl.h index 2dfecea0bde1..fdae1eca0aef 100644 --- a/src/operator/nn/lrn-inl.h +++ b/src/operator/nn/lrn-inl.h @@ -23,8 +23,8 @@ * \brief * \author Bing Xu */ -#ifndef MXNET_OPERATOR_LRN_INL_H_ -#define MXNET_OPERATOR_LRN_INL_H_ +#ifndef MXNET_OPERATOR_NN_LRN_INL_H_ +#define MXNET_OPERATOR_NN_LRN_INL_H_ #include #include #include @@ -124,4 +124,4 @@ void LRNGradCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, } // namespace op } // namespace mxnet -#endif // MXNET_OPERATOR_LRN_INL_H_ +#endif // MXNET_OPERATOR_NN_LRN_INL_H_ diff --git a/src/operator/nn/lrn.cc b/src/operator/nn/lrn.cc index 21bf457512f2..00cac28f2484 100644 --- a/src/operator/nn/lrn.cc +++ b/src/operator/nn/lrn.cc @@ -25,8 +25,8 @@ */ #include "./lrn-inl.h" -#if MXNET_USE_CUDNN == 1 -#include "./cudnn_lrn-inl.h" +#if MXNET_USE_MKLDNN == 1 +#include "./mkldnn/mkldnn_lrn-inl.h" #endif namespace mxnet { @@ -71,15 +71,82 @@ static bool LRNType(const nnvm::NodeAttrs& attrs, struct LRNGrad { const char *op_name; std::vector operator()(const nnvm::NodePtr& n, - const std::vector& ograds) const { + const std::vector& ograds) const { std::vector heads; - heads.push_back(ograds[0]); // out_grad + heads.push_back(ograds[0]); // out_grad heads.push_back(n->inputs[lrn_enum::kData]); heads.emplace_back(nnvm::NodeEntry{n, lrn_enum::kTmpNorm, 0}); return MakeGradNode(op_name, n, heads, n->attrs.dict); } }; +inline static bool LRNForwardInferStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + *dispatch_mode = DispatchMode::kFCompute; +#if MXNET_USE_MKLDNN == 1 + CHECK(!in_attrs->empty()); + if (dev_mask == mshadow::cpu::kDevMask) { + *dispatch_mode = DispatchMode::kFComputeEx; + } +#endif + for (size_t i = 0; i < out_attrs->size(); i++) + (*out_attrs)[i] = kDefaultStorage; + return true; +} + +inline static bool LRNBackwardInferStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + *dispatch_mode = DispatchMode::kFCompute; +#if MXNET_USE_MKLDNN == 1 + CHECK(!in_attrs->empty()); + if (dev_mask == mshadow::cpu::kDevMask) { + *dispatch_mode = DispatchMode::kFComputeEx; + } +#endif + for (size_t i = 0; i < out_attrs->size(); i++) + (*out_attrs)[i] = kDefaultStorage; + return true; +} + +#if MXNET_USE_MKLDNN == 1 +void LRNComputeExCPU(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const LRNParam ¶m = nnvm::get(attrs.parsed); + if (SupportMKLDNN(inputs[0])) { + MKLDNNLRN_Forward(ctx, param, inputs[0], req[0], outputs[0]); + return; + } + FallBackCompute(LRNCompute, attrs, ctx, inputs, req, outputs); +} + +void LRNGradComputeExCPU(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const LRNParam ¶m = nnvm::get(attrs.parsed); + const NDArray &out_grad = inputs[0]; + const NDArray &in_data = inputs[1]; + const NDArray &in_grad = outputs[0]; + + if (SupportMKLDNN(inputs[0])) { + MKLDNNLRN_Backward(ctx, param, out_grad, in_data, + req[0], in_grad); + return; + } + FallBackCompute(LRNGradCompute, attrs, ctx, inputs, req, outputs); +} +#endif + DMLC_REGISTER_PARAMETER(LRNParam); NNVM_REGISTER_OP(LRN) @@ -106,7 +173,11 @@ number of kernels in the layer. .set_attr_parser(ParamParser) .set_attr("FInferShape", LRNShape) .set_attr("FInferType", LRNType) +.set_attr("FInferStorageType", LRNForwardInferStorageType) .set_attr("FCompute", LRNCompute) +#if MXNET_USE_MKLDNN == 1 +.set_attr("FComputeEx", LRNComputeExCPU) +#endif .set_attr("FGradient", LRNGrad{"_backward_LRN"}) .add_argument("data", "NDArray-or-Symbol", "Input data to LRN") .add_arguments(LRNParam::__FIELDS__()); @@ -114,7 +185,11 @@ number of kernels in the layer. NNVM_REGISTER_OP(_backward_LRN) .set_num_outputs(1) .set_attr_parser(ParamParser) +.set_attr("FInferStorageType", LRNBackwardInferStorageType) .set_attr("TIsBackward", true) +#if MXNET_USE_MKLDNN == 1 +.set_attr("FComputeEx", LRNGradComputeExCPU) +#endif .set_attr("FCompute", LRNGradCompute); } // namespace op diff --git a/src/operator/nn/lrn.cu b/src/operator/nn/lrn.cu index 83dd1d0322ea..4c31ca96025c 100644 --- a/src/operator/nn/lrn.cu +++ b/src/operator/nn/lrn.cu @@ -25,9 +25,6 @@ */ #include "./lrn-inl.h" -#if MXNET_USE_CUDNN == 1 -#include "./cudnn_lrn-inl.h" -#endif namespace mxnet { namespace op { diff --git a/src/operator/nn/mkldnn/mkldnn_act.cc b/src/operator/nn/mkldnn/mkldnn_act.cc new file mode 100644 index 000000000000..71fdf4ca585b --- /dev/null +++ b/src/operator/nn/mkldnn/mkldnn_act.cc @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file mkldnn_act.cc + * \brief + * \author Da Zheng +*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "../../operator_common.h" +#include "../activation-inl.h" +#include "./mkldnn_base-inl.h" + +#if MXNET_USE_MKLDNN == 1 + +#include + +namespace mxnet { +namespace op { + +bool SupportMKLDNNAct(const ActivationParam& param) { + // We only enable ReLU for now. It seems other activations have some precision + // problems. + return param.act_type == activation::kReLU; +#if 0 + || param.act_type == activation::kSigmoid + || param.act_type == activation::kSoftReLU; +#endif +} + +static inline mkldnn::algorithm GetMKLDNNActAlgo(const ActivationParam& param) { + switch (param.act_type) { + case activation::kReLU: + return mkldnn::algorithm::eltwise_relu; + case activation::kSigmoid: + return mkldnn::algorithm::eltwise_logistic; + case activation::kTanh: + return mkldnn::algorithm::eltwise_tanh; + case activation::kSoftReLU: + return mkldnn::algorithm::eltwise_soft_relu; + default: + LOG(FATAL) << "unknown activation type"; + return mkldnn::algorithm::eltwise_relu; + } +} + +typedef std::shared_ptr mkldnn_act_pdesc_ptr; + +static mkldnn::eltwise_forward::primitive_desc GetActFwdDescImpl( + const ActivationParam& param, bool is_train, + const mkldnn::memory &input_mem, int dtype) { + mkldnn::memory::primitive_desc data_mpd = input_mem.get_primitive_desc(); + mkldnn::memory::desc data_md = data_mpd.desc(); + auto cpu_engine = data_mpd.get_engine(); + + auto alg = GetMKLDNNActAlgo(param); + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { + DType alpha = 0; + mkldnn::eltwise_forward::desc desc = is_train + ? mkldnn::eltwise_forward::desc(mkldnn::prop_kind::forward_training, + alg, data_md, alpha) + : mkldnn::eltwise_forward::desc(mkldnn::prop_kind::forward_scoring, + alg, data_md, alpha); + return mkldnn::eltwise_forward::primitive_desc(desc, cpu_engine); + }); + LOG(INFO) << "Unsupported data type for MKLDNN activation"; + mkldnn::eltwise_forward::desc desc = mkldnn::eltwise_forward::desc( + mkldnn::prop_kind::forward_training, alg, data_md, 0.0); + return mkldnn::eltwise_forward::primitive_desc(desc, cpu_engine); +} + +typedef MKLDNNParamOpSign MKLDNNActSignature; + +class MKLDNNActForward { + std::shared_ptr fwd; + std::shared_ptr data; + std::shared_ptr out; + + public: + const mkldnn::eltwise_forward::primitive_desc fwd_pd; + + MKLDNNActForward(const ActivationParam& param, bool is_train, + const NDArray &data, const mkldnn::memory &mem): fwd_pd( + GetActFwdDescImpl(param, is_train, mem, data.dtype())) { + } + + void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &output) { + if (this->data == nullptr) + this->data = std::shared_ptr(new mkldnn::memory( + data.get_primitive_desc(), data.get_data_handle())); + else + this->data->set_data_handle(data.get_data_handle()); + + CHECK(fwd_pd.dst_primitive_desc() == output.get_primitive_desc()); + if (this->out == nullptr) + this->out = std::shared_ptr(new mkldnn::memory( + fwd_pd.dst_primitive_desc(), output.get_data_handle())); + else + this->out->set_data_handle(output.get_data_handle()); + + if (this->fwd == nullptr) { + this->fwd = std::shared_ptr( + new mkldnn::eltwise_forward(fwd_pd, mkldnn::primitive::at(*this->data), + *this->out)); + } + } + + const mkldnn::eltwise_forward &GetFwd() const { + return *fwd; + } +}; + +static MKLDNNActForward &GetActForward(const ActivationParam& param, + const OpContext &ctx, const NDArray &in_data, + const mkldnn::memory &in_mem) { + static thread_local std::unordered_map fwds; + MKLDNNActSignature key(param); + key.AddSign(ctx.is_train); + key.AddSign(param.act_type); + key.AddSign(in_data); + + auto it = fwds.find(key); + if (it == fwds.end()) { + MKLDNNActForward fwd(param, ctx.is_train, in_data, in_mem); + auto ins_ret = fwds.insert(std::pair( + key, fwd)); + CHECK(ins_ret.second); + it = ins_ret.first; + } + return it->second; +} + +void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const NDArray &in_data, const OpReqType &req, + const NDArray &out_data) { + const ActivationParam& param = nnvm::get(attrs.parsed); + auto input_mem = in_data.GetMKLDNNData(); + MKLDNNActForward &fwd = GetActForward(param, ctx, in_data, *input_mem); + auto out_mem = const_cast(out_data).CreateMKLDNNData( + fwd.fwd_pd.dst_primitive_desc()); + fwd.SetNewMem(*input_mem, *out_mem); + MKLDNNStream *stream = MKLDNNStream::Get(); + stream->RegisterPrim(fwd.GetFwd()); + stream->Submit(); +} + +void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const NDArray &out_grad, const NDArray &in_data, + const OpReqType &req, const NDArray &in_grad) { + if (req == kNullOp) { + return; + } + + const ActivationParam& param = nnvm::get(attrs.parsed); + TmpMemMgr::Get()->Init(ctx.requested[activation::kTempSpace]); + auto diff_dst_memory = out_grad.GetMKLDNNData(); + auto input_mem = in_data.GetMKLDNNData(); + // We need to make sure the two inputs to eltwise_backward has the same memory + // descriptor. Otherwise, the perf will suffer. + if (input_mem->get_primitive_desc() != diff_dst_memory->get_primitive_desc()) + input_mem = in_data.GetMKLDNNDataReorder(diff_dst_memory->get_primitive_desc()); + mkldnn::memory::primitive_desc data_mpd = input_mem->get_primitive_desc(); + mkldnn::memory::desc data_md = data_mpd.desc(); + mkldnn::memory::desc diff_md = diff_dst_memory->get_primitive_desc().desc(); + auto cpu_engine = data_mpd.get_engine(); + + MKLDNNStream *stream = MKLDNNStream::Get(); + auto alg = GetMKLDNNActAlgo(param); + mkldnn_output_t diff_src_memory; + + MSHADOW_REAL_TYPE_SWITCH(in_data.dtype(), DType, { + DType alpha = 0; + mkldnn::eltwise_forward::desc fw_desc(mkldnn::prop_kind::forward_training, + alg, data_md, alpha); + mkldnn::eltwise_forward::primitive_desc fw_pdesc(fw_desc, cpu_engine); + mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, alpha); + mkldnn::eltwise_backward::primitive_desc bw_pdesc(bw_desc, cpu_engine, + fw_pdesc); + + diff_src_memory = CreateMKLDNNMem(in_grad, + bw_pdesc.diff_src_primitive_desc(), req); + stream->RegisterPrim(mkldnn::eltwise_backward(bw_pdesc, *input_mem, + *diff_dst_memory, + *diff_src_memory.second)); + }); + CommitOutput(in_grad, diff_src_memory); + stream->Submit(); +} + +} // namespace op +} // namespace mxnet + +#endif diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h new file mode 100644 index 000000000000..7fcf1c028d37 --- /dev/null +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -0,0 +1,429 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/******************************************************************************* +* Copyright 2016-2017 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +* \file mkldnn_base-inl.h +* \brief +* \author young.jin.kim@intel.com +* ashok.emani@intel.com +* deepthi.karkada@intel.com +* louis.feng@intel.com +* adam.d.straw@intel.com +* zhengda1936@gmail.com +* +*******************************************************************************/ + +#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_BASE_INL_H_ +#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_BASE_INL_H_ + +#if MXNET_USE_MKLDNN == 1 +#include +#include +#include +#include +#include +#include +#include +#include "mkldnn.hpp" +#include "mxnet/ndarray.h" +#include "mxnet/resource.h" +#include "mxnet/op_attr_types.h" +using namespace mkldnn; +namespace mxnet { +extern bool EnableMkldnnWarnGenerated(); +// ===== CpuEngine ======================================= +// cpu_engine singleton +class CpuEngine { + public: + static CpuEngine *Get() { + // I's thread-safe in C++11. + static thread_local CpuEngine myInstance; + return &myInstance; + } + CpuEngine(CpuEngine const &) = delete; // Copy construct + CpuEngine(CpuEngine &&) = delete; // Move construct + CpuEngine &operator=(CpuEngine const &) = delete; // Copy assign + CpuEngine &operator=(CpuEngine &&) = delete; // Move assign + + mkldnn::engine &get_engine() { return _cpu_engine; } + + protected: + CpuEngine() : _cpu_engine(mkldnn::engine::cpu, 0) {} + ~CpuEngine() {} + + private: + mkldnn::engine _cpu_engine; +}; + +// type enumerator +template +struct data_type_enum {}; + +template <> +struct data_type_enum { + enum { type = mkldnn::memory::data_type::f32 }; +}; + +template <> +struct data_type_enum { + enum { type = mkldnn::memory::data_type::s32 }; +}; + +template <> +struct data_type_enum { + enum { type = mkldnn::memory::data_type::s16 }; +}; + +template <> +struct data_type_enum { + enum { type = mkldnn::memory::data_type::s8 }; +}; + +template <> +struct data_type_enum { + enum { type = mkldnn::memory::data_type::u8 }; +}; + +static inline bool SupportMKLDNNArray(int dtype, const TShape &shape) { + int ndim = shape.ndim(); + bool support = ndim == 1 || ndim == 2 || ndim == 4; + support = support && (dtype == mshadow::kFloat32 || dtype == mshadow::kInt32 + || dtype == mshadow::kInt8 || dtype == mshadow::kUint8); + return support; +} + +static inline bool SupportStorageMKLDNN(int stype) { + return stype == kMKLDNNStorage || stype == kDefaultStorage; +} + +static inline bool SupportMKLDNN(int dtype, const TShape &shape) { + int ndim = shape.ndim(); + return dtype == mshadow::kFloat32 && (ndim == 1 || ndim == 2 || ndim == 4); +} + +static inline bool SupportMKLDNN(const NDArray &input) { + return SupportMKLDNN(input.dtype(), input.shape()) + && SupportStorageMKLDNN(input.storage_type()); +} + +static inline bool SupportMKLDNNConv(const NDArray &input) { + return input.dtype() == mshadow::kFloat32 && input.shape().ndim() == 4; +} + +namespace op { +struct ActivationParam; +bool SupportMKLDNNAct(const op::ActivationParam& param); +} + +static int GetTypeSize(int dtype) { + int size = -1; + MSHADOW_TYPE_SWITCH(dtype, DType, { + size = sizeof(DType); + }); + return size; +} + +static inline size_t GetArraySize(const NDArray &arr) { + return arr.shape().Size() * GetTypeSize(arr.dtype()); +} + +static inline mkldnn::memory::data_type get_mkldnn_type(int dtype) { + switch (dtype) { + case mshadow::kFloat32: + return mkldnn::memory::data_type::f32; + case mshadow::kInt32: + return mkldnn::memory::data_type::s32; + case mshadow::kInt8: + return mkldnn::memory::data_type::s8; + case mshadow::kUint8: + return mkldnn::memory::data_type::u8; + default: + LOG(FATAL) << "unknown type for MKLDNN"; + return mkldnn::memory::data_type::data_undef; + } +} + +inline static mkldnn::memory::desc GetMemDesc(const NDArray &arr, int ndim) { + mkldnn::memory::dims dims(ndim); + for (size_t i = 0; i < dims.size(); i++) dims[i] = arr.shape()[i]; + return mkldnn::memory::desc{dims, get_mkldnn_type(arr.dtype()), + mkldnn::memory::format::any}; +} + +inline static mkldnn::memory::desc GetMemDesc(const NDArray &arr) { + return GetMemDesc(arr, arr.shape().ndim()); +} + +inline static mkldnn::memory::desc GetWeightDesc(const NDArray &arr, + int num_groups) { + if (num_groups == 1) { + return GetMemDesc(arr); + } else { + CHECK_EQ(arr.shape().ndim(), 4U); + mkldnn::memory::dims tz = mkldnn::memory::dims{ num_groups, + static_cast(arr.shape()[0] / num_groups), + static_cast(arr.shape()[1]), + static_cast(arr.shape()[2]), + static_cast(arr.shape()[3])}; + return mkldnn::memory::desc{tz, get_mkldnn_type(arr.dtype()), + mkldnn::memory::format::any}; + } +} + +typedef std::shared_ptr mkldnn_mem_ptr; +typedef std::shared_ptr mkldnn_mem_const_ptr; + +class TmpMemMgr { + // This points to the memory buffer where we can allocate temp memory. + char *curr_mem; + // The total size of the temp memory. + size_t mem_size; + // This contains the current available memory size. + size_t curr_size; + // This estimate the required temp memory size in an operator. + size_t est_size; + const size_t alignment = 4096; + + public: + static TmpMemMgr *Get() { + static thread_local TmpMemMgr mgr; + return &mgr; + } + + TmpMemMgr() { + Reset(); + est_size = 0; + mem_size = 0; + } + + void Reset() { + curr_mem = nullptr; + curr_size = 0; + // We don't reset est_size and mem_size because est_size contains the + // estimated temp memory size from the last run and mem_size contains the + // memroy size allocated in the last run. + } + + void Init(const Resource &r) { + // If the last time, if we estimate that we need more memory, we should the + // larger memory size. + mem_size = std::max(mem_size, est_size); + if (mem_size > 0) { + // Let's allocate some extra memory. If we don't use some of them all the time, + // the OS won't physically allocate pages for them any way. + this->curr_size = mem_size * 2; + this->curr_mem = static_cast(r.get_host_space_internal(this->curr_size)); + } + // reset est_size, so we can start to estimate the temp memory size. + this->est_size = 0; + } + + mkldnn::memory *Alloc(const mkldnn::memory::primitive_desc &pd); +}; + +class MKLDNNStream { + std::vector net; + // Here we hold all memory related to the operators in the stream. + std::vector > mem_holder; + + public: + static MKLDNNStream *Get() { + static thread_local MKLDNNStream stream; + return &stream; + } + + void RegisterPrim(const mkldnn::primitive &prim) { net.push_back(prim); } + + void RegisterMem(std::shared_ptr mem) { + mem_holder.push_back(mem); + } + + void Submit() { + if (!net.empty()) + mkldnn::stream(mkldnn::stream::kind::eager).submit(net).wait(); + net.clear(); + mem_holder.clear(); + TmpMemMgr::Get()->Reset(); + } +}; + +class MKLDNNOpSignature { + std::vector eles; + uint64_t hash; + + public: + MKLDNNOpSignature() { + hash = 0; + } + + explicit MKLDNNOpSignature(uint64_t hash) { + this->hash = hash; + } + + /* + * We provide different methods to add signature to an op. + * For operations, such as convolutin and fully connected, which determines + * the optimal data layout for the op, we only need to use the shape and data + * type to sign the op. For other operations, such as activation, which uses + * whatever layout in the input array, we have to use the shape, the data type + * and the layout to sign the op. + */ + + void AddSign(const mkldnn::memory &mem) { + auto desc = mem.get_primitive_desc().desc(); + hash = hash * 2 + desc.data.format; + eles.push_back(desc.data.format); + hash = hash * 2 + desc.data.data_type; + eles.push_back(desc.data.data_type); + for (int i = 0; i < desc.data.ndims; i++) { + hash = hash * 2 + desc.data.dims[i]; + eles.push_back(desc.data.dims[i]); + } + } + + void AddSign(const std::vector &arrs) { + for (auto &arr : arrs) { + AddSign(arr); + } + } + + void AddSign(const NDArray &arr) { + if (arr.IsMKLDNN()) { + AddSign(*(arr.GetMKLDNNData())); + } else { + hash = hash * 2 + arr.dtype(); + eles.push_back(arr.dtype()); + AddSign(arr.shape()); + } + } + + void AddSign(const TShape &shape) { + for (size_t i = 0; i < shape.ndim(); i++) { + hash = hash * 2 + shape[i]; + eles.push_back(shape[i]); + } + } + + void AddSign(int val) { + hash = hash * 2 + val; + eles.push_back(val); + } + + bool operator==(const MKLDNNOpSignature &sign) const { + if (hash != sign.hash) + return false; + if (eles.size() != sign.eles.size()) + return false; + for (size_t i = 0; i < eles.size(); i++) + if (eles[i] != sign.eles[i]) + return false; + return true; + } + + uint64_t GetHash() const { + return hash; + } +}; + +struct MKLDNNOpHash { + size_t operator()(const MKLDNNOpSignature &sign) const { + return sign.GetHash(); + } +}; + +template +class MKLDNNParamOpSign: public MKLDNNOpSignature { + const ParamType param; + + public: + explicit MKLDNNParamOpSign(const ParamType &_param): MKLDNNOpSignature( + _param.GetHash()), param(_param) { + } + + bool operator==(const MKLDNNParamOpSign &sign) const { + const MKLDNNOpSignature &this_upper = *this; + const MKLDNNOpSignature &other_upper = sign; + return this_upper == other_upper && param == sign.param; + } +}; + +enum OutDataOp { + Noop, + CopyBack, + AddBack, +}; + +typedef std::pair mkldnn_output_t; + +/* + * These two functions try to create MKLDNN memory in an NDArray based on `req'. + * The difference is that the first function can create MKLDNN memory with + * special layouts in an NDArray, while the second one can only create MKLDNN + * memory with default layouts. + * If these two functions are used, we have to call CommitOutput to write + * the output back to the output NDArray. + */ +mkldnn_output_t CreateMKLDNNMem(const NDArray &arr, + const mkldnn::memory::primitive_desc &desc, + OpReqType req); +mkldnn_output_t CreateMKLDNNWeightGrad(const NDArray &arr, + const mkldnn::memory::primitive_desc &desc, + OpReqType req); +/* This function has to be used with one of the functions above. */ +void CommitOutput(const NDArray &arr, const mkldnn_output_t &res); + +static inline void InvalidateOutputs(const std::vector &arrs, + const std::vector &reqs) { + for (size_t i = 0; i < arrs.size(); i++) { + if (reqs[i] == kWriteTo || reqs[i] == kNullOp) { + const_cast(arrs[i]).InvalidateData(); + } + } +} + +const mkldnn::memory *GetWeights(const NDArray &arr, + const mkldnn::memory::primitive_desc &target_pd, + int num_groups); + +mkldnn_memory_format_t GetDefaultFormat(mkldnn::memory::desc desc); +mkldnn::memory::primitive_desc GetPrimitiveDesc(mkldnn::memory::primitive_desc pd, + mkldnn_memory_format_t format); + +void FallBackCompute(FCompute fn, const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs); +} // namespace mxnet +#endif +#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_BASE_INL_H_ diff --git a/src/operator/nn/mkldnn/mkldnn_base.cc b/src/operator/nn/mkldnn/mkldnn_base.cc new file mode 100644 index 000000000000..947bc5979d85 --- /dev/null +++ b/src/operator/nn/mkldnn/mkldnn_base.cc @@ -0,0 +1,260 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#if MXNET_USE_MKLDNN == 1 + +#include "./mkldnn_base-inl.h" +#include "./mkldnn_ops-inl.h" + +namespace mxnet { + +mkldnn::memory *TmpMemMgr::Alloc(const mkldnn::memory::primitive_desc &pd) { + // We need to include the size of the memory used for alignment. + this->est_size += pd.get_size() + alignment; + void *this_mem = this->curr_mem; + void *mem = std::align(alignment, pd.get_size(), this_mem, this->curr_size); + if (mem) { + // The memory is allocated from the temporary memory space in the + // operator. It'll only become invalid after we exit from the operator. + mkldnn_mem_ptr ret(new mkldnn::memory(pd, this_mem)); + MKLDNNStream::Get()->RegisterMem(ret); + CHECK_EQ(this_mem, mem); + this->curr_size -= pd.get_size(); + this->curr_mem = static_cast(this_mem) + pd.get_size(); + return ret.get(); + } else { + LOG(WARNING) << "Allocate " << pd.get_size() + << " bytes with malloc directly"; + mkldnn_mem_ptr ret(new mkldnn::memory(pd)); + MKLDNNStream::Get()->RegisterMem(ret); + return ret.get(); + } +} + +mkldnn_output_t CreateMKLDNNMem(const NDArray &arr, + const mkldnn::memory::primitive_desc &desc, + OpReqType req) { + if (kAddTo == req) { + auto tmp = TmpMemMgr::Get()->Alloc(desc); + return mkldnn_output_t(OutDataOp::AddBack, tmp); + } else if (kWriteInplace == req) { + // MKLDNN ops may not support the case that the input and the output uses + // the same memory. Let's use an extra copy to make sure it always works. + auto tmp = TmpMemMgr::Get()->Alloc(desc); + return mkldnn_output_t(OutDataOp::CopyBack, tmp); + } else { + mkldnn::memory *mem = const_cast(arr).CreateMKLDNNData(desc); + if (mem == nullptr) { + auto tmp = TmpMemMgr::Get()->Alloc(desc); + return mkldnn_output_t(OutDataOp::CopyBack, tmp); + } else { + return mkldnn_output_t(OutDataOp::Noop, mem); + } + } +} + +mkldnn_output_t CreateMKLDNNWeightGrad(const NDArray &arr, + const mkldnn::memory::primitive_desc &desc, + OpReqType req) { + if (kAddTo == req) { + auto tmp = TmpMemMgr::Get()->Alloc(desc); + return mkldnn_output_t(OutDataOp::AddBack, tmp); + } else if (kWriteInplace == req) { + auto tmp = TmpMemMgr::Get()->Alloc(desc); + return mkldnn_output_t(OutDataOp::CopyBack, tmp); + } else { + auto _desc = desc; + auto def_format = GetDefaultFormat(_desc.desc()); + mkldnn::memory *mem = nullptr; + if (def_format == _desc.desc().data.format) { + mem = const_cast(arr).CreateMKLDNNData(desc); + } + if (mem == nullptr) { + auto tmp = TmpMemMgr::Get()->Alloc(desc); + return mkldnn_output_t(OutDataOp::CopyBack, tmp); + } else { + return mkldnn_output_t(OutDataOp::Noop, mem); + } + } +} + +void CommitOutput(const NDArray &arr, const mkldnn_output_t &res) { + if (res.first == CopyBack) { + const_cast(arr).CopyFrom(*res.second); + } else if (res.first == AddBack) { + auto mem = arr.GetMKLDNNData(res.second->get_primitive_desc()); + CHECK(mem != nullptr); + // We have to allocate new memory for the sum result. + auto sum_res = TmpMemMgr::Get()->Alloc( + res.second->get_primitive_desc()); + op::Sum(*res.second, *mem, *sum_res); + const_cast(arr).CopyFrom(*sum_res); + } +} + +const mkldnn::memory *GetWeights(const NDArray &arr, + const mkldnn::memory::primitive_desc &target_pd, + int num_groups) { + const mkldnn::memory *mem = arr.GetMKLDNNData(target_pd); + // If the weight array already uses the target layout, simply return it + // directly. + if (mem) + return mem; + + mkldnn::memory::data_type type = get_mkldnn_type(arr.dtype()); + auto engine = CpuEngine::Get()->get_engine(); + if (arr.shape().ndim() == 2) { + mkldnn::memory::dims tz = mkldnn::memory::dims{ + static_cast(arr.shape()[0]), static_cast(arr.shape()[1])}; + mkldnn::memory::desc md = + mkldnn::memory::desc{tz, type, mkldnn::memory::format::oi}; + mkldnn::memory::primitive_desc pd = + mkldnn::memory::primitive_desc{md, engine}; + mem = arr.GetMKLDNNData(pd); + } else if (arr.shape().ndim() == 4 && num_groups == 1) { + mkldnn::memory::dims tz = mkldnn::memory::dims{ + static_cast(arr.shape()[0]), static_cast(arr.shape()[1]), + static_cast(arr.shape()[2]), static_cast(arr.shape()[3])}; + mkldnn::memory::desc md = + mkldnn::memory::desc{tz, type, mkldnn::memory::format::oihw}; + mkldnn::memory::primitive_desc pd = + mkldnn::memory::primitive_desc{md, engine}; + mem = arr.GetMKLDNNData(pd); + } else if (arr.shape().ndim() == 4) { + mkldnn::memory::dims tz = mkldnn::memory::dims{ num_groups, + static_cast(arr.shape()[0] / num_groups), + static_cast(arr.shape()[1]), + static_cast(arr.shape()[2]), + static_cast(arr.shape()[3])}; + mkldnn::memory::desc md = + mkldnn::memory::desc{tz, type, mkldnn::memory::format::goihw}; + mkldnn::memory::primitive_desc pd = + mkldnn::memory::primitive_desc{md, engine}; + mem = arr.GetMKLDNNData(pd); + } else { + LOG(FATAL) << "The weight array has an unsupported number of dimensions"; + return nullptr; + } + if (mem == nullptr) + mem = arr.GetMKLDNNDataReorder(target_pd); + if (mem->get_primitive_desc() == target_pd) return mem; + + auto ret = TmpMemMgr::Get()->Alloc(target_pd); + MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(*mem, *ret)); + return ret; +} + +mkldnn_memory_format_t GetDefaultFormat(mkldnn::memory::desc desc) { + if (desc.data.ndims == 1) { + return desc.data.format; + } else if (desc.data.ndims == 2) { + if (desc.data.format == mkldnn_io) + return mkldnn_oi; + else + return desc.data.format; + } else if (desc.data.ndims == 4) { + switch (desc.data.format) { + case mkldnn_nchw: + case mkldnn_nhwc: + case mkldnn_chwn: + case mkldnn_nChw8c: + case mkldnn_nChw16c: + return mkldnn_nchw; + case mkldnn_oihw: + case mkldnn_ihwo: + case mkldnn_hwio: + case mkldnn_OIhw8i8o: + case mkldnn_OIhw16i16o: + case mkldnn_OIhw8i16o2i: + case mkldnn_OIhw8o16i2o: + case mkldnn_OIhw8o8i: + case mkldnn_OIhw16o16i: + case mkldnn_IOhw16o16i: + case mkldnn_Oihw8o: + case mkldnn_Oihw16o: + case mkldnn_Ohwi8o: + case mkldnn_Ohwi16o: + case mkldnn_OhIw16o4i: + return mkldnn_oihw; + default: + LOG(FATAL) << "Unknown MKLDNN format for 4 dimensions: " << desc.data.format; + return mkldnn_format_undef; + } + } else if (desc.data.ndims == 5) { + switch (desc.data.format) { + case mkldnn_goihw: + case mkldnn_gOIhw8i8o: + case mkldnn_gOIhw16i16o: + case mkldnn_gOIhw8i16o2i: + case mkldnn_gOIhw8o16i2o: + case mkldnn_gOIhw8o8i: + case mkldnn_gOIhw16o16i: + case mkldnn_gIOhw16o16i: + case mkldnn_gOihw8o: + case mkldnn_gOihw16o: + case mkldnn_gOhwi8o: + case mkldnn_gOhwi16o: + case mkldnn_gOhIw16o4i: + return mkldnn_goihw; + default: + LOG(FATAL) << "Unknown MKLDNN format for 4 dimensions: " << desc.data.format; + return mkldnn_format_undef; + } + } else { + LOG(FATAL) << "Unsupported dimensions: " << desc.data.ndims; + return mkldnn_format_undef; + } +} + +mkldnn::memory::primitive_desc GetPrimitiveDesc(mkldnn::memory::primitive_desc pd, + mkldnn_memory_format_t format) { + mkldnn::memory::dims dims(pd.desc().data.ndims); + for (size_t i = 0; i < dims.size(); i++) + dims[i] = pd.desc().data.dims[i]; + mkldnn::memory::format cpp_format = static_cast(format); + mkldnn::memory::data_type cpp_type = static_cast( + pd.desc().data.data_type); + mkldnn::memory::desc data_md(dims, cpp_type, cpp_format); + return mkldnn::memory::primitive_desc(data_md, pd.get_engine()); +} + +void FallBackCompute(FCompute fn, const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + // TODO(zhengda) We should buffer the NDArrays. + std::vector in_bufs; + std::vector in_blobs(inputs.size()); + for (size_t i = 0; i < in_blobs.size(); i++) { + in_blobs[i] = inputs[i].data(); + } + std::vector out_blobs(outputs.size()); + for (size_t i = 0; i < out_blobs.size(); i++) { + if (req[i] == kWriteTo) + const_cast(outputs[i]).InvalidateData(); + CHECK(outputs[i].IsDefault()); + out_blobs[i] = outputs[i].data(); + } + fn(attrs, ctx, in_blobs, req, out_blobs); +} + +} // namespace mxnet + +#endif diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h new file mode 100644 index 000000000000..a72e7e2e1ba6 --- /dev/null +++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h @@ -0,0 +1,431 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file mkldnn_batch_norm.cc + * \brief + * \author Tao Lv +*/ + +#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_BATCH_NORM_INL_H_ +#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_BATCH_NORM_INL_H_ + +#if MXNET_USE_MKLDNN == 1 +#include +#include +#include +#include "../batch_norm-inl.h" +#include "./mkldnn_ops-inl.h" +#include "./mkldnn_base-inl.h" + +#define VARIANCE_TO_INVSTD(__var$, __eps$) (1.0/sqrt((__var$) + DType(__eps$))) +#define INVSTD_TO_VARIANCE(__invstd$, __eps$) ((1.0 / ((__invstd$) * (__invstd$))) - (__eps$)) +namespace mxnet { +namespace op { + +typedef mkldnn::batch_normalization_forward::primitive_desc t_bn_f_pdesc; +typedef mkldnn::batch_normalization_forward::desc t_bn_f_desc; +typedef mkldnn::batch_normalization_backward::primitive_desc t_bn_b_pdesc; +typedef mkldnn::batch_normalization_backward::desc t_bn_b_desc; + +using mkldnn::use_global_stats; +using mkldnn::use_scale_shift; +using mkldnn::forward_training; +using mkldnn::forward_inference; + +inline static unsigned _GetFlags(const std::vector &in_data, + const std::vector &aux_states, + const BatchNormParam ¶m, bool is_train) { + unsigned flags = 0U; + if (in_data.size() == 3U) { + flags |= use_scale_shift; + } + + // aux_states[0]: inMean + // aux_states[1]: inVariance + if (aux_states.size() == 2U && !is_train) { + flags |= use_global_stats; + } + return flags; +} + +template +inline static t_bn_f_pdesc _GetFwd(const mkldnn::memory &data_mem, + bool is_train, + DType eps, + unsigned flags) { + auto data_mpd = data_mem.get_primitive_desc(); + auto data_md = data_mpd.desc(); + auto engine = CpuEngine::Get()->get_engine(); + + if (is_train) { + t_bn_f_desc bnFwd_desc(forward_training, data_md, eps, flags); + return t_bn_f_pdesc(bnFwd_desc, engine); + } else { + t_bn_f_desc bnFwd_desc(forward_inference, data_md, eps, flags); + return t_bn_f_pdesc(bnFwd_desc, engine); + } +} + +template +inline static t_bn_b_pdesc _GetBwd(const mkldnn::memory &data_mem, + const mkldnn::memory &diff_mem, + DType eps, + unsigned flags) { + auto data_mpd = data_mem.get_primitive_desc(); + auto data_md = data_mpd.desc(); + auto diff_mpd = diff_mem.get_primitive_desc(); + auto diff_md = diff_mpd.desc(); + auto engine = CpuEngine::Get()->get_engine(); + + t_bn_b_desc bnBwd_desc(mkldnn::prop_kind::backward, diff_md, data_md, eps, flags); + return t_bn_b_pdesc(bnBwd_desc, engine, _GetFwd(data_mem, true, eps, flags)); +} + +typedef MKLDNNParamOpSign MKLDNNBNSignature; + +class MKLDNNBNForward { + std::shared_ptr data_m; + std::shared_ptr weight_m; + std::shared_ptr out_m; + std::shared_ptr mean_m; + std::shared_ptr var_m; + std::shared_ptr fwd; + bool is_train; + t_bn_f_pdesc pd; + + public: + MKLDNNBNForward(const t_bn_f_pdesc &_pd, bool is_train): pd(_pd) { + weight_m.reset(new mkldnn::memory(pd.weights_primitive_desc())); + this->is_train = is_train; + } + + const mkldnn::memory &GetWeight() const { + return *weight_m; + } + + const t_bn_f_pdesc &GetPd() const { + return pd; + } + + const mkldnn::memory &GetMean() const { + return *mean_m; + } + + const mkldnn::memory &GetVar() const { + return *var_m; + } + + void SetDataHandle(const NDArray &data, const NDArray &mean, + const NDArray &var, const mkldnn::memory &out) { + auto _data = data.GetMKLDNNData(); + if (data_m) { + data_m->set_data_handle(_data->get_data_handle()); + } else { + data_m.reset(new mkldnn::memory(_data->get_primitive_desc(), + _data->get_data_handle())); + } + if (out_m) { + out_m->set_data_handle(out.get_data_handle()); + } else { + out_m.reset(new mkldnn::memory(out.get_primitive_desc(), + out.get_data_handle())); + } + auto mean_ptr = mean.data().dptr_; + if (mean_m) { + mean_m->set_data_handle(mean_ptr); + } else { + mean_m.reset(new mkldnn::memory(pd.mean_primitive_desc(), + mean_ptr)); + } + auto var_ptr = var.data().dptr_; + if (var_m) { + var_m->set_data_handle(var_ptr); + } else { + var_m.reset(new mkldnn::memory(pd.variance_primitive_desc(), + var_ptr)); + } + + if (fwd == nullptr) { + if (!is_train) + fwd.reset(new mkldnn::batch_normalization_forward( + pd, *data_m, mkldnn::primitive::at(*mean_m), + mkldnn::primitive::at(*var_m), *weight_m, *out_m)); + else + fwd.reset(new mkldnn::batch_normalization_forward( + pd, mkldnn::primitive::at(*data_m), + mkldnn::primitive::at(*weight_m), *out_m, + *mean_m, *var_m)); + } + } + + const mkldnn::batch_normalization_forward &GetFwd() const { + return *fwd; + } +}; + +template +static MKLDNNBNForward &GetBNForward(const BatchNormParam& param, + const OpContext &ctx, const NDArray &in_data, + unsigned flags) { + static thread_local std::unordered_map fwds; + MKLDNNBNSignature key(param); + key.AddSign(ctx.is_train); + key.AddSign(in_data); + + auto it = fwds.find(key); + if (it == fwds.end()) { + auto fwd_pd = _GetFwd(*in_data.GetMKLDNNData(), ctx.is_train, + (DType) param.eps, flags); + MKLDNNBNForward fwd(fwd_pd, ctx.is_train); + auto ins_ret = fwds.insert(std::pair( + key, fwd)); + CHECK(ins_ret.second); + it = ins_ret.first; + } + return it->second; +} + +template +void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam ¶m, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data, + const std::vector &aux_states) { + TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]); + unsigned flags = _GetFlags(in_data, aux_states, param, ctx.is_train); + const NDArray &data = in_data[batchnorm::kData]; + + auto &fwd = GetBNForward(param, ctx, data, flags); + const NDArray &out = out_data[batchnorm::kOut]; + + // for output memory + auto out_mem = const_cast(out).CreateMKLDNNData(fwd.GetPd().dst_primitive_desc()); + + // mxnet will always use scale shift. + // But if fix_gamma is true, then all scale elements will be set to 1.0f + if (flags & use_scale_shift) { + const NDArray &gamma = in_data[batchnorm::kGamma]; + const NDArray &beta = in_data[batchnorm::kBeta]; + CHECK_EQ(gamma.storage_type(), mxnet::kDefaultStorage); + CHECK_EQ(beta.storage_type(), mxnet::kDefaultStorage); + + const mkldnn::memory &weight_mem = fwd.GetWeight(); + DType* weight_buf = reinterpret_cast(weight_mem.get_data_handle()); + + nnvm::dim_t channels_ = data.shape()[1]; + CHECK(weight_mem.get_primitive_desc().get_size() == channels_ * sizeof(DType) * 2); + DType* weight_ptr = gamma.data().dptr(); + DType* bias_ptr = beta.data().dptr(); + if (!param.fix_gamma) { +#pragma omp parallel for simd + for (int i = 0; i < channels_; i++) { + weight_buf[i] = weight_ptr[i]; + weight_buf[channels_ + i] = bias_ptr[i]; // bias + } + } else if (IsBNWriting(req[batchnorm::kGamma])) { +#pragma omp parallel for simd + for (int i = 0; i < channels_; i++) { + weight_buf[i] = (DType)1.0f; + weight_ptr[i] = (DType)1.0f; + weight_buf[channels_ + i] = bias_ptr[i]; // bias + } + } else { +#pragma omp parallel for simd + for (int i = 0; i < channels_; i++) { + weight_buf[i] = (DType)1.0f; + weight_buf[channels_ + i] = bias_ptr[i]; // bias + } + } + + if (!ctx.is_train) { + DType* omean = out_data[batchnorm::kMean].data().dptr(); + DType* ovar = out_data[batchnorm::kVar].data().dptr(); + DType* inmean = aux_states[batchnorm::kMovingMean].data().dptr(); + DType* invar = aux_states[batchnorm::kMovingVar].data().dptr(); + // to align with origin implmentation: batch_norm.cc: L164 +#pragma omp parallel for simd + for (int i = 0; i < channels_; i++) { + omean[i] = inmean[i]; + ovar[i] = VARIANCE_TO_INVSTD(invar[i], param.eps); + } + + fwd.SetDataHandle(data, aux_states[batchnorm::kMovingMean], + aux_states[batchnorm::kMovingVar], + *out_mem); + MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd()); + MKLDNNStream::Get()->Submit(); + } else { // training + const NDArray &outMean = out_data[batchnorm::kMean]; + const NDArray &outVar = out_data[batchnorm::kVar]; + DType* omean = outMean.data().dptr(); + DType* ovar = outVar.data().dptr(); + + fwd.SetDataHandle(data, outMean, outVar, *out_mem); + MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd()); + MKLDNNStream::Get()->Submit(); + DType* mean_mem_ptr = reinterpret_cast(fwd.GetMean().get_data_handle()); + DType* var_mem_ptr = reinterpret_cast(fwd.GetVar().get_data_handle()); +#pragma omp parallel for simd + for (int i = 0; i < channels_; i++) { + omean[i] = mean_mem_ptr[i]; + ovar[i] = VARIANCE_TO_INVSTD(var_mem_ptr[i], param.eps); + } + } + } else { // no input gamma and beta + LOG(FATAL) << "MKLDNN batch normalization: should not reach here ..."; + } +} + +template +void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam ¶m, + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &req, + const std::vector &in_grad, + const std::vector &aux_states) { + TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]); + CHECK_EQ(out_grad.size(), param.output_mean_var ? 3U : 1U); + CHECK_EQ(in_data.size(), 3U); + CHECK_EQ(out_data.size(), 3U); + CHECK_EQ(in_grad.size(), 3U); + unsigned flags = _GetFlags(in_data, aux_states, param, ctx.is_train); + + const NDArray &data = in_data[batchnorm::kData]; + const NDArray &diff = out_grad[batchnorm::kOut]; + const NDArray &gradIn = in_grad[batchnorm::kData]; + const NDArray &moving_mean = aux_states[batchnorm::kMovingMean]; + const NDArray &moving_var = aux_states[batchnorm::kMovingVar]; + const NDArray &out_mean = out_data[batchnorm::kMean]; + const NDArray &out_var = out_data[batchnorm::kVar]; + + CHECK(out_mean.IsDefault()); + CHECK(out_var.IsDefault()); + CHECK(moving_mean.IsDefault()); + CHECK(moving_var.IsDefault()); + + auto data_mem = data.GetMKLDNNData(); + auto diff_mem = diff.GetMKLDNNData(); + // MKLDNN batchnorm should run on special layouts. If one of them isn't, we + // should reorder them. + if (data.IsDefault()) + data_mem = data.GetMKLDNNDataReorder(diff_mem->get_primitive_desc()); + else if (diff.IsDefault()) + diff_mem = diff.GetMKLDNNDataReorder(data_mem->get_primitive_desc()); + auto bwd_pd = _GetBwd(*data_mem, *diff_mem, param.eps, flags); + auto gradi_mem = const_cast(gradIn).CreateMKLDNNData(data_mem->get_primitive_desc()); + + if (flags & use_scale_shift) { + const NDArray &gamma = in_data[batchnorm::kGamma]; + const NDArray &beta = in_data[batchnorm::kBeta]; + // TODO(tao): how to reuse this memory? + std::shared_ptr weight_mem( + new mkldnn::memory(bwd_pd.weights_primitive_desc())); + + DType* weight_buf = reinterpret_cast(weight_mem->get_data_handle()); + nnvm::dim_t channels_ = data.shape()[1]; + for (int i = 0; i < channels_; i++) { + if (!param.fix_gamma) + weight_buf[i] = (gamma.data().dptr())[i]; // weight + else + weight_buf[i] = (DType)1.0f; + } + + for (int i = 0; i < channels_; i++) { + weight_buf[channels_ + i] = (beta.data().dptr())[i]; // bias + } + + std::shared_ptr gradw_mem( + new mkldnn::memory(bwd_pd.diff_weights_primitive_desc())); + // training but no input mean and variance + if (ctx.is_train && !param.use_global_stats) { + DType* moving_mean_ptr = reinterpret_cast(moving_mean.data().dptr()); + DType* moving_var_ptr = reinterpret_cast(moving_var.data().dptr()); + DType* out_mean_ptr = reinterpret_cast(out_mean.data().dptr()); + DType* out_var_ptr = reinterpret_cast(out_var.data().dptr()); + mkldnn::memory var_mem(bwd_pd.variance_primitive_desc()); + DType *tmp_var_ptr = reinterpret_cast(var_mem.get_data_handle()); + + DType minus_mom = (1.0f - param.momentum); + for (int i = 0; i < channels_; i++) { + moving_mean_ptr[i] = moving_mean_ptr[i] * param.momentum + + out_mean_ptr[i] * minus_mom; + float variance = INVSTD_TO_VARIANCE(out_var_ptr[i], param.eps); + tmp_var_ptr[i] = variance; + moving_var_ptr[i] = moving_var_ptr[i] * param.momentum + + variance * minus_mom; + } + + std::shared_ptr out_mean_mem( + new mkldnn::memory(bwd_pd.mean_primitive_desc(), out_mean_ptr)); + std::shared_ptr out_var_mem( + new mkldnn::memory(bwd_pd.variance_primitive_desc(), out_var_ptr)); + + auto bn_bwd = mkldnn::batch_normalization_backward(bwd_pd, + *data_mem, + mkldnn::primitive::at(*out_mean_mem), + mkldnn::primitive::at(var_mem), + *diff_mem, + *weight_mem, + *gradi_mem, + *gradw_mem); + + MKLDNNStream::Get()->RegisterPrim(bn_bwd); + MKLDNNStream::Get()->Submit(); + } else { + std::shared_ptr imean_mem( + new mkldnn::memory(bwd_pd.mean_primitive_desc(), + moving_mean.data().dptr())); + std::shared_ptr ivar_mem( + new mkldnn::memory(bwd_pd.variance_primitive_desc(), + moving_var.data().dptr())); + auto bn_bwd = mkldnn::batch_normalization_backward(bwd_pd, + *data_mem, + mkldnn::primitive::at(*imean_mem), + mkldnn::primitive::at(*ivar_mem), + *diff_mem, + *weight_mem, + *gradi_mem, + *gradw_mem); + + MKLDNNStream::Get()->RegisterPrim(bn_bwd); + MKLDNNStream::Get()->Submit(); + } + + // copy data from gradw_mem to in_grad[1] and in_grad[2] + DType* gw_buf = reinterpret_cast(gradw_mem->get_data_handle()); + for (int i = 0; i < channels_; i++) { + if (!param.fix_gamma) + (in_grad[1].data().dptr())[i] = gw_buf[i]; + else + (in_grad[1].data().dptr())[i] = 0.0f; + } + + for (int i = 0; i < channels_; i++) { + (in_grad[2].data().dptr())[i] = gw_buf[i + channels_]; + } + } else { + LOG(FATAL) << "MKLDNN batch normalization backward: should not reach here ..."; + } +} +} // namespace op +} // namespace mxnet +#endif // MXNET_USE_MKLDNN +#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_BATCH_NORM_INL_H_ diff --git a/src/operator/nn/mkldnn/mkldnn_concat.cc b/src/operator/nn/mkldnn/mkldnn_concat.cc new file mode 100644 index 000000000000..d3e6e775020d --- /dev/null +++ b/src/operator/nn/mkldnn/mkldnn_concat.cc @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file mkldnn_concat.cc + * \brief + * \author Wenting Jiang +*/ +#include "../concat-inl.h" +#include "./mkldnn_ops-inl.h" +#include "./mkldnn_base-inl.h" + +#if MXNET_USE_MKLDNN == 1 +namespace mxnet { +namespace op { + +void MKLDNNConcatForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data) { + TmpMemMgr::Get()->Init(ctx.requested[concat_enum::kTempSpace]); + const ConcatParam& param = nnvm::get(attrs.parsed); + int num_in_data = param.num_args; + int concat_dim = param.dim; + std::vector data_md; + std::vector data_mem; + for (int i =0; i < num_in_data; i++) { + auto tmp_mem = in_data[i].GetMKLDNNData(); + auto tmp_pd = tmp_mem->get_primitive_desc(); + data_md.push_back(tmp_pd); + data_mem.push_back(*tmp_mem); + } + mkldnn::concat::primitive_desc fwd_pd(concat_dim, data_md); + auto engine = CpuEngine::Get()->get_engine(); + auto out_mem = CreateMKLDNNMem(out_data[concat_enum::kOut], + fwd_pd.dst_primitive_desc(), req[concat_enum::kOut]); + MKLDNNStream::Get()->RegisterPrim(mkldnn::concat(fwd_pd, data_mem, *out_mem.second)); + CommitOutput(out_data[concat_enum::kOut], out_mem); + MKLDNNStream::Get()->Submit(); +} + +void MKLDNNConcatBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + TmpMemMgr::Get()->Init(ctx.requested[concat_enum::kTempSpace]); + const ConcatParam& param = nnvm::get(attrs.parsed); + int num_in_data = param.num_args; + int axis_ = param.dim; + auto engine = CpuEngine::Get()->get_engine(); + auto gz_mem = inputs[0].GetMKLDNNData(); + mkldnn::memory::primitive_desc gz_pd = gz_mem->get_primitive_desc(); + /* init the offset */ + mkldnn::memory::dims offsets = {0, 0, 0, 0}; + for (int i = 0; i < num_in_data; i++) { + mkldnn::memory::dims diff_src_tz + = {static_cast(inputs[i+1].shape()[0]), + static_cast(inputs[i+1].shape()[1]), + static_cast(inputs[i+1].shape()[2]), + static_cast(inputs[i+1].shape()[3])}; + auto diff_src_mpd = inputs[i+1].GetMKLDNNData()->get_primitive_desc(); + auto gradi_mem_ = CreateMKLDNNMem(outputs[i], diff_src_mpd, req[i]); + // create view from gy to gxs[i] + std::shared_ptr view_pd; + view_pd.reset(new mkldnn::view::primitive_desc(gz_pd, diff_src_tz, offsets)); + // create reorder primitive from gy to gxs[i] + mkldnn::reorder::primitive_desc reorder_pd( + view_pd.get()->dst_primitive_desc(), diff_src_mpd); + offsets[axis_] += diff_src_tz[axis_]; + MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder( + reorder_pd, *gz_mem, *gradi_mem_.second)); + CommitOutput(outputs[i], gradi_mem_); + } + MKLDNNStream::Get()->Submit(); +} + +} // namespace op +} // namespace mxnet +#endif diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc b/src/operator/nn/mkldnn/mkldnn_convolution.cc new file mode 100644 index 000000000000..f10ff0f674a2 --- /dev/null +++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc @@ -0,0 +1,357 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file mkldnn_convolution.cc + * \brief + * \author Da Zheng +*/ + +#include "../convolution-inl.h" +#include "./mkldnn_ops-inl.h" +#include "./mkldnn_base-inl.h" + +#if MXNET_USE_MKLDNN == 1 +namespace mxnet { +namespace op { + +static mkldnn::convolution_forward::primitive_desc GetConvFwdImpl( + const ConvolutionParam& param, bool is_train, const NDArray &data, + const NDArray &weights, const NDArray *bias, const NDArray &output) { + auto prop = is_train ? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring; + auto data_md = GetMemDesc(data); + auto weight_md = GetWeightDesc(weights, param.num_group); + auto out_md = GetMemDesc(output); + auto engine = CpuEngine::Get()->get_engine(); + mkldnn::memory::dims strides{0, 0}; + if (param.stride.ndim() == 2) { + strides[0] = param.stride[0]; + strides[1] = param.stride[1]; + } + mkldnn::memory::dims padding{0, 0}; + if (param.pad.ndim() == 2) { + padding[0] = param.pad[0]; + padding[1] = param.pad[1]; + } + if (param.dilate.ndim() == 0 && bias == nullptr) { + mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, + data_md, weight_md, out_md, strides, padding, padding, mkldnn::padding_kind::zero); + return mkldnn::convolution_forward::primitive_desc(desc, engine); + } else if (param.dilate.ndim() == 0) { + auto bias_md = GetMemDesc(*bias); + mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, + data_md, weight_md, bias_md, out_md, strides, padding, padding, + mkldnn::padding_kind::zero); + return mkldnn::convolution_forward::primitive_desc(desc, engine); + } else { + mkldnn::memory::dims dilates{0, 0}; + if (param.dilate.ndim() == 2) { + dilates[0] = param.dilate[0] - 1; + dilates[1] = param.dilate[1] - 1; + } + if (bias == nullptr) { + mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, + data_md, weight_md, out_md, strides, dilates, padding, padding, + mkldnn::padding_kind::zero); + return mkldnn::convolution_forward::primitive_desc(desc, engine); + } else { + auto bias_md = GetMemDesc(*bias); + mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, + data_md, weight_md, bias_md, out_md, strides, + dilates, padding, padding, + mkldnn::padding_kind::zero); + return mkldnn::convolution_forward::primitive_desc(desc, engine); + } + } +} + +static mkldnn::convolution_backward_data::primitive_desc GetConvBwdData( + const ConvolutionParam& param, const NDArray &data, const NDArray &weights, + const NDArray &output, const mkldnn::convolution_forward::primitive_desc &fwd_pd) { + auto data_md = GetMemDesc(data); + auto weight_md = GetWeightDesc(weights, param.num_group); + auto out_md = GetMemDesc(output); + auto engine = CpuEngine::Get()->get_engine(); + mkldnn::memory::dims strides{0, 0}; + if (param.stride.ndim() == 2) { + strides[0] = param.stride[0]; + strides[1] = param.stride[1]; + } + mkldnn::memory::dims padding{0, 0}; + if (param.pad.ndim() == 2) { + padding[0] = param.pad[0]; + padding[1] = param.pad[1]; + } + if (param.dilate.ndim() == 0) { + mkldnn::convolution_backward_data::desc desc(mkldnn::algorithm::convolution_direct, + data_md, weight_md, out_md, strides, padding, padding, mkldnn::padding_kind::zero); + return mkldnn::convolution_backward_data::primitive_desc(desc, engine, fwd_pd); + } else { + mkldnn::memory::dims dilates{0, 0}; + if (param.dilate.ndim() == 2) { + dilates[0] = param.dilate[0] - 1; + dilates[1] = param.dilate[1] - 1; + } + mkldnn::convolution_backward_data::desc desc(mkldnn::algorithm::convolution_direct, + data_md, weight_md, out_md, strides, dilates, padding, padding, + mkldnn::padding_kind::zero); + return mkldnn::convolution_backward_data::primitive_desc(desc, engine, fwd_pd); + } +} + +static mkldnn::convolution_backward_weights::primitive_desc GetConvBwdWeights( + const ConvolutionParam& param, const NDArray &data, + const NDArray &weights, const NDArray *bias, const NDArray &output, + const mkldnn::convolution_forward::primitive_desc &fwd_pd) { + auto data_md = GetMemDesc(data); + auto weight_md = GetWeightDesc(weights, param.num_group); + auto out_md = GetMemDesc(output); + auto engine = CpuEngine::Get()->get_engine(); + mkldnn::memory::dims strides{0, 0}; + if (param.stride.ndim() == 2) { + strides[0] = param.stride[0]; + strides[1] = param.stride[1]; + } + mkldnn::memory::dims padding{0, 0}; + if (param.pad.ndim() == 2) { + padding[0] = param.pad[0]; + padding[1] = param.pad[1]; + } + if (param.dilate.ndim() == 0 && bias == nullptr) { + mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct, + data_md, weight_md, out_md, strides, padding, padding, mkldnn::padding_kind::zero); + return mkldnn::convolution_backward_weights::primitive_desc(desc, engine, fwd_pd); + } else if (param.dilate.ndim() == 0) { + auto bias_md = GetMemDesc(*bias); + mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct, + data_md, weight_md, bias_md, out_md, strides, padding, padding, + mkldnn::padding_kind::zero); + return mkldnn::convolution_backward_weights::primitive_desc(desc, engine, fwd_pd); + } else { + mkldnn::memory::dims dilates{0, 0}; + if (param.dilate.ndim() == 2) { + dilates[0] = param.dilate[0] - 1; + dilates[1] = param.dilate[1] - 1; + } + if (bias == nullptr) { + mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct, + data_md, weight_md, out_md, strides, dilates, padding, padding, + mkldnn::padding_kind::zero); + return mkldnn::convolution_backward_weights::primitive_desc(desc, engine, fwd_pd); + } else { + auto bias_md = GetMemDesc(*bias); + mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct, + data_md, weight_md, bias_md, out_md, + strides, dilates, padding, padding, + mkldnn::padding_kind::zero); + return mkldnn::convolution_backward_weights::primitive_desc(desc, engine, fwd_pd); + } + } +} + +class MKLDNNConvForward { + std::shared_ptr fwd; + std::shared_ptr data; + std::shared_ptr weight; + std::shared_ptr bias; + std::shared_ptr out; + + public: + mkldnn::convolution_forward::primitive_desc fwd_pd; + + MKLDNNConvForward(const ConvolutionParam& param, bool is_train, + const NDArray &data, const NDArray &weights, + const NDArray *bias, const NDArray &output): fwd_pd( + GetConvFwdImpl(param, is_train, data, weights, bias, output)) { + } + + void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight, + const mkldnn::memory *bias, const mkldnn::memory &output) { + if (this->data == nullptr) + this->data = std::shared_ptr(new mkldnn::memory( + fwd_pd.src_primitive_desc(), data.get_data_handle())); + else + this->data->set_data_handle(data.get_data_handle()); + + if (this->weight == nullptr) + this->weight = std::shared_ptr(new mkldnn::memory( + fwd_pd.weights_primitive_desc(), weight.get_data_handle())); + else + this->weight->set_data_handle(weight.get_data_handle()); + + if (this->out == nullptr) + this->out = std::shared_ptr(new mkldnn::memory( + fwd_pd.dst_primitive_desc(), output.get_data_handle())); + else + this->out->set_data_handle(output.get_data_handle()); + + if (bias != nullptr) { + if (this->bias == nullptr) + this->bias = std::shared_ptr(new mkldnn::memory( + fwd_pd.bias_primitive_desc(), bias->get_data_handle())); + else + this->bias->set_data_handle(bias->get_data_handle()); + if (this->fwd == nullptr) + this->fwd = std::shared_ptr( + new mkldnn::convolution_forward(fwd_pd, mkldnn::primitive::at(*this->data), + mkldnn::primitive::at(*this->weight), + mkldnn::primitive::at(*this->bias), + *this->out)); + } else if (this->fwd == nullptr) { + this->fwd = std::shared_ptr( + new mkldnn::convolution_forward(fwd_pd, mkldnn::primitive::at(*this->data), + mkldnn::primitive::at(*this->weight), + *this->out)); + } + } + + const mkldnn::convolution_forward &GetFwd() const { + return *fwd; + } +}; + +typedef MKLDNNParamOpSign MKLDNNConvSignature; + +static inline MKLDNNConvForward &GetConvFwd( + const nnvm::NodeAttrs& attrs, bool is_train, + const NDArray &data, const NDArray &weights, + const NDArray *bias, const NDArray &output) { + static thread_local std::unordered_map fwds; + const ConvolutionParam& param = nnvm::get(attrs.parsed); + MKLDNNConvSignature key(param); + key.AddSign(is_train); + // Here we can sign the conv op with NDArray because conv primitive will + // decide the right layout for the, so we only need to get the shape and the + // data type of the arrays. + key.AddSign(data); + key.AddSign(weights); + key.AddSign(output); + if (bias) + key.AddSign(*bias); + + auto it = fwds.find(key); + if (it == fwds.end()) { + MKLDNNConvForward fwd(param, is_train, data, weights, bias, output); + auto ins_ret = fwds.insert( + std::pair(key, fwd)); + CHECK(ins_ret.second); + it = ins_ret.first; + } + return it->second; +} + +void MKLDNNConvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data) { + TmpMemMgr::Get()->Init(ctx.requested[conv::kTempSpace]); + const ConvolutionParam& param = nnvm::get(attrs.parsed); + MKLDNNConvForward &fwd = GetConvFwd(attrs, + ctx.is_train, in_data[conv::kData], in_data[conv::kWeight], + param.no_bias ? nullptr : &in_data[conv::kBias], out_data[conv::kOut]); + + auto data_mem = in_data[conv::kData].GetMKLDNNDataReorder(fwd.fwd_pd.src_primitive_desc()); + const mkldnn::memory *weight_mem; + if (ctx.is_train) { + // TODO(zhengda) kvstore doesn't handle MKLDNN correctly. Let's reorder it + // to the default format for now. + if (in_data[conv::kWeight].IsMKLDNN()) + const_cast(in_data[conv::kWeight]).Reorder2Default(); + weight_mem = GetWeights(in_data[conv::kWeight], fwd.fwd_pd.weights_primitive_desc(), + param.num_group); + } else { + // For inference, we want to reorder the weight array so we don't need to + // reorder data every time. + const_cast(in_data[conv::kWeight]).Reorder( + fwd.fwd_pd.weights_primitive_desc()); + weight_mem = in_data[conv::kWeight].GetMKLDNNData(); + } + auto out_mem = CreateMKLDNNMem(out_data[conv::kOut], fwd.fwd_pd.dst_primitive_desc(), + req[conv::kOut]); + const mkldnn::memory *bias_mem = nullptr; + if (!param.no_bias) + bias_mem = in_data[conv::kBias].GetMKLDNNDataReorder(fwd.fwd_pd.bias_primitive_desc()); + fwd.SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second); + MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd()); + + CommitOutput(out_data[conv::kOut], out_mem); + MKLDNNStream::Get()->Submit(); +} + +void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + TmpMemMgr::Get()->Init(ctx.requested[conv::kTempSpace]); + const std::vector &in_grad = outputs; + const ConvolutionParam& param = nnvm::get(attrs.parsed); + mkldnn::convolution_forward::primitive_desc fwd_pd = GetConvFwdImpl(param, ctx.is_train, + inputs[conv::kData + 1], inputs[conv::kWeight + 1], + param.no_bias ? nullptr : &inputs[conv::kBias + 1], inputs[conv::kOut]); + + CHECK_NE(req[conv::kWeight], kWriteInplace) << "cannot write weight inplace"; + mkldnn::convolution_backward_data::primitive_desc bwdData_pd + = GetConvBwdData(param, inputs[conv::kData + 1], inputs[conv::kWeight + 1], + inputs[conv::kOut], fwd_pd); + auto out_grad_mem = inputs[conv::kOut].GetMKLDNNDataReorder( + bwdData_pd.diff_dst_primitive_desc()); + if (req[conv::kData]) { + auto weight_mem = GetWeights(inputs[conv::kWeight + 1], + bwdData_pd.weights_primitive_desc(), param.num_group); + auto in_grad_mem = CreateMKLDNNMem(in_grad[conv::kData], + bwdData_pd.diff_src_primitive_desc(), req[conv::kData]); + MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_backward_data(bwdData_pd, + *out_grad_mem, *weight_mem, *in_grad_mem.second)); + CommitOutput(in_grad[conv::kData], in_grad_mem); + } + if (req[conv::kWeight]) { + mkldnn::convolution_backward_weights::primitive_desc bwdWeights_pd + = GetConvBwdWeights(param, inputs[conv::kData + 1], inputs[conv::kWeight + 1], + param.no_bias ? nullptr : &inputs[conv::kBias + 1], + inputs[conv::kOut], fwd_pd); + if (bwdData_pd.diff_dst_primitive_desc() != bwdWeights_pd.diff_dst_primitive_desc()) + out_grad_mem = inputs[conv::kOut].GetMKLDNNDataReorder( + bwdWeights_pd.diff_dst_primitive_desc()); + auto data_mem = inputs[conv::kData + 1].GetMKLDNNDataReorder( + bwdWeights_pd.src_primitive_desc()); + auto in_grad_weight = CreateMKLDNNWeightGrad(in_grad[conv::kWeight], + bwdWeights_pd.diff_weights_primitive_desc(), + req[conv::kWeight]); + mkldnn_output_t in_grad_bias; + if (param.no_bias) { + MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_backward_weights( + bwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second)); + } else { + in_grad_bias = CreateMKLDNNMem(in_grad[conv::kBias], + bwdWeights_pd.diff_bias_primitive_desc(), + req[conv::kBias]); + MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_backward_weights( + bwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second, + *in_grad_bias.second)); + CommitOutput(in_grad[conv::kBias], in_grad_bias); + } + CommitOutput(in_grad[conv::kWeight], in_grad_weight); + } + MKLDNNStream::Get()->Submit(); +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_MKLDNN == 1 diff --git a/src/operator/nn/mkldnn/mkldnn_copy.cc b/src/operator/nn/mkldnn/mkldnn_copy.cc new file mode 100644 index 000000000000..71d540c969cd --- /dev/null +++ b/src/operator/nn/mkldnn/mkldnn_copy.cc @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file mkldnn_softmax.cc + * \brief + * \author Da Zheng +*/ + +#include "../softmax-inl.h" +#include "./mkldnn_ops-inl.h" +#include "./mkldnn_base-inl.h" + +#if MXNET_USE_MKLDNN == 1 +namespace mxnet { +namespace op { + +void MKLDNNCopy(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const NDArray &in_data, const OpReqType &req, + const NDArray &out_data) { + TmpMemMgr::Get()->Init(ctx.requested[0]); + auto in_mem = in_data.GetMKLDNNData(); + if (req == kAddTo) { + TmpMemMgr::Get()->Init(ctx.requested[0]); + // We should try and force the output memory has the same format + // as the input memory. If not, we'll have to reorder memory. + auto out_mem = out_data.GetMKLDNNData(in_mem->get_primitive_desc()); + if (out_mem == nullptr) + out_mem = out_data.GetMKLDNNData(); + auto sum_res = TmpMemMgr::Get()->Alloc(out_mem->get_primitive_desc()); + Sum(*in_mem, *out_mem, *sum_res); + const_cast(out_data).CopyFrom(*sum_res); + } else { + const_cast(out_data).CopyFrom(*in_mem); + } + MKLDNNStream::Get()->Submit(); +} + +} // namespace op +} // namespace mxnet +#endif diff --git a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc new file mode 100644 index 000000000000..db0c90d7f9a8 --- /dev/null +++ b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc @@ -0,0 +1,390 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file mkldnn_deconvolution.cc + * \brief + * \author Da Zheng, Rong Zhang (rong.a.zhang@intel.com) +*/ + +#if MXNET_USE_MKLDNN == 1 + +#include "../deconvolution-inl.h" +#include "./mkldnn_ops-inl.h" +#include "./mkldnn_base-inl.h" + +namespace mxnet { +namespace op { + +static inline mkldnn::memory::desc GetBiasDesc(mkldnn::memory::desc md) { + mkldnn::memory::dims dims(1); + // This is convolution on 4D data. The second dimension is the channel. + dims[0] = md.data.dims[1]; + return mkldnn::memory::desc(dims, + static_cast(md.data.data_type), + mkldnn::memory::format::any); +} + +static mkldnn::convolution_forward::primitive_desc GetDeconvBwd_( + const mkldnn::memory::desc &data_md, const mkldnn::memory::desc &weights_md, + bool has_bias, const mkldnn::memory::desc &out_md, + const mkldnn::engine &engine, const mkldnn::memory::dims &strides, + const mkldnn::memory::dims &padding, const mkldnn::memory::dims &dilates) { + if (!has_bias) { + mkldnn::convolution_forward::desc desc(mkldnn::prop_kind::forward_training, + mkldnn::algorithm::convolution_direct, out_md, weights_md, data_md, strides, + dilates, padding, padding, mkldnn::padding_kind::zero); + return mkldnn::convolution_forward::primitive_desc(desc, engine); + } else { + auto bias_md = GetBiasDesc(data_md); + mkldnn::convolution_forward::desc desc(mkldnn::prop_kind::forward_training, + mkldnn::algorithm::convolution_direct, out_md, weights_md, bias_md, + data_md, strides, dilates, padding, padding, mkldnn::padding_kind::zero); + return mkldnn::convolution_forward::primitive_desc(desc, engine); + } +} + +static mkldnn::convolution_backward_data::primitive_desc GetDeconvFwdImpl( + const DeconvolutionParam& param, const NDArray &data, const NDArray &weights, + bool has_bias, const NDArray &output) { + auto data_md = GetMemDesc(data); + auto weight_md = GetWeightDesc(weights, param.num_group); + auto out_md = GetMemDesc(output); + auto engine = CpuEngine::Get()->get_engine(); + mkldnn::memory::dims strides{0, 0}; + if (param.stride.ndim() == 2) { + strides[0] = param.stride[0]; + strides[1] = param.stride[1]; + } else if (param.stride.ndim() == 1) { + strides[0] = param.stride[0]; + strides[1] = param.stride[0]; + } else { + LOG(FATAL) << "Unsupported stride dim"; + } + mkldnn::memory::dims padding{0, 0}; + if (param.pad.ndim() == 2) { + padding[0] = param.pad[0]; + padding[1] = param.pad[1]; + } else if (param.pad.ndim() == 1) { + padding[0] = param.pad[0]; + padding[1] = param.pad[0]; + } else { + LOG(FATAL) << "Unsupported pad dim"; + } + mkldnn::memory::dims dilate{0, 0}; + if (param.dilate.ndim() == 2) { + dilate[0] = param.dilate[0] - 1; + dilate[1] = param.dilate[1] - 1; + } + auto bwd_pd = GetDeconvBwd_(data_md, weight_md, has_bias, out_md, engine, + strides, padding, dilate); + mkldnn::convolution_backward_data::desc desc(mkldnn::algorithm::convolution_direct, + out_md, weight_md, data_md, strides, dilate, padding, padding, + mkldnn::padding_kind::zero); + return mkldnn::convolution_backward_data::primitive_desc(desc, engine, bwd_pd); +} + +static mkldnn::convolution_forward::primitive_desc GetDeconvBwdData( + const DeconvolutionParam ¶m, const NDArray &data, const NDArray &weights, + bool has_bias, const NDArray &output) { + auto data_md = GetMemDesc(data); + auto weight_md = GetWeightDesc(weights, param.num_group); + auto out_md = GetMemDesc(output); + auto engine = CpuEngine::Get()->get_engine(); + mkldnn::memory::dims strides{0, 0}; + if (param.stride.ndim() == 2) { + strides[0] = param.stride[0]; + strides[1] = param.stride[1]; + } else if (param.stride.ndim() == 1) { + strides[0] = param.stride[0]; + strides[1] = param.stride[0]; + } else { + LOG(FATAL) << "Unsupported stride dim"; + } + mkldnn::memory::dims padding{0, 0}; + if (param.pad.ndim() == 2) { + padding[0] = param.pad[0]; + padding[1] = param.pad[1]; + } else if (param.pad.ndim() == 1) { + padding[0] = param.pad[0]; + padding[1] = param.pad[0]; + } else { + LOG(FATAL) << "Unsupported pad dim"; + } + mkldnn::memory::dims dilate{0, 0}; + if (param.dilate.ndim() == 2) { + dilate[0] = param.dilate[0] - 1; + dilate[1] = param.dilate[1] - 1; + } + return GetDeconvBwd_(data_md, weight_md, has_bias, out_md, engine, + strides, padding, dilate); +} + +static mkldnn::convolution_backward_weights::primitive_desc GetDeconvBwdWeights( + const DeconvolutionParam& param, const NDArray &data, const NDArray &weights, + bool has_bias, const NDArray &output, + const mkldnn::convolution_forward::primitive_desc &fwd_pd) { + auto data_md = GetMemDesc(data); + auto weight_md = GetWeightDesc(weights, param.num_group); + auto out_md = GetMemDesc(output); + auto engine = CpuEngine::Get()->get_engine(); + mkldnn::memory::dims strides{0, 0}; + if (param.stride.ndim() == 2) { + strides[0] = param.stride[0]; + strides[1] = param.stride[1]; + } else if (param.stride.ndim() == 1) { + strides[0] = param.stride[0]; + strides[1] = param.stride[0]; + } else { + LOG(FATAL) << "Unsupported stride dim"; + } + mkldnn::memory::dims padding{0, 0}; + if (param.pad.ndim() == 2) { + padding[0] = param.pad[0]; + padding[1] = param.pad[1]; + } else if (param.pad.ndim() == 1) { + padding[0] = param.pad[0]; + padding[1] = param.pad[0]; + } else { + LOG(FATAL) << "Unsupported pad dim"; + } + mkldnn::memory::dims dilate{0, 0}; + if (param.dilate.ndim() == 2) { + dilate[0] = param.dilate[0] - 1; + dilate[1] = param.dilate[1] - 1; + } + if (!has_bias) { + mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct, + out_md, weight_md, data_md, strides, dilate, padding, padding, mkldnn::padding_kind::zero); + return mkldnn::convolution_backward_weights::primitive_desc(desc, engine, fwd_pd); + } else { + auto bias_md = GetBiasDesc(data_md); + mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct, + out_md, weight_md, bias_md, data_md, strides, dilate, padding, padding, + mkldnn::padding_kind::zero); + return mkldnn::convolution_backward_weights::primitive_desc(desc, engine, fwd_pd); + } +} + +class MKLDNNDeconvForward { + std::shared_ptr fwd; + std::shared_ptr data; + std::shared_ptr weight; + std::shared_ptr bias; + std::shared_ptr out; + OutDataOp data_op; + + public: + MKLDNNDeconvForward(const DeconvolutionParam& param, + const NDArray &data, + const NDArray &weights, + bool has_bias, + const NDArray &output); + void SetDataHandle(const DeconvolutionParam& param, + const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data); + + void Execute(const std::vector &out_data); + + private: + mkldnn::convolution_backward_data::primitive_desc fwd_pd; +}; // class MKLDNNDeconvForward + +MKLDNNDeconvForward::MKLDNNDeconvForward(const DeconvolutionParam& param, + const NDArray &data, + const NDArray &weights, + bool has_bias, + const NDArray &output) + :fwd_pd(GetDeconvFwdImpl(param, data, weights, has_bias, output)) { + this->data = std::shared_ptr(new mkldnn::memory( + fwd_pd.diff_dst_primitive_desc())); + this->weight = std::shared_ptr(new mkldnn::memory( + fwd_pd.weights_primitive_desc())); + this->out = std::shared_ptr(new mkldnn::memory( + fwd_pd.diff_src_primitive_desc())); + this->fwd = std::shared_ptr( + new mkldnn::convolution_backward_data(fwd_pd, + mkldnn::primitive::at(*this->data), + mkldnn::primitive::at(*this->weight), + *this->out)); +} + +void MKLDNNDeconvForward::SetDataHandle(const DeconvolutionParam& param, + const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data) { + auto data_mem = in_data[deconv::kData].GetMKLDNNDataReorder( + fwd_pd.diff_dst_primitive_desc()); + const mkldnn::memory *weight_mem; + if (ctx.is_train) { + // TODO(zhengda) kvstore doesn't handle MKLDNN correctly. Let's reorder it + // to the default format for now. + if (in_data[deconv::kWeight].IsMKLDNN()) + const_cast(in_data[deconv::kWeight]).Reorder2Default(); + weight_mem = GetWeights(in_data[deconv::kWeight], + fwd_pd.weights_primitive_desc(), + param.num_group); + } else { + // For inference, we want to reorder the weight array so we don't need to + // reorder data every time. + const_cast(in_data[deconv::kWeight]).Reorder( + fwd_pd.weights_primitive_desc()); + weight_mem = in_data[deconv::kWeight].GetMKLDNNData(); + } + auto out_mem = CreateMKLDNNMem(out_data[deconv::kOut], + fwd_pd.diff_src_primitive_desc(), req[deconv::kOut]); + auto output = out_mem.second; + this->data->set_data_handle(data_mem->get_data_handle()); + this->weight->set_data_handle(weight_mem->get_data_handle()); + this->out->set_data_handle(output->get_data_handle()); + this->data_op = out_mem.first; +} + +void MKLDNNDeconvForward::Execute(const std::vector &out_data) { + MKLDNNStream::Get()->RegisterPrim(*fwd); + CommitOutput(out_data[deconv::kOut], mkldnn_output_t(this->data_op, this->out.get())); + MKLDNNStream::Get()->Submit(); +} + +static void MKLDNNDeconvFwdBiasPostProcess(const DeconvolutionParam& param, + const OpContext &ctx, + const std::vector &in_data, + const std::vector &out_data) { + // add bias, broadcast bias to dim 1: channel + if (!param.no_bias) { + // MKLDNN only supports float right now. + typedef float DType; + Stream *s = ctx.get_stream(); + Tensor bias = in_data[deconv::kBias].data().get(s); + // If the output data is stored in a special MKLDNN format, data() + // automatically converts its format to the default format. + // Unfortunately, MKLDNN doesn't support broadcast. + Tensor out_cpu = out_data[deconv::kOut].data().get(s); + out_cpu += mshadow::expr::broadcast<1>(bias, out_cpu.shape_); + } +} + +typedef MKLDNNParamOpSign MKLDNNDeconvSignature; + +static inline MKLDNNDeconvForward &GetDeconvFwd( + const nnvm::NodeAttrs& attrs, const NDArray &data, + const NDArray &weights, const NDArray *bias, + const NDArray &output) { + static thread_local + std::unordered_map fwds; + const DeconvolutionParam& param = nnvm::get(attrs.parsed); + MKLDNNDeconvSignature key(param); + // Here we can sign the conv op with NDArray because conv primitive will + // decide the right layout for the, so we only need to get the shape and the + // data type of the arrays. + key.AddSign(data); + key.AddSign(weights); + key.AddSign(output); + if (bias) + key.AddSign(*bias); + + auto it = fwds.find(key); + if (it == fwds.end()) { + bool has_bias = (bias != nullptr); + MKLDNNDeconvForward fwd(param, data, weights, has_bias, output); + auto ins_ret = fwds.insert( + std::pair(key, fwd)); + CHECK(ins_ret.second); + it = ins_ret.first; + } + return it->second; +} + +void MKLDNNDeconvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data) { + TmpMemMgr::Get()->Init(ctx.requested[deconv::kTempSpace]); + const DeconvolutionParam& param = nnvm::get(attrs.parsed); + + MKLDNNDeconvForward &deconvFwd = GetDeconvFwd( + attrs, in_data[deconv::kData], in_data[deconv::kWeight], + param.no_bias ? nullptr : &in_data[deconv::kBias], out_data[deconv::kOut]); + + deconvFwd.SetDataHandle(param, ctx, in_data, req, out_data); + + deconvFwd.Execute(out_data); + + MKLDNNDeconvFwdBiasPostProcess(param, ctx, in_data, out_data); +} + +void MKLDNNDeconvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + TmpMemMgr::Get()->Init(ctx.requested[deconv::kTempSpace]); + const std::vector &in_grad = outputs; + const DeconvolutionParam& param = nnvm::get(attrs.parsed); + CHECK_NE(req[deconv::kWeight], kWriteInplace) << "cannot write weight inplace"; + mkldnn::convolution_forward::primitive_desc bwdData_pd = GetDeconvBwdData( + param, inputs[deconv::kData + 1], inputs[deconv::kWeight + 1], false, + inputs[deconv::kOut]); + auto out_grad_mem = inputs[deconv::kOut].GetMKLDNNDataReorder( + bwdData_pd.src_primitive_desc()); + if (req[deconv::kData]) { + auto weight_mem = GetWeights(inputs[deconv::kWeight + 1], + bwdData_pd.weights_primitive_desc(), + param.num_group); + auto in_grad_mem = CreateMKLDNNMem(in_grad[deconv::kData], + bwdData_pd.dst_primitive_desc(), + req[deconv::kData]); + MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_forward(bwdData_pd, + *out_grad_mem, *weight_mem, *in_grad_mem.second)); + CommitOutput(in_grad[deconv::kData], in_grad_mem); + } + if (req[deconv::kWeight]) { + mkldnn::convolution_backward_weights::primitive_desc bwdWeights_pd + = GetDeconvBwdWeights(param, inputs[deconv::kData + 1], + inputs[deconv::kWeight + 1], false, inputs[deconv::kOut], bwdData_pd); + if (bwdData_pd.src_primitive_desc() != bwdWeights_pd.src_primitive_desc()) + out_grad_mem = inputs[deconv::kOut].GetMKLDNNDataReorder( + bwdWeights_pd.src_primitive_desc()); + auto data_mem = inputs[deconv::kData + 1].GetMKLDNNDataReorder( + bwdWeights_pd.diff_dst_primitive_desc()); + auto in_grad_weight = CreateMKLDNNWeightGrad(in_grad[deconv::kWeight], + bwdWeights_pd.diff_weights_primitive_desc(), + req[deconv::kWeight]); + MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_backward_weights( + bwdWeights_pd, *out_grad_mem, *data_mem, *in_grad_weight.second)); + CommitOutput(in_grad[deconv::kWeight], in_grad_weight); + } + MKLDNNStream::Get()->Submit(); + if (!param.no_bias) { + typedef float DType; + Stream *s = ctx.get_stream(); + Tensor gbias = in_grad[deconv::kBias].data().get(s); + // If there is bias, the out grad has already been converted to the default + // format, so this shouldn't cause any performance issues. + Tensor grad = inputs[deconv::kOut].data().get(s); + Assign(gbias, req[deconv::kBias], mshadow::expr::sumall_except_dim<1>(grad)); + } +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_MKLDNN == 1 diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc new file mode 100644 index 000000000000..451b94060a41 --- /dev/null +++ b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file mkldnn_fully_connected.cc + * \brief + * \author Da Zheng +*/ + +#include "../fully_connected-inl.h" +#include "./mkldnn_base-inl.h" + +#if MXNET_USE_MKLDNN == 1 +namespace mxnet { +namespace op { + +inline static mkldnn::inner_product_forward::primitive_desc GetIPFwd( + const NDArray &data, const NDArray &weight, const NDArray *bias, + const mkldnn::memory::desc &out_md) { + auto data_md = GetMemDesc(data); + auto weight_md = GetMemDesc(weight); + auto engine = CpuEngine::Get()->get_engine(); + if (bias) { + auto bias_md = GetMemDesc(*bias); + mkldnn::inner_product_forward::desc ipFwd_desc(mkldnn::prop_kind::forward_training, + data_md, weight_md, bias_md, out_md); + return mkldnn::inner_product_forward::primitive_desc(ipFwd_desc, engine); + } else { + mkldnn::inner_product_forward::desc ipFwd_desc(mkldnn::prop_kind::forward_training, + data_md, weight_md, out_md); + return mkldnn::inner_product_forward::primitive_desc(ipFwd_desc, engine); + } +} + +inline static mkldnn::inner_product_backward_data::primitive_desc GetIpBwdData( + const NDArray &data, const NDArray &weight, const NDArray &output, + mkldnn::inner_product_forward::primitive_desc ipFwd_pd) { + auto data_md = GetMemDesc(data); + auto weight_md = GetMemDesc(weight); + auto out_md = GetMemDesc(output); + auto engine = CpuEngine::Get()->get_engine(); + mkldnn::inner_product_backward_data::desc desc(data_md, weight_md, out_md); + return mkldnn::inner_product_backward_data::primitive_desc(desc, engine, ipFwd_pd); +} + +inline static mkldnn::inner_product_backward_weights::primitive_desc GetIPBwdWeights( + const NDArray &data, const NDArray &weight, const NDArray *bias, + const NDArray &output, mkldnn::inner_product_forward::primitive_desc ipFwd_pd) { + auto data_md = GetMemDesc(data); + auto weight_md = GetMemDesc(weight); + auto out_md = GetMemDesc(output); + auto engine = CpuEngine::Get()->get_engine(); + if (bias) { + auto bias_md = GetMemDesc(*bias); + mkldnn::inner_product_backward_weights::desc ipBwdWeights_desc(data_md, + weight_md, bias_md, out_md); + return mkldnn::inner_product_backward_weights::primitive_desc( + ipBwdWeights_desc, engine, ipFwd_pd); + } else { + mkldnn::inner_product_backward_weights::desc ipBwdWeights_desc(data_md, + weight_md, out_md); + return mkldnn::inner_product_backward_weights::primitive_desc( + ipBwdWeights_desc, engine, ipFwd_pd); + } +} + +void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data) { + TmpMemMgr::Get()->Init(ctx.requested[fullc::kTempSpace]); + const FullyConnectedParam& param = nnvm::get(attrs.parsed); + const TShape& ishape = in_data[fullc::kData].shape(); + const TShape& oshape = out_data[fullc::kOut].shape(); + NDArray weight = in_data[fullc::kWeight]; + NDArray data = in_data[fullc::kData]; + auto out_md = GetMemDesc(out_data[fullc::kOut]); + if (data.shape().ndim() != 2 && !param.flatten) { + data = data.ReshapeMKLDNN(Shape2(ishape.ProdShape(0, ishape.ndim()-1), + ishape[ishape.ndim()-1])); + mkldnn::memory::dims out_dims{static_cast(oshape.ProdShape(0, oshape.ndim()-1)), + static_cast(oshape[ishape.ndim()-1])}; + out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(out_data[fullc::kOut].dtype()), + mkldnn::memory::format::any); + } else if (data.shape().ndim() != 2) { + data = data.ReshapeMKLDNN(Shape2(ishape[0], ishape.ProdShape(1, ishape.ndim()))); + mkldnn::memory::dims out_dims{static_cast(oshape[0]), + static_cast(oshape.ProdShape(1, oshape.ndim()))}; + out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(out_data[fullc::kOut].dtype()), + mkldnn::memory::format::any); + } + + mkldnn::inner_product_forward::primitive_desc ipFwd_pd = GetIPFwd(data, weight, + param.no_bias ? nullptr : &in_data[fullc::kBias], out_md); + auto data_mem = data.GetMKLDNNDataReorder(ipFwd_pd.src_primitive_desc()); + auto weight_mem = weight.GetMKLDNNDataReorder(ipFwd_pd.weights_primitive_desc()); + auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut], + ipFwd_pd.dst_primitive_desc(), req[fullc::kOut]); + if (param.no_bias) { + MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_forward( + ipFwd_pd, *data_mem, *weight_mem, *out_mem.second)); + } else { + auto bias_mem = in_data[fullc::kBias].GetMKLDNNDataReorder(ipFwd_pd.bias_primitive_desc()); + MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_forward(ipFwd_pd, + *data_mem, *weight_mem, *bias_mem, *out_mem.second)); + } + CommitOutput(out_data[fullc::kOut], out_mem); + MKLDNNStream::Get()->Submit(); +} + +void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + TmpMemMgr::Get()->Init(ctx.requested[fullc::kTempSpace]); + const std::vector &in_grad = outputs; + const FullyConnectedParam& param = nnvm::get(attrs.parsed); + const TShape& ishape = inputs[fullc::kData + 1].shape(); + const TShape& oshape = inputs[fullc::kOut].shape(); + + NDArray weight = inputs[fullc::kWeight + 1]; + NDArray data = inputs[fullc::kData + 1]; + if (data.shape().ndim() != 2 && !param.flatten) + data = data.ReshapeMKLDNN(Shape2(ishape.ProdShape(0, ishape.ndim()-1), + ishape[ishape.ndim()-1])); + else if (data.shape().ndim() != 2) + data = data.ReshapeMKLDNN(Shape2(ishape[0], + ishape.ProdShape(1, ishape.ndim()))); + NDArray out_grad = inputs[fullc::kOut]; + if (out_grad.shape().ndim() != 2 && !param.flatten) + out_grad = out_grad.ReshapeMKLDNN(Shape2(oshape.ProdShape(0, oshape.ndim()-1), + oshape[oshape.ndim()-1])); + else if (out_grad.shape().ndim() != 2) + out_grad = out_grad.ReshapeMKLDNN(Shape2(oshape[0], + oshape.ProdShape(1, oshape.ndim()))); + + mkldnn::inner_product_forward::primitive_desc ipFwd_pd = GetIPFwd(data, weight, + param.no_bias ? nullptr : &in_grad[fullc::kBias], GetMemDesc(out_grad)); + + CHECK_NE(req[fullc::kWeight], kWriteInplace) << "cannot write weight inplace"; + if (req[fullc::kData]) { + mkldnn::inner_product_backward_data::primitive_desc ipBwdData_pd = GetIpBwdData( + data, weight, out_grad, ipFwd_pd); + auto out_grad_mem = out_grad.GetMKLDNNDataReorder( + ipBwdData_pd.diff_dst_primitive_desc()); + auto weight_mem = weight.GetMKLDNNDataReorder(ipBwdData_pd.weights_primitive_desc()); + auto in_grad_mem = CreateMKLDNNMem(in_grad[fullc::kData], + ipBwdData_pd.diff_src_primitive_desc(), + req[fullc::kData]); + MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_data( + ipBwdData_pd, *out_grad_mem, *weight_mem, *in_grad_mem.second)); + CommitOutput(in_grad[fullc::kData], in_grad_mem); + } + if (req[fullc::kWeight]) { + mkldnn::inner_product_backward_weights::primitive_desc ipBwdWeights_pd + = GetIPBwdWeights(data, weight, param.no_bias ? nullptr : &in_grad[fullc::kBias], + out_grad, ipFwd_pd); + auto out_grad_mem = out_grad.GetMKLDNNDataReorder( + ipBwdWeights_pd.diff_dst_primitive_desc()); + auto data_mem = data.GetMKLDNNDataReorder(ipBwdWeights_pd.src_primitive_desc()); + auto in_grad_weight = CreateMKLDNNWeightGrad(in_grad[fullc::kWeight], + ipBwdWeights_pd.diff_weights_primitive_desc(), + req[fullc::kWeight]); + mkldnn_output_t in_grad_bias; + if (param.no_bias) { + MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_weights( + ipBwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second)); + } else { + in_grad_bias = CreateMKLDNNMem(in_grad[fullc::kBias], + ipBwdWeights_pd.diff_bias_primitive_desc(), + req[fullc::kBias]); + MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_weights( + ipBwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second, + *in_grad_bias.second)); + } + CommitOutput(in_grad[fullc::kWeight], in_grad_weight); + CommitOutput(in_grad[fullc::kBias], in_grad_bias); + } + MKLDNNStream::Get()->Submit(); +} + +} // namespace op +} // namespace mxnet +#endif // MXNET_USE_MKLDNN == 1 diff --git a/src/operator/nn/mkldnn/mkldnn_lrn-inl.h b/src/operator/nn/mkldnn/mkldnn_lrn-inl.h new file mode 100644 index 000000000000..e0ecc1873d96 --- /dev/null +++ b/src/operator/nn/mkldnn/mkldnn_lrn-inl.h @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file mkldnn_lrn-inl.h + * \brief + * \Author: Patric Zhao, patric.zhao@intel.com +*/ +#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_LRN_INL_H_ +#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_LRN_INL_H_ + +#if MXNET_USE_MKLDNN == 1 +#include +#include "../lrn-inl.h" +#include "./mkldnn_base-inl.h" + +namespace mxnet { +namespace op { + +static inline algorithm GetMKLDNNLRNAlgo(const LRNParam ¶m) { + // TODO(Patric): lrn_within_channel will cause core dump in MKLDNN backward + // Need to confirm with MKLDNN team and fix later + return algorithm::lrn_across_channels; +} + +inline static lrn_forward::primitive_desc GetLRNFwd( + const LRNParam ¶m, bool is_train, const memory::desc &src_md) { + auto engine = CpuEngine::Get()->get_engine(); + auto alg_ = GetMKLDNNLRNAlgo(param); + auto alpha_ = param.alpha; + auto beta_ = param.beta; + auto nsize_ = param.nsize; + auto k_ = param.knorm; + auto kind_ = prop_kind::forward_training; + if (is_train) { + kind_ = prop_kind::forward_training; + } else { + kind_ = prop_kind::forward_scoring; + } + lrn_forward::desc fwd_desc_(kind_, alg_, src_md, nsize_, alpha_, beta_, k_); + return mkldnn::lrn_forward::primitive_desc(fwd_desc_, engine); +} + +inline static mkldnn::lrn_backward::primitive_desc GetLRNBwd( + const LRNParam ¶m, const mkldnn::memory::desc &diff_in_md, + const mkldnn::memory::desc &diff_md, + const lrn_forward::primitive_desc &lrnFwd_desc) { + auto engine = CpuEngine::Get()->get_engine(); + auto alg_ = GetMKLDNNLRNAlgo(param); + auto alpha_ = param.alpha; + auto beta_ = param.beta; + int nsize_ = param.nsize; + auto k_ = param.knorm; + + lrn_backward::desc lrnBwd_desc(alg_, diff_in_md, + diff_md, nsize_, alpha_, beta_, k_); + return mkldnn::lrn_backward::primitive_desc(lrnBwd_desc, + engine, lrnFwd_desc); +} + +void MKLDNNLRN_Forward(const OpContext &ctx, const LRNParam ¶m, + const NDArray &in_data, const OpReqType &req, + const NDArray &out_data) { + auto src_mem = in_data.GetMKLDNNData(); + auto src_md = src_mem->get_primitive_desc().desc(); + auto pdesc = GetLRNFwd(param, ctx.is_train, src_md); + auto dst_mem = const_cast(out_data).CreateMKLDNNData( + pdesc.dst_primitive_desc()); + if (ctx.is_train) { + std::shared_ptr ws_mem( + new mkldnn::memory(pdesc.workspace_primitive_desc())); + MKLDNNStream::Get()->RegisterPrim( + lrn_forward(pdesc, mkldnn::primitive::at(*src_mem), + *ws_mem, *dst_mem)); + MKLDNNStream::Get()->Submit(); + } else { + MKLDNNStream::Get()->RegisterPrim( + lrn_forward(pdesc, mkldnn::primitive::at(*src_mem), *dst_mem)); + MKLDNNStream::Get()->Submit(); + } +} + +void MKLDNNLRN_Backward(const OpContext &ctx, const LRNParam ¶m, + const NDArray &out_grad, + const NDArray &in_data, + const OpReqType &req, + const NDArray &in_grad) { + if (req == kNullOp) { + return; + } + // Repeat FW for getting workspace + auto data_mem = in_data.GetMKLDNNData(); + auto data_md = data_mem->get_primitive_desc().desc(); + auto pdesc_fwd = GetLRNFwd(param, ctx.is_train, data_md); + + // TODO(Patric): To keep the function stateless, we can't pass workspace + // from LRN forward to backward. We have to re-compute + // LRN forward to get the workspace. + // Will refine this code later. + std::shared_ptr ws_mem( + new mkldnn::memory(pdesc_fwd.workspace_primitive_desc())); + std::shared_ptr dst_temp( + new mkldnn::memory(pdesc_fwd.dst_primitive_desc())); + MKLDNNStream::Get()->RegisterPrim( + lrn_forward(pdesc_fwd, mkldnn::primitive::at(*data_mem), + *ws_mem, *dst_temp)); + + auto data_in_md = pdesc_fwd.src_primitive_desc().desc(); + auto diff_mem = out_grad.GetMKLDNNData(); + auto diff_md = diff_mem->get_primitive_desc().desc(); + auto pdesc_bwd = GetLRNBwd(param, data_in_md, diff_md, pdesc_fwd); + auto diff_src_mem = CreateMKLDNNMem(in_grad, + pdesc_bwd.diff_src_primitive_desc(), req); + + MKLDNNStream::Get()->RegisterPrim( + lrn_backward(pdesc_bwd, mkldnn::primitive::at(*data_mem), + mkldnn::primitive::at(*diff_mem), *ws_mem, *diff_src_mem.second)); + MKLDNNStream::Get()->Submit(); +} +} // namespace op +} // namespace mxnet +#endif // MXNET_USE_MKLDNN == 1 +#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_LRN_INL_H__ diff --git a/src/operator/nn/mkldnn/mkldnn_ops-inl.h b/src/operator/nn/mkldnn/mkldnn_ops-inl.h new file mode 100644 index 000000000000..9149cb0c6a94 --- /dev/null +++ b/src/operator/nn/mkldnn/mkldnn_ops-inl.h @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file mkldnn_ops-inl.h + * \brief + * \author Da Zheng +*/ + +#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_OPS_INL_H_ +#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_OPS_INL_H_ + +#if MXNET_USE_MKLDNN == 1 + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace mxnet { +namespace op { + +/* For fully connected. */ +void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data); +void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs); + +/* For convolution. */ +void MKLDNNConvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data); +void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); + +/* For deconvolution */ +void MKLDNNDeconvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data); +void MKLDNNDeconvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); + +/* For softmax */ +void MKLDNNSoftmaxForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const NDArray &in_data, const OpReqType &req, + const NDArray &out_data); + +/* For sum */ +void MKLDNNSumForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const std::vector &inputs, const OpReqType &req, + const NDArray &out_data); + +/* For copy */ +void MKLDNNCopy(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const NDArray &in_data, const OpReqType &req, + const NDArray &out_data); + +/* For concat */ +void MKLDNNConcatForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data); +void MKLDNNConcatBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); + +/* For activation */ +void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const NDArray &in_data, const OpReqType &req, + const NDArray &out_data); +void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const NDArray &out_grad, const NDArray &in_data, + const OpReqType &req, const NDArray &in_grad); + +void Sum(const mkldnn::memory &arr1, const mkldnn::memory &arr2, + const mkldnn::memory &out); + +} // namespace op +} // namespace mxnet +#endif // MXNET_USE_MKLDNN == 1 + +#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_OPS_INL_H_ diff --git a/src/operator/nn/mkldnn/mkldnn_pooling-inl.h b/src/operator/nn/mkldnn/mkldnn_pooling-inl.h new file mode 100644 index 000000000000..6947f66ee424 --- /dev/null +++ b/src/operator/nn/mkldnn/mkldnn_pooling-inl.h @@ -0,0 +1,393 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file mkldnn_pooling-inl.h + * \brief +*/ +#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_POOLING_INL_H_ +#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_POOLING_INL_H_ + +#if MXNET_USE_MKLDNN == 1 + +#include +#include +#include "../pooling-inl.h" +#include "./mkldnn_base-inl.h" + +namespace mxnet { +namespace op { + +class MKLDNNPoolingFwd { + public: + MKLDNNPoolingFwd(const mxnet::NDArray &input, + const mxnet::NDArray &output, + int kernel_h, int kernel_w, + int stride_h, int stride_w, + int padding_t, int padding_b, int padding_l, int padding_r, + mkldnn::algorithm alg_kind, + bool with_workspace, bool is_train) : + _is_train(is_train), + _with_workspace(with_workspace), + _alg_kind(alg_kind), + fwd(nullptr), data(nullptr), out(nullptr), workspace(nullptr) { + _Init(input, output, + kernel_h, kernel_w, stride_h, stride_w, + padding_t, padding_b, padding_l, padding_r); + } + + ~MKLDNNPoolingFwd() {} + void SetDataHandle(const mxnet::NDArray &data, + const mxnet::NDArray &output, + const mxnet::NDArray *workspace = nullptr); + void Execute(); + + private: + bool _is_train; + bool _with_workspace; + mkldnn::algorithm _alg_kind; + std::shared_ptr fwd_pd; + std::shared_ptr fwd; + std::shared_ptr data; + std::shared_ptr out; + std::shared_ptr workspace; + + private: + void _Init(const mxnet::NDArray &input, + const mxnet::NDArray &output, + int kernel_h, int kernel_w, + int stride_h, int stride_w, + int padding_t, int padding_b, int padding_l, int padding_r); +}; + +void MKLDNNPoolingFwd::_Init(const mxnet::NDArray &input, const mxnet::NDArray &output, + int kernel_h, int kernel_w, int stride_h, int stride_w, + int padding_t, int padding_b, int padding_l, int padding_r) { + auto src_md = input.GetMKLDNNData()->get_primitive_desc().desc(); + mkldnn::memory::dims dims = {src_md.data.dims[0], + src_md.data.dims[1], + static_cast(output.shape()[2]), + static_cast(output.shape()[3])}; + auto dst_md = mkldnn::memory::desc({dims}, + static_cast(src_md.data.data_type), + static_cast(src_md.data.format)); + auto engine = CpuEngine::Get()->get_engine(); + auto alg_kind = this->_alg_kind; + if (alg_kind != pooling_max && + alg_kind != pooling_avg && + alg_kind != pooling_avg_include_padding && + alg_kind != pooling_avg_exclude_padding) { + LOG(FATAL) << "MKLDNN Pooling: algorithm is not supported"; + } + + auto prop = mkldnn::prop_kind::forward_scoring; + if (this->_is_train && alg_kind != mkldnn::algorithm::pooling_avg) { + prop = mkldnn::prop_kind::forward_training; + } + + if (this->_is_train && prop == mkldnn::prop_kind::forward_scoring) { + LOG(INFO) << "MKLDNN Pooling: training with prop_kind is forward_scoring"; + } + + mkldnn::memory::dims strides = {stride_h, stride_w }; + mkldnn::memory::dims pad_l = {padding_t, padding_l }; + mkldnn::memory::dims pad_r = {padding_b, padding_r }; + mkldnn::memory::dims kernel = {kernel_h, kernel_w }; + + auto fwd_desc = mkldnn::pooling_forward::desc(prop, alg_kind, src_md, dst_md, + strides, kernel, pad_l, pad_r, + mkldnn::padding_kind::zero); + this->fwd_pd.reset(new mkldnn::pooling_forward::primitive_desc(fwd_desc, engine)); + this->data.reset(new mkldnn::memory(input.GetMKLDNNData()->get_primitive_desc())); + this->out.reset(new mkldnn::memory(this->fwd_pd->dst_primitive_desc())); + if (this->_with_workspace) { + this->workspace.reset(new mkldnn::memory(this->fwd_pd->workspace_primitive_desc())); + this->fwd.reset(new mkldnn::pooling_forward(*(this->fwd_pd), + mkldnn::primitive::at(*(this->data)), + *(this->out), + *(this->workspace))); + } else { + this->fwd.reset(new mkldnn::pooling_forward(*(fwd_pd), + mkldnn::primitive::at(*(this->data)), + *(this->out))); + } + return; +} + +void MKLDNNPoolingFwd::SetDataHandle(const mxnet::NDArray &data, + const mxnet::NDArray &output, + const mxnet::NDArray *workspace) { + auto data_mem = data.GetMKLDNNData(); + auto out_mem = const_cast(output).CreateMKLDNNData( + this->fwd_pd->dst_primitive_desc()); + this->data->set_data_handle(data_mem->get_data_handle()); + this->out->set_data_handle(out_mem->get_data_handle()); + if (this->_with_workspace && workspace == nullptr) { + LOG(FATAL) << "MKLDNN Pooling: incorrect workspace input"; + } + + if (this->_with_workspace) { + // auto ws_mem = const_cast(workspace)->CreateMKLDNNData( + // this->fwd_pd->workspace_primitive_desc()); + auto ws_mem = workspace->GetMKLDNNData(); + this->workspace->set_data_handle(ws_mem->get_data_handle()); + } +} + +void MKLDNNPoolingFwd::Execute() { + if (this->fwd) { + MKLDNNStream::Get()->RegisterPrim(*(this->fwd)); + MKLDNNStream::Get()->Submit(); + } else { + LOG(FATAL) << "MKLDNN Pooling: forward primitive is nullptr"; + } +} + +static inline bool SupportMKLDNNPooling(const PoolingParam ¶m) { + return param.kernel.ndim() == 2 + && (param.pool_type == pool_enum::kMaxPooling + || param.pool_type == pool_enum::kAvgPooling); +} + +static inline bool SupportMKLDNNPooling(const PoolingParam ¶m, + const TShape &dshape) { + auto ret = SupportMKLDNNPooling(param); + if (!ret) + return false; + if (param.pooling_convention == pool_enum::kValid) + return true; + if ((dshape[2] + 2 * param.pad[0] - param.kernel[0]) % param.stride[0] == 0 + && (dshape[3] + 2 * param.pad[1] - param.kernel[1]) % param.stride[1] == 0) + return true; + else + return false; +} + +static inline mkldnn::algorithm +GetMKLDNNPoolAlgo(const PoolingParam ¶m) { + switch (param.pool_type) { + case pool_enum::kMaxPooling: + return mkldnn::algorithm::pooling_max; + break; + case pool_enum::kAvgPooling: + return mkldnn::algorithm::pooling_avg_include_padding; + break; + default: + LOG(FATAL) << "MKLDNN Pooling: Unknown pooling method."; + return mkldnn::algorithm::pooling_max; + } +} + +inline static mkldnn::pooling_forward::primitive_desc +GetPoolingFwd(const PoolingParam ¶m, + bool is_train, + const memory::desc &data_md, + const memory::desc &out_md) { + CHECK_EQ(param.kernel.ndim(), 2) << "Not Implemented"; + int kernel_h_, kernel_w_; + if (param.global_pool) { + kernel_h_ = data_md.data.dims[2]; + kernel_w_ = data_md.data.dims[3]; + } else { + kernel_h_ = param.kernel[0]; + kernel_w_ = param.kernel[1]; + } + + CHECK_GT(kernel_h_, 0) << "Filter dimensions cannot be zero."; + CHECK_GT(kernel_w_, 0) << "Filter dimensions cannot be zero."; + + auto pad_t_ = param.pad[0], pad_b_ = param.pad[0]; + auto pad_l_ = param.pad[1], pad_r_ = param.pad[1]; + auto stride_h_ = param.stride[0], stride_w_ = param.stride[1]; + + auto engine = CpuEngine::Get()->get_engine(); + if (param.global_pool) { + CHECK(pad_t_ == 0 && pad_l_ == 0 && stride_h_ == 1 && stride_w_ == 1) + << "With Global_pooling: true; only pad = 0 and stride = 1"; + } + if (pad_t_ != 0 || pad_l_ != 0) { + CHECK(param.pool_type == pool_enum::kAvgPooling || + param.pool_type == pool_enum::kMaxPooling) + << "Padding implemented only for average and max pooling."; + CHECK_LT(pad_l_, kernel_w_); + CHECK_LT(pad_t_, kernel_h_); + } + + auto alg = GetMKLDNNPoolAlgo(param); + auto kind = prop_kind::forward_scoring; + if (is_train && alg != algorithm::pooling_avg) { + kind = prop_kind::forward_training; + } + + pooling_forward::desc poolingFwd_desc(kind, alg, data_md, out_md, + {static_cast(stride_h_), + static_cast(stride_w_)}, + {kernel_h_, kernel_w_}, + {static_cast(pad_t_), + static_cast(pad_l_)}, + {static_cast(pad_b_), + static_cast(pad_r_)}, + padding_kind::zero); + return mkldnn::pooling_forward::primitive_desc(poolingFwd_desc, engine); +} + +inline bool MKLDNNRequireWorkspace(const PoolingParam ¶m) { + return param.pool_type != pool_enum::kAvgPooling; +} + +typedef MKLDNNParamOpSign MKLDNNPoolingSignature; + +static inline MKLDNNPoolingFwd &GetPoolingFwd(const PoolingParam ¶m, + bool is_train, + const NDArray &data, + const NDArray &output) { + static thread_local std::unordered_map pooling_fwds; + + bool with_workspace = is_train && MKLDNNRequireWorkspace(param); + MKLDNNPoolingSignature key(param); + key.AddSign(is_train); + key.AddSign(with_workspace); + key.AddSign(data); + key.AddSign(output); + + auto it = pooling_fwds.find(key); + if (it == pooling_fwds.end()) { + CHECK_EQ(param.kernel.ndim(), 2) << "Not Implemented"; + auto data_md = data.GetMKLDNNData()->get_primitive_desc().desc(); + int kernel_h_, kernel_w_; + if (param.global_pool) { + kernel_h_ = data_md.data.dims[2]; + kernel_w_ = data_md.data.dims[3]; + } else { + kernel_h_ = param.kernel[0]; + kernel_w_ = param.kernel[1]; + } + + CHECK_GT(kernel_h_, 0) << "Filter dimensions cannot be zero."; + CHECK_GT(kernel_w_, 0) << "Filter dimensions cannot be zero."; + + auto pad_t_ = param.pad[0], pad_b_ = param.pad[0]; + auto pad_l_ = param.pad[1], pad_r_ = param.pad[1]; + auto stride_h_ = param.stride[0], stride_w_ = param.stride[1]; + + if (param.global_pool) { + CHECK(pad_t_ == 0 && pad_l_ == 0 && stride_h_ == 1 && stride_w_ == 1) + << "With Global_pooling: true; only pad = 0 and stride = 1"; + } + + if (pad_t_ != 0 || pad_l_ != 0) { + CHECK(param.pool_type == pool_enum::kAvgPooling || + param.pool_type == pool_enum::kMaxPooling) + << "Padding implemented only for average and max pooling."; + CHECK_LT(pad_l_, kernel_w_); + CHECK_LT(pad_t_, kernel_h_); + } + + auto alg = GetMKLDNNPoolAlgo(param); + MKLDNNPoolingFwd fwd(data, output, kernel_h_, kernel_w_, stride_h_, stride_w_, + pad_t_, pad_b_, pad_l_, pad_r_, alg, with_workspace, is_train); + auto ins_ret = pooling_fwds.insert( + std::pair(key, fwd)); + CHECK(ins_ret.second); + it = ins_ret.first; + } + return it->second; +} + +void MKLDNNPoolingCompute(const OpContext &ctx, const PoolingParam ¶m, + const NDArray &in_data, const OpReqType &req, + const NDArray &out_data, const NDArray *workspace) { + auto fwd = GetPoolingFwd(param, ctx.is_train, in_data, out_data); + fwd.SetDataHandle(in_data, out_data, workspace); + fwd.Execute(); +} + +void MKLDNNPoolingGradCompute(const OpContext &ctx, const PoolingParam ¶m, + const NDArray &out_grad, const NDArray &in_data, + const NDArray *workspace, const OpReqType &req, + const NDArray &in_grad) { + if (req == kNullOp) { + return; + } + + TmpMemMgr::Get()->Init(ctx.requested[0]); + auto diff_dst_mem = out_grad.GetMKLDNNData(); + auto input_mem = in_data.GetMKLDNNData(); + mkldnn::memory::primitive_desc data_mpd = input_mem->get_primitive_desc(); + mkldnn::memory::desc data_md = data_mpd.desc(); + memory::dims dims = {data_md.data.dims[0], data_md.data.dims[1], + static_cast(out_grad.shape()[2]), + static_cast(out_grad.shape()[3])}; + memory::desc out_md({dims}, + static_cast(data_md.data.data_type), + static_cast(data_md.data.format)); + auto pdesc_fwd = GetPoolingFwd(param, ctx.is_train, data_md, out_md); + + mkldnn::memory::desc diff_md = diff_dst_mem->get_primitive_desc().desc(); + memory::dims dims1 = {diff_md.data.dims[0], diff_md.data.dims[1], + static_cast(in_grad.shape()[2]), + static_cast(in_grad.shape()[3])}; + memory::desc diff_in_md( + {dims1}, static_cast(diff_md.data.data_type), + static_cast(diff_md.data.format)); + auto cpu_engine = data_mpd.get_engine(); + + auto alg = GetMKLDNNPoolAlgo(param); + + int kernel_h_, kernel_w_; + if (param.global_pool) { + kernel_h_ = data_md.data.dims[2]; + kernel_w_ = data_md.data.dims[3]; + } else { + kernel_h_ = param.kernel[0]; + kernel_w_ = param.kernel[1]; + } + pooling_backward::desc desc(alg, diff_in_md, diff_md, + {static_cast(param.stride[0]), + static_cast(param.stride[1])}, + {kernel_h_, kernel_w_}, + {static_cast(param.pad[0]), + static_cast(param.pad[1])}, + {static_cast(param.pad[0]), + static_cast(param.pad[1])}, + padding_kind::zero); + pooling_backward::primitive_desc pdesc(desc, cpu_engine, pdesc_fwd); + + auto diff_src_mem = + CreateMKLDNNMem(in_grad, pdesc.diff_src_primitive_desc(), req); + + if (MKLDNNRequireWorkspace(param)) { + CHECK(workspace != nullptr); + auto workspace_mem = workspace->GetMKLDNNData(); + MKLDNNStream::Get()->RegisterPrim( + pooling_backward(pdesc, *diff_dst_mem, primitive::at(*workspace_mem), + *diff_src_mem.second)); + } else { + MKLDNNStream::Get()->RegisterPrim( + pooling_backward(pdesc, *diff_dst_mem, *diff_src_mem.second)); + } + CommitOutput(in_grad, diff_src_mem); + MKLDNNStream::Get()->Submit(); +} +} // namespace op +} // namespace mxnet +#endif // MXNET_USE_MKLDNN == 1 +#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_POOLING_INL_H_ diff --git a/src/operator/nn/mkldnn/mkldnn_softmax.cc b/src/operator/nn/mkldnn/mkldnn_softmax.cc new file mode 100644 index 000000000000..aa59f13d06da --- /dev/null +++ b/src/operator/nn/mkldnn/mkldnn_softmax.cc @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file mkldnn_softmax.cc + * \brief + * \author Da Zheng +*/ + +#include "../softmax-inl.h" +#include "./mkldnn_ops-inl.h" +#include "./mkldnn_base-inl.h" + +#if MXNET_USE_MKLDNN == 1 +namespace mxnet { +namespace op { + +void MKLDNNSoftmaxForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const NDArray &in_data, const OpReqType &req, + const NDArray &out_data) { + const SoftmaxParam& param = nnvm::get(attrs.parsed); + auto input_mem = in_data.GetMKLDNNData(); + mkldnn::memory::primitive_desc data_mpd = input_mem->get_primitive_desc(); + mkldnn::memory::desc data_md = data_mpd.desc(); + auto cpu_engine = data_mpd.get_engine(); + auto prop = ctx.is_train + ? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring; + mkldnn::softmax_forward::desc desc = mkldnn::softmax_forward::desc(prop, + data_md, param.axis); + mkldnn::softmax_forward::primitive_desc pdesc(desc, cpu_engine); + + auto output_memory = out_data.GetMKLDNNData(); + MKLDNNStream *stream = MKLDNNStream::Get(); + stream->RegisterPrim(mkldnn::softmax_forward(pdesc, *input_mem, *output_memory)); + stream->Submit(); +} + +} // namespace op +} // namespace mxnet +#endif diff --git a/src/operator/nn/mkldnn/mkldnn_sum.cc b/src/operator/nn/mkldnn/mkldnn_sum.cc new file mode 100644 index 000000000000..1efc285b808f --- /dev/null +++ b/src/operator/nn/mkldnn/mkldnn_sum.cc @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file mkldnn_sum.cc + * \brief + * \author Da Zheng +*/ +#include + +#include "./mkldnn_ops-inl.h" +#include "./mkldnn_base-inl.h" + +#if MXNET_USE_MKLDNN == 1 +namespace mxnet { +namespace op { + +void Sum(const mkldnn::memory &arr1, const mkldnn::memory &arr2, + const mkldnn::memory &out) { + std::vector input_pds(2); + std::vector scales(2); + std::vector inputs; + input_pds[0] = arr1.get_primitive_desc(); + input_pds[1] = arr2.get_primitive_desc(); + CHECK(input_pds[0] == input_pds[1]); + scales[0] = 1; + scales[1] = 1; + inputs.push_back(arr1); + inputs.push_back(arr2); + // TODO(zhengda) I need to reorder memory here. + mkldnn::sum::primitive_desc sum_pd(scales, input_pds); + MKLDNNStream::Get()->RegisterPrim(mkldnn::sum(sum_pd, inputs, out)); +} + +void MKLDNNSumForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const std::vector &inputs, const OpReqType &req, + const NDArray &out_data) { + TmpMemMgr::Get()->Init(ctx.requested[0]); + std::vector in_prims; + std::vector in_pds(inputs.size()); + std::vector scales(inputs.size()); + for (size_t i = 0; i < inputs.size(); i++) { + auto in_mem = inputs[i].GetMKLDNNData(); + in_prims.push_back(*in_mem); + in_pds[i] = in_mem->get_primitive_desc(); + scales[i] = 1; + } + mkldnn::sum::primitive_desc pdesc(scales, in_pds); + + auto out_mem = CreateMKLDNNMem(out_data, pdesc.dst_primitive_desc(), req); + MKLDNNStream *stream = MKLDNNStream::Get(); + stream->RegisterPrim(mkldnn::sum(pdesc, in_prims, *out_mem.second)); + CommitOutput(out_data, out_mem); + stream->Submit(); +} + +} // namespace op +} // namespace mxnet +#endif diff --git a/src/operator/nn/pooling-inl.h b/src/operator/nn/pooling-inl.h index 3f511dfaacd9..9b0f18456e1d 100644 --- a/src/operator/nn/pooling-inl.h +++ b/src/operator/nn/pooling-inl.h @@ -78,8 +78,46 @@ struct PoolingParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(pad).set_default(TShape()) .describe("Pad for pooling: (y, x) or (d, y, x). Defaults to no padding."); } + + bool operator==(const PoolingParam& other) const { + return this->kernel == other.kernel && + this->stride == other.stride && + this->pad == other.pad && + this->pool_type == other.pool_type && + this->pooling_convention == other.pooling_convention && + this->global_pool == other.global_pool && + this->cudnn_off == other.cudnn_off; + } + +#if MXNET_USE_MKLDNN == 1 + static uint64_t ComputeHash(const TShape &shape) { + uint64_t hash = 0; + for (size_t i = 0; i < shape.ndim(); i++) + hash = hash * 2 + shape[i]; + return hash; + } + + uint64_t GetHash() const { + uint64_t hash = 0; + hash = hash * 2 + ComputeHash(kernel); + hash = hash * 2 + ComputeHash(stride); + hash = hash * 2 + ComputeHash(pad); + hash = hash * 2 + pool_type; + hash = hash * 2 + pooling_convention; + hash = hash * 2 + global_pool; + hash = hash * 2 + cudnn_off; + return hash; + } +#endif }; +/* + * When MKLDNN is enabled, we might want 2 outputs instead of one inputs, which + * also changes the number of inputs for backward. + */ +int GetNumOutputs(const PoolingParam ¶m); +int GetNumBackInputs(const PoolingParam ¶m); + template void PoolingForward(const OpContext& ctx, const PoolingParam ¶m, const TBlob& in_data, const OpReqType& req, @@ -122,9 +160,9 @@ void PoolingCompute(const nnvm::NodeAttrs& attrs, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - CHECK_EQ(inputs.size(), 1U); - CHECK_EQ(outputs.size(), 1U); const PoolingParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), GetNumOutputs(param)); MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { if (pool_enum::kMaxPooling == param.pool_type || pool_enum::kAvgPooling == param.pool_type @@ -142,16 +180,28 @@ void PoolingGradCompute(const nnvm::NodeAttrs& attrs, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - CHECK_EQ(inputs.size(), 3U); + const PoolingParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(inputs.size(), GetNumBackInputs(param)); CHECK_EQ(outputs.size(), 1U); CHECK_EQ(req.size(), 1U); - const PoolingParam& param = nnvm::get(attrs.parsed); + off_t ograd_idx, in_data_idx, out_data_idx; + // When MKLDNN is enabled, the input data may contains arrays for workspace. + if (GetNumBackInputs(param) == 5) { + ograd_idx = 0; + in_data_idx = 2; + out_data_idx = 3; + } else { + ograd_idx = 0; + in_data_idx = 1; + out_data_idx = 2; + } MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { if (pool_enum::kMaxPooling == param.pool_type || pool_enum::kAvgPooling == param.pool_type || pool_enum::kSumPooling == param.pool_type) { - PoolingBackward(ctx, param, - inputs[0], inputs[1], inputs[2], req[0], outputs[0]); + PoolingBackward(ctx, param, inputs[ograd_idx], + inputs[in_data_idx], inputs[out_data_idx], + req[0], outputs[0]); } else { LOG(FATAL) << "unknown pooling type"; } diff --git a/src/operator/nn/pooling.cc b/src/operator/nn/pooling.cc index 41ace3cecae0..5c5814dc7410 100644 --- a/src/operator/nn/pooling.cc +++ b/src/operator/nn/pooling.cc @@ -23,8 +23,8 @@ * \brief * \author Bing Xu, Jun Wu, Da Zheng */ -#include "./pooling-inl.h" #include "../elemwise_op_common.h" +#include "./pooling-inl.h" #if MXNET_USE_MKL2017 == 1 #include #include "../mkl/mkl_memory-inl.h" @@ -33,11 +33,14 @@ #if MXNET_USE_NNPACK == 1 #include "./nnpack/nnpack_pooling-inl.h" #endif // MXNET_USE_NNPACK +#if MXNET_USE_MKLDNN == 1 +#include "./mkldnn/mkldnn_pooling-inl.h" +#endif // MXNET_USE_MKLDNN namespace mxnet { namespace op { -static void PoolingParamParser(nnvm::NodeAttrs* attrs) { +static void PoolingParamParser(nnvm::NodeAttrs *attrs) { using namespace mshadow; PoolingParam param_; param_.Init(attrs->dict); @@ -48,115 +51,231 @@ static void PoolingParamParser(nnvm::NodeAttrs* attrs) { if (param_.stride.ndim() == 0) param_.stride = Shape2(1, 1); if (param_.pad.ndim() == 0) param_.pad = Shape2(0, 0); } else { - CHECK_EQ(param_.kernel.ndim(), 3U) << param_.kernel.ndim() << "D pooling not supported"; + CHECK_EQ(param_.kernel.ndim(), 3U) << param_.kernel.ndim() + << "D pooling not supported"; if (param_.stride.ndim() == 0) param_.stride = Shape3(1, 1, 1); if (param_.pad.ndim() == 0) param_.pad = Shape3(0, 0, 0); } CHECK_EQ(param_.stride.ndim(), param_.kernel.ndim()) - << "stride and kernel should have the same length"; + << "stride and kernel should have the same length"; CHECK_EQ(param_.pad.ndim(), param_.kernel.ndim()) - << "pad and kernel should have the same length"; + << "pad and kernel should have the same length"; attrs->parsed = std::move(param_); } -static bool PoolingShape(const nnvm::NodeAttrs& attrs, - std::vector *in_shape, std::vector *out_shape) { - const PoolingParam& param_ = nnvm::get(attrs.parsed); +int GetNumOutputs(const PoolingParam ¶m) { +#if MXNET_USE_MKLDNN == 1 + return MKLDNNRequireWorkspace(param) && SupportMKLDNNPooling(param) ? 2 : 1; +#else + return 1; +#endif +} + +int GetNumBackInputs(const PoolingParam ¶m) { +#if MXNET_USE_MKLDNN == 1 + return MKLDNNRequireWorkspace(param) && SupportMKLDNNPooling(param) ? 5 : 3; +#else + return 3; +#endif +} + +static bool PoolingType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + out_attrs->at(0) = in_attrs->at(0); +#if MXNET_USE_MKLDNN == 1 + const PoolingParam ¶m = nnvm::get(attrs.parsed); + if (MKLDNNRequireWorkspace(param) && SupportMKLDNNPooling(param)) { + CHECK_GT(out_attrs->size(), 1U); + out_attrs->at(1) = mshadow::kInt32; + } +#endif + return true; +} + +static bool PoolingShape(const nnvm::NodeAttrs &attrs, + std::vector *in_shape, + std::vector *out_shape) { + const PoolingParam ¶m_ = nnvm::get(attrs.parsed); CHECK_EQ(in_shape->size(), 1U); const TShape &dshape = (*in_shape)[0]; - CHECK_GE(dshape.ndim(), 3U) << "Pooling: Input data should be 3D in (batch, channel, x)" - << " Or 4D in (batch, channel, y, x) " - << " Or 5D in (batch, channel, d, y, x)"; + CHECK_GE(dshape.ndim(), 3U) + << "Pooling: Input data should be 3D in (batch, channel, x)" + << " Or 4D in (batch, channel, y, x) " + << " Or 5D in (batch, channel, d, y, x)"; TShape oshape = dshape; - if (dshape.ndim() == 0) return false; + if (dshape.ndim() == 0) return false; if (param_.kernel.ndim() == 1) { - CHECK_EQ(dshape.ndim(), 3U) << "Pooling: Input data should be 3D in (batch, channel, x)"; + CHECK_EQ(dshape.ndim(), 3U) + << "Pooling: Input data should be 3D in (batch, channel, x)"; if (param_.global_pool) { oshape[2] = 1; } else { CHECK(param_.kernel[0] <= dshape[2] + 2 * param_.pad[0]) - << "kernel size (" << param_.kernel[0] << ") exceeds input (" << dshape[2] - << " padded to " << (dshape[2] + 2*param_.pad[0]) << ")"; + << "kernel size (" << param_.kernel[0] << ") exceeds input (" + << dshape[2] << " padded to " << (dshape[2] + 2 * param_.pad[0]) + << ")"; if (param_.pooling_convention == pool_enum::kValid) { - oshape[2] = 1 + (dshape[2] + 2 * param_.pad[0] - param_.kernel[0]) / - param_.stride[0]; + oshape[2] = 1 + + (dshape[2] + 2 * param_.pad[0] - param_.kernel[0]) / + param_.stride[0]; } else { - oshape[2] = 1 + static_cast(ceil(static_cast( - dshape[2] + 2 * param_.pad[0] - - param_.kernel[0]) / param_.stride[0])); + oshape[2] = 1 + static_cast(ceil( + static_cast(dshape[2] + 2 * param_.pad[0] - + param_.kernel[0]) / + param_.stride[0])); } } out_shape->clear(); out_shape->push_back(oshape); // save output shape +#if MXNET_USE_MKLDNN == 1 + if (MKLDNNRequireWorkspace(param_) && SupportMKLDNNPooling(param_)) + out_shape->push_back(oshape); // for workspace +#endif } else if (param_.kernel.ndim() == 2) { - CHECK_EQ(dshape.ndim(), 4U) << "Pooling: Input data should be 4D in (batch, channel, y, x)"; + CHECK_EQ(dshape.ndim(), 4U) + << "Pooling: Input data should be 4D in (batch, channel, y, x)"; if (param_.global_pool) { oshape[2] = 1; oshape[3] = 1; } else { CHECK(param_.kernel[0] <= dshape[2] + 2 * param_.pad[0]) - << "kernel size (" << param_.kernel[0] << ") exceeds input (" << dshape[2] - << " padded to " << (dshape[2] + 2*param_.pad[0]) << ")"; + << "kernel size (" << param_.kernel[0] << ") exceeds input (" + << dshape[2] << " padded to " << (dshape[2] + 2 * param_.pad[0]) + << ")"; CHECK(param_.kernel[1] <= dshape[3] + 2 * param_.pad[1]) - << "kernel size (" << param_.kernel[1] << ") exceeds input (" << dshape[3] - << " padded to " << (dshape[3] + 2*param_.pad[1]) << ")"; + << "kernel size (" << param_.kernel[1] << ") exceeds input (" + << dshape[3] << " padded to " << (dshape[3] + 2 * param_.pad[1]) + << ")"; if (param_.pooling_convention == pool_enum::kValid) { - oshape[2] = 1 + (dshape[2] + 2 * param_.pad[0] - param_.kernel[0]) / - param_.stride[0]; - oshape[3] = 1 + (dshape[3] + 2 * param_.pad[1] - param_.kernel[1]) / - param_.stride[1]; + oshape[2] = 1 + + (dshape[2] + 2 * param_.pad[0] - param_.kernel[0]) / + param_.stride[0]; + oshape[3] = 1 + + (dshape[3] + 2 * param_.pad[1] - param_.kernel[1]) / + param_.stride[1]; } else { - oshape[2] = 1 + static_cast(ceil(static_cast( - dshape[2] + 2 * param_.pad[0] - - param_.kernel[0]) / param_.stride[0])); - oshape[3] = 1 + static_cast(ceil(static_cast( - dshape[3] + 2 * param_.pad[1] - - param_.kernel[1]) / param_.stride[1])); + oshape[2] = 1 + static_cast(ceil( + static_cast(dshape[2] + 2 * param_.pad[0] - + param_.kernel[0]) / + param_.stride[0])); + oshape[3] = 1 + static_cast(ceil( + static_cast(dshape[3] + 2 * param_.pad[1] - + param_.kernel[1]) / + param_.stride[1])); } } out_shape->clear(); out_shape->push_back(oshape); // save output shape +#if MXNET_USE_MKLDNN == 1 + if (MKLDNNRequireWorkspace(param_) && SupportMKLDNNPooling(param_)) + out_shape->push_back(oshape); // for workspace +#endif } else if (param_.kernel.ndim() == 3) { CHECK_EQ(dshape.ndim(), 5U) - << "Pooling: Input data should be 5D in (batch, channel, d, y, x)"; - CHECK_LE(param_.kernel[0], dshape[2] + 2 * param_.pad[0]) << "kernel size exceeds input"; - CHECK_LE(param_.kernel[1], dshape[3] + 2 * param_.pad[1]) << "kernel size exceeds input"; - CHECK_LE(param_.kernel[2], dshape[4] + 2 * param_.pad[2]) << "kernel size exceeds input"; + << "Pooling: Input data should be 5D in (batch, channel, d, y, x)"; + CHECK_LE(param_.kernel[0], dshape[2] + 2 * param_.pad[0]) + << "kernel size exceeds input"; + CHECK_LE(param_.kernel[1], dshape[3] + 2 * param_.pad[1]) + << "kernel size exceeds input"; + CHECK_LE(param_.kernel[2], dshape[4] + 2 * param_.pad[2]) + << "kernel size exceeds input"; if (param_.global_pool) { oshape[2] = 1; oshape[3] = 1; oshape[4] = 1; } else { if (param_.pooling_convention == pool_enum::kValid) { - oshape[2] = 1 + (dshape[2] + 2 * param_.pad[0] - param_.kernel[0]) / - param_.stride[0]; - oshape[3] = 1 + (dshape[3] + 2 * param_.pad[1] - param_.kernel[1]) / - param_.stride[1]; - oshape[4] = 1 + (dshape[4] + 2 * param_.pad[2] - param_.kernel[2]) / - param_.stride[2]; + oshape[2] = 1 + + (dshape[2] + 2 * param_.pad[0] - param_.kernel[0]) / + param_.stride[0]; + oshape[3] = 1 + + (dshape[3] + 2 * param_.pad[1] - param_.kernel[1]) / + param_.stride[1]; + oshape[4] = 1 + + (dshape[4] + 2 * param_.pad[2] - param_.kernel[2]) / + param_.stride[2]; } else { - oshape[2] = 1 + static_cast(ceil(static_cast( - dshape[2] + 2 * param_.pad[0] - - param_.kernel[0]) / param_.stride[0])); - oshape[3] = 1 + static_cast(ceil(static_cast( - dshape[3] + 2 * param_.pad[1] - - param_.kernel[1]) / param_.stride[1])); - oshape[4] = 1 + static_cast(ceil(static_cast( - dshape[4] + 2 * param_.pad[2] - - param_.kernel[2]) / param_.stride[2])); + oshape[2] = 1 + static_cast(ceil( + static_cast(dshape[2] + 2 * param_.pad[0] - + param_.kernel[0]) / + param_.stride[0])); + oshape[3] = 1 + static_cast(ceil( + static_cast(dshape[3] + 2 * param_.pad[1] - + param_.kernel[1]) / + param_.stride[1])); + oshape[4] = 1 + static_cast(ceil( + static_cast(dshape[4] + 2 * param_.pad[2] - + param_.kernel[2]) / + param_.stride[2])); } } out_shape->clear(); out_shape->push_back(oshape); // save output shape +#if MXNET_USE_MKLDNN == 1 + if (MKLDNNRequireWorkspace(param_) && SupportMKLDNNPooling(param_)) + out_shape->push_back(oshape); // for workspace +#endif } return true; } +#if MXNET_USE_MKLDNN == 1 +void PoolingComputeExCPU(const nnvm::NodeAttrs &attrs, const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const PoolingParam ¶m = nnvm::get(attrs.parsed); + const NDArray *workspace = nullptr; + if (MKLDNNRequireWorkspace(param)) { + CHECK_GT(outputs.size(), 1U); + workspace = &outputs[1]; + } + if (SupportMKLDNN(inputs[0]) + && SupportMKLDNNPooling(param, inputs[0].shape())) { + MKLDNNPoolingCompute(ctx, param, inputs[0], req[0], outputs[0], + workspace); + return; + } + FallBackCompute(PoolingCompute, attrs, ctx, inputs, req, outputs); +} + +void PoolingGradComputeExCPU(const nnvm::NodeAttrs &attrs, const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const PoolingParam ¶m = nnvm::get(attrs.parsed); + const NDArray &out_grad = inputs[0]; + const NDArray *workspace = nullptr; + const NDArray *in_data = nullptr; + if (MKLDNNRequireWorkspace(param)) { + // The first two elements are the gradient of the outputs in forward. + // The third is the input of forward. + // The fourth and the fifth are the outputs of forward. + CHECK_EQ(inputs.size(), 5U); + in_data = &inputs[2]; + workspace = &inputs[4]; + } else { + CHECK_EQ(inputs.size(), 3U); + in_data = &inputs[1]; + } + const NDArray &in_grad = outputs[0]; + if (SupportMKLDNN(inputs[0]) + && SupportMKLDNNPooling(param, inputs[0].shape())) { + MKLDNNPoolingGradCompute(ctx, param, out_grad, *in_data, workspace, + req[0], in_grad); + return; + } + FallBackCompute(PoolingGradCompute, attrs, ctx, inputs, req, outputs); +} +#endif + struct PoolingGrad { const char *op_name; - std::vector operator()(const nnvm::NodePtr& n, - const std::vector& ograds) const { + std::vector operator()( + const nnvm::NodePtr &n, + const std::vector &ograds) const { std::vector heads; heads.push_back(ograds[pool_enum::kOut]); heads.push_back(n->inputs[pool_enum::kData]); @@ -165,10 +284,53 @@ struct PoolingGrad { } }; +inline static bool PoolingStorageType(const nnvm::NodeAttrs &attrs, + const int dev_mask, + DispatchMode *dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1); + + *dispatch_mode = DispatchMode::kFCompute; +#if MXNET_USE_MKLDNN == 1 + const PoolingParam ¶m = nnvm::get(attrs.parsed); + if (dev_mask == mshadow::cpu::kDevMask && SupportMKLDNNPooling(param)) { + *dispatch_mode = DispatchMode::kFComputeEx; + } +#else + CHECK_EQ(out_attrs->size(), 1); +#endif + for (size_t i = 0; i < out_attrs->size(); i++) + (*out_attrs)[i] = kDefaultStorage; + return true; +} + +inline static bool BackwardPoolingStorageType(const nnvm::NodeAttrs &attrs, + const int dev_mask, + DispatchMode *dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + const PoolingParam ¶m = nnvm::get(attrs.parsed); + CHECK_EQ(in_attrs->size(), GetNumBackInputs(param)); + CHECK_EQ(out_attrs->size(), 1); + + *dispatch_mode = DispatchMode::kFCompute; +#if MXNET_USE_MKLDNN == 1 + if (dev_mask == mshadow::cpu::kDevMask && SupportMKLDNNPooling(param)) { + *dispatch_mode = DispatchMode::kFComputeEx; + } +#else + CHECK_EQ(in_attrs->size(), 3); +#endif + for (size_t i = 0; i < out_attrs->size(); i++) + (*out_attrs)[i] = kDefaultStorage; + return true; +} + DMLC_REGISTER_PARAMETER(PoolingParam); NNVM_REGISTER_OP(Pooling) -.describe(R"code(Performs pooling on the input. + .describe(R"code(Performs pooling on the input. The shapes for 1-D pooling are @@ -208,26 +370,55 @@ height, width)*. )code" ADD_FILELINE) .set_num_inputs(1) -.set_num_outputs(1) +.set_num_outputs([](const NodeAttrs& attrs) { + const PoolingParam ¶m = nnvm::get(attrs.parsed); + return GetNumOutputs(param); +}) +#if MXNET_USE_MKLDNN == 1 +.set_attr("FNumVisibleOutputs", + [](const NodeAttrs& attrs) { return 1; }) +#endif +.set_attr("FListOutputNames", + [](const NodeAttrs& attrs) { + return std::vector{"output"}; +}) .set_attr_parser(PoolingParamParser) -.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FInferStorageType", PoolingStorageType) +.set_attr("FInferType", PoolingType) .set_attr("FInferShape", PoolingShape) .set_attr("FCompute", PoolingCompute) -.set_attr("FGradient", ElemwiseGradUseInOut{"_backward_Pooling"}) -.add_argument("data", "NDArray-or-Symbol", "Input data to the pooling operator.") +#if MXNET_USE_MKLDNN == 1 +.set_attr("FComputeEx", PoolingComputeExCPU) +#endif +.set_attr("FGradient", + ElemwiseGradUseInOut{"_backward_Pooling"}) +.add_argument("data", "NDArray-or-Symbol", + "Input data to the pooling operator.") .add_arguments(PoolingParam::__FIELDS__()); NNVM_REGISTER_OP(_backward_Pooling) .set_num_outputs(1) .set_attr("TIsBackward", true) -.set_attr("FInplaceOption", [](const NodeAttrs& attrs){ +.set_attr( + "FInplaceOption", + [](const NodeAttrs &attrs) { #if MXNET_USE_CUDNN == 1 return std::vector >(); #else return std::vector >{{1, 0}}; #endif }) +#if MXNET_USE_MKLDNN == 1 +.set_attr("FResourceRequest", [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; +}) +#endif +.set_attr("FInferStorageType", + BackwardPoolingStorageType) .set_attr_parser(PoolingParamParser) +#if MXNET_USE_MKLDNN == 1 +.set_attr("FComputeEx", PoolingGradComputeExCPU) +#endif .set_attr("FCompute", PoolingGradCompute); } // namespace op diff --git a/src/operator/nn/pooling.cu b/src/operator/nn/pooling.cu index de7dbf12606d..c3bcecfc77b7 100644 --- a/src/operator/nn/pooling.cu +++ b/src/operator/nn/pooling.cu @@ -51,9 +51,9 @@ void PoolingCompute(const nnvm::NodeAttrs& attrs, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - CHECK_EQ(inputs.size(), 1U); - CHECK_EQ(outputs.size(), 1U); const PoolingParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), GetNumOutputs(param)); #if MXNET_USE_CUDNN == 1 if (!param.cudnn_off && param.kernel.ndim() > 1) { @@ -88,10 +88,21 @@ void PoolingGradCompute(const nnvm::NodeAttrs& attrs, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - CHECK_EQ(inputs.size(), 3U); + const PoolingParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(inputs.size(), GetNumBackInputs(param)); CHECK_EQ(outputs.size(), 1U); CHECK_EQ(req.size(), 1U); - const PoolingParam& param = nnvm::get(attrs.parsed); + off_t ograd_idx, in_data_idx, out_data_idx; + // When MKLDNN is enabled, the input data may contains arrays for workspace. + if (GetNumBackInputs(param) == 5) { + ograd_idx = 0; + in_data_idx = 2; + out_data_idx = 3; + } else { + ograd_idx = 0; + in_data_idx = 1; + out_data_idx = 2; + } #if MXNET_USE_CUDNN == 1 if (!param.cudnn_off && param.kernel.ndim() > 1) { @@ -99,8 +110,8 @@ void PoolingGradCompute(const nnvm::NodeAttrs& attrs, switch (param.pool_type) { case pool_enum::kMaxPooling: case pool_enum::kAvgPooling: - GetCuDNNPoolingOp(param).Backward(ctx, - inputs[0], inputs[1], inputs[2], req[0], outputs[0]); + GetCuDNNPoolingOp(param).Backward(ctx, inputs[ograd_idx], + inputs[in_data_idx], inputs[out_data_idx], req[0], outputs[0]); return; case pool_enum::kSumPooling: LOG(WARNING) << "Sum pooling is not supported by cudnn, MXNet sum pooling is applied."; @@ -114,8 +125,8 @@ void PoolingGradCompute(const nnvm::NodeAttrs& attrs, if (pool_enum::kMaxPooling == param.pool_type || pool_enum::kAvgPooling == param.pool_type || pool_enum::kSumPooling == param.pool_type) { - PoolingBackward(ctx, param, inputs[0], - inputs[1], inputs[2], req[0], outputs[0]); + PoolingBackward(ctx, param, inputs[ograd_idx], + inputs[in_data_idx], inputs[out_data_idx], req[0], outputs[0]); } else { LOG(FATAL) << "unknown pooling type"; } diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc index 4686fb8c0dc1..3af4ef897daf 100644 --- a/src/operator/nn/softmax.cc +++ b/src/operator/nn/softmax.cc @@ -25,11 +25,50 @@ #include "./softmax-inl.h" #include "../tensor/elemwise_unary_op.h" #include "../tensor/elemwise_binary_op.h" +#include "mkldnn/mkldnn_base-inl.h" +#include "mkldnn/mkldnn_ops-inl.h" namespace mxnet { namespace op { DMLC_REGISTER_PARAMETER(SoftmaxParam); +#if MXNET_USE_MKLDNN == 1 +static void SoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const SoftmaxParam& param = nnvm::get(attrs.parsed); + // It seems MKLDNN softmax doesn't support training. + // and it only supports non-negative axis. + if (SupportMKLDNN(inputs[0]) && !ctx.is_train && param.axis >= 0) { + MKLDNNSoftmaxForward(attrs, ctx, inputs[0], req[0], outputs[0]); + return; + } + FallBackCompute(SoftmaxCompute, attrs, ctx, + inputs, req, outputs); +} +#endif + +inline static bool SoftmaxStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1); + CHECK_EQ(out_attrs->size(), 1); + +#if MXNET_USE_MKLDNN == 1 + // We only run MKLDNN op if it runs on CPU. + if (dev_mask == mshadow::cpu::kDevMask) + *dispatch_mode = DispatchMode::kFComputeEx; + else +#endif + *dispatch_mode = DispatchMode::kFCompute; + (*out_attrs)[0] = (*in_attrs)[0]; + return true; +} + MXNET_OPERATOR_REGISTER_UNARY(softmax) .describe(R"code(Applies the softmax function. @@ -54,6 +93,10 @@ Example:: )code" ADD_FILELINE) .set_attr_parser(ParamParser) .set_attr("FCompute", SoftmaxCompute) +#if MXNET_USE_MKLDNN == 1 +.set_attr("FComputeEx", SoftmaxComputeExCPU) +#endif +.set_attr("FInferStorageType", SoftmaxStorageType) .set_attr("FGradient", ElemwiseGradUseOut{"_backward_softmax"}) .add_arguments(SoftmaxParam::__FIELDS__()); diff --git a/src/operator/tensor/cast_storage-inl.h b/src/operator/tensor/cast_storage-inl.h index ebe19d41bbc4..1510766639bc 100644 --- a/src/operator/tensor/cast_storage-inl.h +++ b/src/operator/tensor/cast_storage-inl.h @@ -32,6 +32,9 @@ #ifdef __CUDACC__ #include "./cast_storage-inl.cuh" #endif // __CUDACC__ +#if MXNET_USE_MKLDNN == 1 +#include "../nn/mkldnn/mkldnn_base-inl.h" +#endif namespace mxnet { @@ -342,8 +345,16 @@ void CastStorageComputeImpl(const OpContext& ctx, } else if (src_stype == kCSRStorage && dst_stype == kDefaultStorage) { TBlob ret = output.data(); CastStorageCsrDnsImpl(ctx, input, &ret); +#if MXNET_USE_MKLDNN == 1 + } else if (src_stype == kDefaultStorage && dst_stype == kDefaultStorage) { + // In this case, one of the arrays must use non-default layout. + CHECK(input.IsMKLDNN() || output.IsMKLDNN()); + auto in_mem = input.GetMKLDNNData(); + const_cast(output).CopyFrom(*in_mem); + MKLDNNStream::Get()->Submit(); +#endif } else { - LOG(FATAL) << "Not implemented"; + LOG(FATAL) << "Not implemented from " << src_stype << " to " << dst_stype; } } diff --git a/src/operator/tensor/cast_storage.cc b/src/operator/tensor/cast_storage.cc index 9f257b140f7b..81abcc7dc955 100644 --- a/src/operator/tensor/cast_storage.cc +++ b/src/operator/tensor/cast_storage.cc @@ -25,10 +25,50 @@ #include "./cast_storage-inl.h" #include "../elemwise_op_common.h" #include "../tensor/elemwise_unary_op.h" +#include "../nn/mkldnn/mkldnn_base-inl.h" namespace mxnet { namespace op { +#if MXNET_USE_MKLDNN == 1 + +static inline int get_type_size(int dtype) { + MSHADOW_TYPE_SWITCH(dtype, DType, {return sizeof(DType);}); + return -1; +} + +void CastStorageMKLDnsImpl(const OpContext& ctx, const NDArray& src, const NDArray &dst_arr) { + TBlob dns = dst_arr.data(); + CHECK_EQ(ctx.run_ctx.ctx.dev_mask(), Context::kCPU); + CHECK(src.shape() == dns.shape_); + if (src.dtype() != dns.type_flag_) { + // If the input and output have different data types, we have to convert + // the source array into the default layout, cast the data type and copy + // data to the destination array. + const TBlob &src_blob = src.data(); + CHECK(src.ctx() == dst_arr.ctx()); + ndarray::Copy(src.data(), &dns, src.ctx(), dst_arr.ctx(), ctx.run_ctx); + } else { + // This converts the source data to the default format and write the data to + // the destination directly. + std::vector net; + auto src_mkldnn = src.GetMKLDNNData(); + auto src_pd = src_mkldnn->get_primitive_desc(); + auto def_format = GetDefaultFormat(src_pd.desc()); + if (def_format != src_pd.desc().data.format) { + auto dst_pd = GetPrimitiveDesc(src_pd, def_format); + mkldnn::memory dst_mkldnn(dst_pd, dns.dptr_); + net.push_back(mkldnn::reorder(*src_mkldnn, dst_mkldnn)); + mkldnn::stream(mkldnn::stream::kind::eager).submit(net).wait(); + } else { + const TBlob &src_blob = src.data(); + memcpy(dns.dptr_, src_blob.dptr_, src.shape().Size() * get_type_size(dns.type_flag_)); + } + } +} + +#endif + DMLC_REGISTER_PARAMETER(CastStorageParam); NNVM_REGISTER_OP(cast_storage) .add_alias("_sparse_cast_storage") diff --git a/src/operator/tensor/elemwise_binary_op_basic.cc b/src/operator/tensor/elemwise_binary_op_basic.cc index d7e5e04ce87a..93b8d4687453 100644 --- a/src/operator/tensor/elemwise_binary_op_basic.cc +++ b/src/operator/tensor/elemwise_binary_op_basic.cc @@ -24,11 +24,68 @@ */ #include "./elemwise_unary_op.h" #include "./elemwise_binary_op-inl.h" +#include "../nn/mkldnn/mkldnn_ops-inl.h" +#include "../nn/mkldnn/mkldnn_base-inl.h" namespace mxnet { namespace op { -MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(elemwise_add, op::mshadow_op::plus) +static void ElemwiseAddEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); +#if MXNET_USE_MKLDNN == 1 + if (SupportMKLDNN(inputs[0]) && SupportMKLDNN(inputs[1])) { + MKLDNNSumForward(attrs, ctx, inputs, req[0], outputs[0]); + return; + } else if (inputs[0].storage_type() == kDefaultStorage + && inputs[1].storage_type() == kDefaultStorage) { + // This happens if inputs are supposed to be in MKLDNN format + // but MKLDNN doesn't support the data type or the shape. We're + // forced to convert it to the default format. + std::vector in_blobs(2); + std::vector out_blobs(1); + in_blobs[0] = inputs[0].data(); + in_blobs[1] = inputs[1].data(); + out_blobs[0] = outputs[0].data(); + ElemwiseBinaryOp::Compute(attrs, ctx, in_blobs, + req, out_blobs); + return; + } +#endif + ElemwiseBinaryOp::ComputeEx(attrs, ctx, inputs, + req, outputs); +} + +static inline bool ElemwiseAddStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2); + CHECK_EQ(out_attrs->size(), 1); + bool ret = ElemwiseStorageType<2, 1, true, true, true>(attrs, dev_mask, dispatch_mode, + in_attrs, out_attrs); +#if MXNET_USE_MKLDNN == 1 + if (dev_mask == mshadow::cpu::kDevMask + && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage) + && out_attrs->at(0) == kDefaultStorage) { + *dispatch_mode = DispatchMode::kFComputeEx; + } +#endif + return ret; +} + +MXNET_OPERATOR_REGISTER_BINARY(elemwise_add) +.set_attr("FInferStorageType", ElemwiseAddStorageType) +.set_attr("FCompute", ElemwiseBinaryOp::Compute) +.set_attr("FComputeEx", ElemwiseAddEx) +.set_attr("FResourceRequest", /* For Sparse CSR */ + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace};}) MXNET_ADD_SPARSE_OP_ALIAS(elemwise_add) .add_alias("_add").add_alias("_plus").add_alias("_Plus") .describe(R"code(Adds arguments element-wise. @@ -46,6 +103,41 @@ The storage type of ``elemwise_add`` output depends on storage types of inputs // this must differ from elemwise_add to prevent add to optimization in forward pass. MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_grad_add, op::mshadow_op::plus); +static void _backward_ElemwiseAddEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 2U); +#if MXNET_USE_MKLDNN == 1 + if (inputs[0].IsMKLDNN()) { + MKLDNNCopy(attrs, ctx, inputs[0], req[0], outputs[0]); + MKLDNNCopy(attrs, ctx, inputs[0], req[1], outputs[1]); + return; + } +#endif + ElemwiseBinaryOp::BackwardUseNoneEx( + attrs, ctx, inputs, req, outputs); +} + +static inline bool _backward_ElemwiseAddStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1); + CHECK_EQ(out_attrs->size(), 2); + bool ret = ElemwiseStorageType<1, 2, true, true, true>(attrs, dev_mask, dispatch_mode, + in_attrs, out_attrs); +#if MXNET_USE_MKLDNN == 1 + if (dev_mask == mshadow::cpu::kDevMask) { + *dispatch_mode = DispatchMode::kFComputeEx; + } +#endif + return ret; +} + NNVM_REGISTER_OP(_backward_add) .set_num_inputs(1) .set_num_outputs(2) @@ -55,13 +147,15 @@ NNVM_REGISTER_OP(_backward_add) return std::vector >{{0, 0}, {0, 1}}; }) +#if MXNET_USE_MKLDNN == 1 +.set_attr("FResourceRequest", [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; +}) +#endif .set_attr("FCompute", ElemwiseBinaryOp::BackwardUseNone< cpu, mshadow_op::identity, mshadow_op::identity>) -.set_attr("FComputeEx", - ElemwiseBinaryOp::BackwardUseNoneEx) -.set_attr("FInferStorageType", - ElemwiseStorageType<1, 2, true, true, true>); +.set_attr("FComputeEx", _backward_ElemwiseAddEx) +.set_attr("FInferStorageType", _backward_ElemwiseAddStorageType); MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(elemwise_sub, op::mshadow_op::minus) MXNET_ADD_SPARSE_OP_ALIAS(elemwise_sub) diff --git a/src/operator/tensor/elemwise_binary_scalar_op_basic.cc b/src/operator/tensor/elemwise_binary_scalar_op_basic.cc index 67923790ee83..034377ba79ee 100644 --- a/src/operator/tensor/elemwise_binary_scalar_op_basic.cc +++ b/src/operator/tensor/elemwise_binary_scalar_op_basic.cc @@ -65,7 +65,7 @@ static bool BinaryScalarStorageTypeWithDenseResultStorageType(const NodeAttrs& a const auto dispatch_ex = invalid_ctx ? DispatchMode::kFComputeFallback : DispatchMode::kFComputeEx; const double alpha = nnvm::get(attrs.parsed); - if (instype == kDefaultStorage) { + if (common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) { dispatched = storage_type_assign(&out_attrs[0], kDefaultStorage, dispatch_mode, DispatchMode::kFCompute); } @@ -92,7 +92,7 @@ static bool BinaryScalarStorageType(const nnvm::NodeAttrs& attrs, const auto in_stype = in_attrs->at(0); auto &out_stype = out_attrs->at(0); bool dispatched = false; - if (!dispatched && in_stype == kDefaultStorage) { + if (!dispatched && (in_stype == kDefaultStorage)) { // dns -> dns dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFCompute); diff --git a/src/operator/tensor/elemwise_sum.cc b/src/operator/tensor/elemwise_sum.cc index 041a0be00796..59a284aea6e1 100644 --- a/src/operator/tensor/elemwise_sum.cc +++ b/src/operator/tensor/elemwise_sum.cc @@ -24,6 +24,8 @@ */ #include "./elemwise_sum.h" #include "../../ndarray/ndarray_function.h" +#include "../nn/mkldnn/mkldnn_ops-inl.h" +#include "../../common/utils.h" namespace mxnet { namespace op { @@ -79,9 +81,28 @@ bool ElementWiseSumForwardInferStorageType(const nnvm::NodeAttrs& attrs, std::vector *out_attrs) { CHECK(!in_attrs->empty()); CHECK_EQ(out_attrs->size(), 1U); - return ElemwiseStorageAttr(attrs, dev_mask, dispatch_mode, - in_attrs, out_attrs); + bool ret = ElemwiseStorageAttr(attrs, dev_mask, dispatch_mode, + in_attrs, out_attrs); +#if MXNET_USE_MKLDNN == 1 + // We should always use FComputeEx. + if (dev_mask == mshadow::cpu::kDevMask + && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage) + && out_attrs->at(0) == kDefaultStorage) { + *dispatch_mode = DispatchMode::kFComputeEx; + } +#endif + return ret; +} + +#if MXNET_USE_MKLDNN == 1 +static inline bool IsMKLDNN(const std::vector &arrs) { + for (auto &arr : arrs) { + if (!arr.IsMKLDNN()) + return false; + } + return true; } +#endif void ElementWiseSumComputeExCPU(const nnvm::NodeAttrs& attrs, const OpContext& op_ctx, @@ -92,13 +113,30 @@ void ElementWiseSumComputeExCPU(const nnvm::NodeAttrs& attrs, CHECK_EQ(outputs.size(), 1U); CHECK_EQ(req.size(), 1U); if (req[0] == kNullOp) return; - CHECK_EQ(req[0], kWriteTo) << "ElementWiseSumComputeExCPU only supports req = kWriteTo"; if (inputs[0].storage_type() == kRowSparseStorage) { + CHECK_EQ(req[0], kWriteTo) + << "ElementWiseSumComputeExCPU only supports req = kWriteTo"; mshadow::Stream* s = op_ctx.get_stream(); Resource rsc = ResourceManager::Get()->Request(op_ctx.run_ctx.get_ctx(), ResourceRequest(ResourceRequest::kTempSpace)); NDArray out_nd = outputs[0]; mxnet::ndarray::ElementwiseSum(s, rsc, inputs, &out_nd); +#if MXNET_USE_MKLDNN == 1 + } else if (IsMKLDNN(inputs)) { + MKLDNNSumForward(attrs, op_ctx, inputs, req[0], outputs[0]); +#endif + } else if (common::ContainsOnlyStorage(inputs, kDefaultStorage)) { + // This case happens when we want to create an MKLDNN NDArray but the type + // or the shape isn't supported by MKLDNN. In this case, NDArray falls back + // to the default storage type and, thus, we have to handle the default + // storage in FComputeEx. + std::vector in_blobs(inputs.size()); + std::vector out_blobs(outputs.size()); + for (size_t i = 0; i < in_blobs.size(); i++) + in_blobs[i] = inputs[i].data(); + for (size_t i = 0; i < out_blobs.size(); i++) + out_blobs[i] = outputs[i].data(); + ElementWiseSumCompute(attrs, op_ctx, in_blobs, req, out_blobs); } else { LOG(FATAL) << "Not implemented: " << operator_string(attrs, op_ctx, inputs, req, outputs); } diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc index 079a33e87548..8b32199f313a 100644 --- a/src/operator/tensor/elemwise_unary_op_basic.cc +++ b/src/operator/tensor/elemwise_unary_op_basic.cc @@ -24,6 +24,7 @@ #include #include "elemwise_unary_op.h" #include "./elemwise_binary_op-inl.h" +#include "../nn/mkldnn/mkldnn_ops-inl.h" namespace mxnet { namespace op { @@ -108,12 +109,66 @@ MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_sigmoid, unary_bwd); // copy +static void CopyEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + const auto in_stype = inputs[0].storage_type(); + const auto out_stype = outputs[0].storage_type(); +#if MXNET_USE_MKLDNN == 1 + if (inputs[0].IsMKLDNN()) { + MKLDNNCopy(attrs, ctx, inputs[0], req[0], outputs[0]); + return; + } else if (in_stype == kDefaultStorage && out_stype == kDefaultStorage) { + // This happens if inputs are supposed to be in MKLDNN format + // but MKLDNN doesn't support the data type or the shape. We're + // forced to convert it to the default format. + std::vector in_blobs(1); + std::vector out_blobs(1); + in_blobs[0] = inputs[0].data(); + out_blobs[0] = outputs[0].data(); + UnaryOp::IdentityCompute(attrs, ctx, in_blobs, req, out_blobs); + return; + } +#endif + UnaryOp::IdentityComputeEx(attrs, ctx, inputs, req, outputs); +} + +static inline bool CopyStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1); + CHECK_EQ(out_attrs->size(), 1); + bool ret = ElemwiseStorageType<1, 1, false, true, true>(attrs, dev_mask, dispatch_mode, + in_attrs, out_attrs); +#if MXNET_USE_MKLDNN == 1 + // We have to make sure all inputs are default layouts. Otherwise, we might + // want to fallback. + if (dev_mask == mshadow::cpu::kDevMask + && in_attrs->at(0) == kDefaultStorage + && out_attrs->at(0) == kDefaultStorage) { + *dispatch_mode = DispatchMode::kFComputeEx; + } +#endif + return ret; +} + MXNET_OPERATOR_REGISTER_UNARY(_copy) .MXNET_DESCRIBE("Returns a copy of the input.") .add_alias("identity") -.set_attr("FInferStorageType", ElemwiseStorageType<1, 1, false, true, true>) +.set_attr("FInferStorageType", CopyStorageType) .set_attr("FCompute", UnaryOp::IdentityCompute) -.set_attr("FComputeEx", UnaryOp::IdentityComputeEx) +.set_attr("FComputeEx", CopyEx) +#if MXNET_USE_MKLDNN == 1 +.set_attr("FResourceRequest", [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; +}) +#endif .set_attr("FInplaceIdentity", [](const NodeAttrs& attrs){ return std::vector{true}; @@ -128,9 +183,14 @@ NNVM_REGISTER_OP(_backward_copy) [](const NodeAttrs& attrs){ return std::vector >{{0, 0}}; }) -.set_attr("FInferStorageType", ElemwiseStorageType<1, 1, false, true, true>) +.set_attr("FInferStorageType", CopyStorageType) .set_attr("FCompute", UnaryOp::IdentityCompute) -.set_attr("FComputeEx", UnaryOp::IdentityComputeEx) +.set_attr("FComputeEx", CopyEx) +#if MXNET_USE_MKLDNN == 1 +.set_attr("FResourceRequest", [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; +}) +#endif .set_attr("FInplaceIdentity", [](const NodeAttrs& attrs){ return std::vector{true}; diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc index e8fdce491484..4129113a6a75 100644 --- a/src/operator/tensor/matrix_op.cc +++ b/src/operator/tensor/matrix_op.cc @@ -25,6 +25,8 @@ // this will be invoked by gcc and compile CPU version #include "./matrix_op-inl.h" #include "./elemwise_unary_op.h" +#include "../nn/mkldnn/mkldnn_ops-inl.h" +#include "../nn/mkldnn/mkldnn_base-inl.h" namespace mxnet { namespace op { @@ -180,6 +182,51 @@ If the argument `reverse` is set to 1, then the special values are inferred from .add_argument("data", "NDArray-or-Symbol", "Input data to reshape.") .add_arguments(ReshapeParam::__FIELDS__()); +static void FlattenEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); +#if MXNET_USE_MKLDNN == 1 + const auto in_stype = inputs[0].storage_type(); + const auto out_stype = outputs[0].storage_type(); + if (inputs[0].IsMKLDNN()) { + MKLDNNCopy(attrs, ctx, inputs[0], req[0], outputs[0]); + // If the output is a special MKLDNN layout and the number of dimensions + // is larger than 2, we should use the default layout. + if (outputs[0].IsMKLDNN() && inputs[0].shape().ndim() > 2) + const_cast(outputs[0]).Reorder2Default(); + return; + } else { + // This happens if inputs are supposed to be in MKLDNN format + // but MKLDNN doesn't support the data type or the shape. We're + // forced to convert it to the default format. + FallBackCompute(UnaryOp::IdentityCompute, attrs, ctx, inputs, req, outputs); + return; + } +#endif +} + +static inline bool FlattenStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1); + CHECK_EQ(out_attrs->size(), 1); + bool ret = ElemwiseStorageType<1, 1, false, true, true>(attrs, dev_mask, dispatch_mode, + in_attrs, out_attrs); +#if MXNET_USE_MKLDNN == 1 + if (dev_mask == mshadow::cpu::kDevMask + && in_attrs->at(0) == kDefaultStorage + && out_attrs->at(0) == kDefaultStorage) { + *dispatch_mode = DispatchMode::kFComputeEx; + } +#endif + return ret; +} NNVM_REGISTER_OP(Flatten) .add_alias("flatten") @@ -210,8 +257,15 @@ Example:: .set_num_outputs(1) .set_attr("FInferShape", FlattenShape) .set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FInferStorageType", FlattenStorageType) .set_attr("FGradient", ElemwiseGradUseNone{ "_backward_copy" }) .set_attr("FCompute", UnaryOp::IdentityCompute) +.set_attr("FComputeEx", FlattenEx) +#if MXNET_USE_MKLDNN == 1 +.set_attr("FResourceRequest", [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; +}) +#endif .set_attr("FInplaceOption", [](const NodeAttrs& attrs) { return std::vector >{{0, 0}}; diff --git a/src/storage/cpu_device_storage.h b/src/storage/cpu_device_storage.h index f0dd61f01ac0..fcea4ef74cdf 100644 --- a/src/storage/cpu_device_storage.h +++ b/src/storage/cpu_device_storage.h @@ -54,7 +54,11 @@ class CPUDeviceStorage { /*! * \brief Alignment of allocation. */ +#if MXNET_USE_MKLDNN == 1 + static constexpr size_t alignment_ = 4096; +#else static constexpr size_t alignment_ = 16; +#endif }; // class CPUDeviceStorage inline void* CPUDeviceStorage::Alloc(size_t size) { diff --git a/tests/cpp/include/test_core_op.h b/tests/cpp/include/test_core_op.h index 6a220bdad6d7..1074d0182198 100644 --- a/tests/cpp/include/test_core_op.h +++ b/tests/cpp/include/test_core_op.h @@ -314,7 +314,9 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer // Set up forward attrs_ = ParseAttrs(op_, args); - const int num_inputs = op_->num_inputs; + int num_inputs = op_->num_inputs; + if (op_->get_num_inputs) + num_inputs = op_->get_num_inputs(attrs_); if (!inputs.empty()) { CHECK_EQ(inputs.size(), static_cast(num_inputs)); diff --git a/tests/cpp/operator/activation_perf.cc b/tests/cpp/operator/activation_perf.cc index fe51be533510..e482848705ad 100644 --- a/tests/cpp/operator/activation_perf.cc +++ b/tests/cpp/operator/activation_perf.cc @@ -41,7 +41,7 @@ TEST(ACTIVATION_PERF, ExecuteBidirectional) { TShape shape({5, 5}); kwargs_t kwargs = basic_activation_args; kwargs.push_back({"act_type", "tanh"}); - test::op::CoreOpRunner runner; + test::op::LegacyOpRunner runner; runner.RunBidirectional(false, { shape }, kwargs, 1); } @@ -52,7 +52,7 @@ TEST(ACTIVATION_PERF, TimingCPU) { kwargs_t kwargs = basic_activation_args; // Which math function is arbitrary since it will have roughly constant timing among approaches kwargs.push_back({"act_type", "tanh"}); - test::op::CoreOpRunner runner; + test::op::LegacyOpRunner runner; runner.RunBidirectional(false, { TShape({10, 10, 10, 10}) }, kwargs, 1); // prime code and cache diff --git a/tests/cpp/operator/fully_conn_perf.cc b/tests/cpp/operator/fully_conn_perf.cc index 8c32e51e3161..c8d8021f6f6e 100644 --- a/tests/cpp/operator/fully_conn_perf.cc +++ b/tests/cpp/operator/fully_conn_perf.cc @@ -41,7 +41,7 @@ const kwargs_t basic_fullyconn_args = { {"num_hidden", "250"} }; TEST(FULLY_CONNECTED, ExecuteBidirectionalFullyConnected) { TShape shape({5, 5}); kwargs_t kwargs = basic_fullyconn_args; - test::op::CoreOpRunner runner; + test::op::LegacyOpRunner runner; runner.RunBidirectional(false, { shape }, kwargs, 1); } @@ -50,7 +50,7 @@ TEST(FULLY_CONNECTED, ExecuteBidirectionalFullyConnected) { */ TEST(FULLY_CONNECTED, FullyConnectedTimingCPU) { kwargs_t kwargs = basic_fullyconn_args; - test::op::CoreOpRunner runner; + test::op::LegacyOpRunner runner; runner.RunBidirectional(false, { TShape({10, 10, 10, 10}) }, kwargs, 1); // prime code and cache diff --git a/tests/python/unittest/test_executor.py b/tests/python/unittest/test_executor.py index 40a922cf5c2c..ad69ab53352c 100644 --- a/tests/python/unittest/test_executor.py +++ b/tests/python/unittest/test_executor.py @@ -140,22 +140,20 @@ def test_dot(): def test_reshape(): x = mx.sym.Variable('x') - y = mx.sym.FullyConnected(x, num_hidden=4) + y = mx.sym.Dropout(x, p=0.2) exe = y.simple_bind(mx.cpu(), x=(5,4), grad_req='null') exe.arg_arrays[0][:] = 1 - exe.arg_arrays[1][:] = mx.nd.ones((4,4)) - exe.arg_arrays[2][:] = 0 new_exe = exe.reshape(x=(3,4)) new_exe.forward(is_train=False) # test sub exec forward - assert np.all(new_exe.outputs[0].asnumpy() == 4) + assert np.all(new_exe.outputs[0].asnumpy() == 1) # test shared memory - assert np.all(exe.outputs[0].asnumpy()[:3] == 4) + assert np.all(exe.outputs[0].asnumpy()[:3] == 1) # test base exec forward exe.forward(is_train=False) - assert np.all(exe.outputs[0].asnumpy() == 4) + assert np.all(exe.outputs[0].asnumpy() == 1) if __name__ == "__main__": import nose