From 1abdf93a99cab1d721e8b08d776ab8532daa56fd Mon Sep 17 00:00:00 2001 From: Yutian Li Date: Mon, 10 Aug 2015 01:00:03 +0800 Subject: [PATCH 01/21] [storage] storage backends --- .gitignore | 4 ++ Makefile | 4 +- include/mxnet/cuda_utils.h | 115 +++++++++++++++++++++++++++++++++++++ include/mxnet/storage.h | 25 ++++---- src/storage/cpu_storage.cc | 18 ++++++ src/storage/cpu_storage.h | 41 +++++++++++++ src/storage/gpu_storage.cc | 32 +++++++++++ src/storage/gpu_storage.h | 36 ++++++++++++ src/storage/storage.cc | 29 +++++----- 9 files changed, 280 insertions(+), 24 deletions(-) create mode 100644 include/mxnet/cuda_utils.h create mode 100644 src/storage/cpu_storage.cc create mode 100644 src/storage/cpu_storage.h create mode 100644 src/storage/gpu_storage.cc create mode 100644 src/storage/gpu_storage.h diff --git a/.gitignore b/.gitignore index 1c1cd99ab7a4..1aa2ce47da11 100644 --- a/.gitignore +++ b/.gitignore @@ -49,3 +49,7 @@ Debug *.swp *.swo *.swn + +# Emacs +.clang_complete +.dir-locals.el diff --git a/Makefile b/Makefile index c94c2705f886..e830c3eac7df 100644 --- a/Makefile +++ b/Makefile @@ -58,7 +58,7 @@ endif BIN = test/api_registry_test OBJ = storage.o narray_op_cpu.o static_operator.o static_operator_cpu.o # add threaded engine after it is done -OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o operator.o fully_connect_op_cpu.o +OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o operator.o fully_connect_op_cpu.o cpu_storage.o gpu_storage.o CUOBJ = SLIB = lib/libmxnet.so ALIB = lib/libmxnet.a @@ -76,6 +76,8 @@ $(DMLC_CORE)/libdmlc.a: + cd $(DMLC_CORE); make libdmlc.a config=$(ROOTDIR)/$(config); cd $(ROOTDIR) storage.o: src/storage/storage.cc +cpu_storage.o: src/storage/cpu_storage.cc +gpu_storage.o: src/storage/gpu_storage.cc engine.o: src/dag_engine/simple_engine.cc #engine.o: src/dag_engine/threaded_engine.cc src/common/concurrent_blocking_queue.h src/common/spin_lock.h narray.o: src/narray/narray.cc diff --git a/include/mxnet/cuda_utils.h b/include/mxnet/cuda_utils.h new file mode 100644 index 000000000000..17c78aaf432c --- /dev/null +++ b/include/mxnet/cuda_utils.h @@ -0,0 +1,115 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file cuda_utils.h + * \brief CUDA debugging utilities. + */ +#ifndef MXNET_CUDA_UTILS_H_ +#define MXNET_CUDA_UTILS_H_ + +#include + +#ifdef MXNET_USE_CUDA + +#include +#include +#include + +inline const char* CublasGetErrorString(cublasStatus_t error) { + switch (error) { + case CUBLAS_STATUS_SUCCESS: + return "CUBLAS_STATUS_SUCCESS"; + case CUBLAS_STATUS_NOT_INITIALIZED: + return "CUBLAS_STATUS_NOT_INITIALIZED"; + case CUBLAS_STATUS_ALLOC_FAILED: + return "CUBLAS_STATUS_ALLOC_FAILED"; + case CUBLAS_STATUS_INVALID_VALUE: + return "CUBLAS_STATUS_INVALID_VALUE"; + case CUBLAS_STATUS_ARCH_MISMATCH: + return "CUBLAS_STATUS_ARCH_MISMATCH"; + case CUBLAS_STATUS_MAPPING_ERROR: + return "CUBLAS_STATUS_MAPPING_ERROR"; + case CUBLAS_STATUS_EXECUTION_FAILED: + return "CUBLAS_STATUS_EXECUTION_FAILED"; + case CUBLAS_STATUS_INTERNAL_ERROR: + return "CUBLAS_STATUS_INTERNAL_ERROR"; + case CUBLAS_STATUS_NOT_SUPPORTED: + return "CUBLAS_STATUS_NOT_SUPPORTED"; + default: + break; + } + return "Unknown cuBLAS status"; +} + +inline const char* CurandGetErrorString(curandStatus_t status) { + switch (status) { + case CURAND_STATUS_SUCCESS: + return "CURAND_STATUS_SUCCESS"; + case CURAND_STATUS_VERSION_MISMATCH: + return "CURAND_STATUS_VERSION_MISMATCH"; + case CURAND_STATUS_NOT_INITIALIZED: + return "CURAND_STATUS_NOT_INITIALIZED"; + case CURAND_STATUS_ALLOCATION_FAILED: + return "CURAND_STATUS_ALLOCATION_FAILED"; + case CURAND_STATUS_TYPE_ERROR: + return "CURAND_STATUS_TYPE_ERROR"; + case CURAND_STATUS_OUT_OF_RANGE: + return "CURAND_STATUS_OUT_OF_RANGE"; + case CURAND_STATUS_LENGTH_NOT_MULTIPLE: + return "CURAND_STATUS_LENGTH_NOT_MULTIPLE"; + case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED: + return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED"; + case CURAND_STATUS_LAUNCH_FAILURE: + return "CURAND_STATUS_LAUNCH_FAILURE"; + case CURAND_STATUS_PREEXISTING_FAILURE: + return "CURAND_STATUS_PREEXISTING_FAILURE"; + case CURAND_STATUS_INITIALIZATION_FAILED: + return "CURAND_STATUS_INITIALIZATION_FAILED"; + case CURAND_STATUS_ARCH_MISMATCH: + return "CURAND_STATUS_ARCH_MISMATCH"; + case CURAND_STATUS_INTERNAL_ERROR: + return "CURAND_STATUS_INTERNAL_ERROR"; + } + return "Unknown cuRAND status"; +} + +#define CHECK_CUDA_ERROR(msg) \ + { \ + cudaError_t e = cudaGetLastError(); \ + CHECK_EQ(e, cudaSuccess) << (msg) << " CUDA: " << cudaGetErrorString(e); \ + } + +#define CUDA_CALL(func) \ + { \ + cudaError_t e = (func); \ + CHECK_EQ(e, cudaSuccess) << "CUDA: " << cudaGetErrorString(e); \ + } + +#define CUBLAS_CALL(func) \ + { \ + cublasStatus_t e = (func); \ + CHECK_EQ(e, CUBLAS_STATUS_SUCCESS) \ + << "cuBLAS: " << CublasGetErrorString(e); \ + } + +#define CURAND_CALL(func) \ + { \ + curandStatus_t e = (func); \ + CHECK_EQ(e, CURAND_STATUS_SUCCESS) \ + << "cuRAND: " << CurandGetErrorString(e); \ + } + +#endif // MXNET_USE_CUDA + +#ifdef MXNET_USE_CUDNN + +#include + +#define CUDNN_CALL(func) \ + { \ + cudnnStatus_t e = (func); \ + CHECK_EQ(e, CUDNN_STATUS_SUCCESS) << "cuDNN: " << cudnnGetErrorString(e); \ + } + +#endif // MXNET_USE_CUDNN + +#endif // MXNET_CUDA_UTILS_H_ diff --git a/include/mxnet/storage.h b/include/mxnet/storage.h index 3bb123b44816..a900537baff3 100644 --- a/include/mxnet/storage.h +++ b/include/mxnet/storage.h @@ -1,5 +1,5 @@ /*! - * Copyright (c) 2015 by Contributors + * Copyright (c) 2015 by Contributors * \file storage.h * \brief the memory allocator that manages the memory across multiple devices */ @@ -9,6 +9,7 @@ #include "./tensor_blob.h" namespace mxnet { + /*! \brief memory allocator of storage */ class StorageManager { public: @@ -17,14 +18,9 @@ class StorageManager { */ struct Handle { /*! \brief pointer to the data */ - void *dptr; + void* dptr; /*! \brief context information about device and deviceID */ Context ctx; - /*! - * \brief internal handle reserved for manager, - * user should not change or use this - */ - void *handle_; }; /*! * \brief allocate a new contiguous memory for a given size @@ -32,14 +28,23 @@ class StorageManager { * \param ctx context information about the device and deviceID * \return Handle struct */ - virtual Handle Alloc(size_t size, Context ctx) = 0; + Handle Alloc(size_t size, Context ctx); /*! * \brief free the space represened the handle * \param handle the handle to memory to be freed */ - virtual void Free(Handle handle) = 0; + void Free(Handle handle); /*! \return storage manager singleton */ - static StorageManager *Get(); + static StorageManager* Get(); + + private: + /*! + * \brief disabled constructors + */ + StorageManager() {} + DISALLOW_COPY_AND_ASSIGN(StorageManager); }; // class StorageManager + } // namespace mxnet + #endif // MXNET_STORAGE_H_ diff --git a/src/storage/cpu_storage.cc b/src/storage/cpu_storage.cc new file mode 100644 index 000000000000..b19e67ef2e8e --- /dev/null +++ b/src/storage/cpu_storage.cc @@ -0,0 +1,18 @@ +/*! + * Copyright (c) 2015 by Contributors + */ +#include "./cpu_storage.h" +#include +#include + +namespace mxnet { +namespace storage { + +void* CpuStorage::Alloc(size_t size) { + return CHECK_NOTNULL(aligned_alloc(alignment_, size)); +} + +void CpuStorage::Free(void* ptr) { free(ptr); } + +} // namespace storage +} // namespace mxnet diff --git a/src/storage/cpu_storage.h b/src/storage/cpu_storage.h new file mode 100644 index 000000000000..03aa0eee46ab --- /dev/null +++ b/src/storage/cpu_storage.h @@ -0,0 +1,41 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file cpu_storage.h + * \brief CPU storage implementation. + */ +#ifndef MXNET_STORAGE_CPU_STORAGE_H_ +#define MXNET_STORAGE_CPU_STORAGE_H_ + +#include "mxnet/base.h" + +namespace mxnet { +namespace storage { + +/*! + * \brief CPU storage implementation. + */ +class CpuStorage { + public: + /*! + * \brief Aligned allocation on CPU. + * \param size Size to allocate. + * \return Pointer to the storage. + */ + static void* Alloc(size_t size); + /*! + * \brief Deallocation. + * \param ptr Pointer to deallocate. + */ + static void Free(void* ptr); + + private: + /*! + * \brief Alignment of allocation. + */ + static constexpr size_t alignment_ = 16; +}; + +} // namespace storage +} // namespace mxnet + +#endif // MXNET_STORAGE_CPU_STORAGE_H_ diff --git a/src/storage/gpu_storage.cc b/src/storage/gpu_storage.cc new file mode 100644 index 000000000000..cc30d9038644 --- /dev/null +++ b/src/storage/gpu_storage.cc @@ -0,0 +1,32 @@ +/*! + * Copyright (c) 2015 by Contributors + */ +#include "./gpu_storage.h" +#include "mxnet/cuda_utils.h" +#ifdef MXNET_USE_CUDA +#include +#endif // MXNET_USE_CUDA + +namespace mxnet { +namespace storage { + +void* GpuStorage::Alloc(size_t size) { + void* ret; +#ifdef MXNET_USE_CUDA + CUDA_CALL(cudaMalloc(&ret, size)); +#else // MXNET_USE_CUDA + LOG(FATAL) << "Please compile with CUDA enabled"; +#endif // MXNET_USE_CUDA + return ret; +} + +void GpuStorage::Free(void* ptr) { +#ifdef MXNET_USE_CUDA + CUDA_CALL(cudaFree(ptr)); +#else // MXNET_USE_CUDA + LOG(FATAL) << "Please compile with CUDA enabled"; +#endif // MXNET_USE_CUDA +} + +} // namespace storage +} // namespace mxnet diff --git a/src/storage/gpu_storage.h b/src/storage/gpu_storage.h new file mode 100644 index 000000000000..5834dd7b58ba --- /dev/null +++ b/src/storage/gpu_storage.h @@ -0,0 +1,36 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file gpu_storage.h + * \brief GPU storage implementation. + */ + +#ifndef MXNET_STORAGE_GPU_STORAGE_H_ +#define MXNET_STORAGE_GPU_STORAGE_H_ + +#include "mxnet/base.h" + +namespace mxnet { +namespace storage { + +/*! + * \brief GPU storage implementation. + */ +class GpuStorage { + public: + /*! + * \brief Allocation. + * \param size Size to allocate. + * \return Pointer to the storage. + */ + static void* Alloc(size_t size); + /*! + * \brief Deallocation. + * \param ptr Pointer to deallocate. + */ + static void Free(void* ptr); +}; + +} // namespace storage +} // namespace mxnet + +#endif // MXNET_STORAGE_GPU_STORAGE_H_ diff --git a/src/storage/storage.cc b/src/storage/storage.cc index 3b539551f02d..12e25e68ddaf 100644 --- a/src/storage/storage.cc +++ b/src/storage/storage.cc @@ -1,21 +1,21 @@ -// Copyright (c) 2015 by Contributors +/*! + * Copyright (c) 2015 by Contributors + */ #include #include namespace mxnet { -class NaiveStorageManager : public StorageManager { - public: - virtual Handle Alloc(size_t size, Context ctx); - virtual void Free(Handle handle); -}; -StorageManager::Handle -NaiveStorageManager::Alloc(size_t size, Context ctx) { +// class NaiveStorageManager : public StorageManager { +// public: +// virtual Handle Alloc(size_t size, Context ctx); +// virtual void Free(Handle handle); +// }; + +StorageManager::Handle StorageManager::Alloc(size_t size, Context ctx) { Handle hd; hd.ctx = ctx; - hd.handle_ = NULL; if (ctx.dev_mask == cpu::kDevMask) { hd.dptr = calloc(size, sizeof(real_t)); - // cudaMallocHost(&hd.dptr, size); } else { #if MXNET_USE_CUDA cudaMalloc(&hd.dptr, size); @@ -23,7 +23,8 @@ NaiveStorageManager::Alloc(size_t size, Context ctx) { } return hd; } -void NaiveStorageManager::Free(StorageManager::Handle handle) { + +void StorageManager::Free(StorageManager::Handle handle) { if (handle.ctx.dev_mask == cpu::kDevMask) { free(handle.dptr); handle.dptr = NULL; @@ -34,8 +35,10 @@ void NaiveStorageManager::Free(StorageManager::Handle handle) { #endif } } -StorageManager *StorageManager::Get() { - static NaiveStorageManager inst; + +StorageManager* StorageManager::Get() { + static StorageManager inst; return &inst; } + } // namespace mxnet From d22b0345af8ee9472764f8ff1adc1e2f6e6df0fd Mon Sep 17 00:00:00 2001 From: Yutian Li Date: Mon, 10 Aug 2015 01:53:11 +0800 Subject: [PATCH 02/21] [storage] naive storage manager --- src/storage/gpu_storage.h | 1 - src/storage/naive_storage_manager.h | 49 ++++++++++++++++++++++++++++ src/storage/pooled_storage_manager.h | 39 ++++++++++++++++++++++ src/storage/storage_manager.h | 39 ++++++++++++++++++++++ 4 files changed, 127 insertions(+), 1 deletion(-) create mode 100644 src/storage/naive_storage_manager.h create mode 100644 src/storage/pooled_storage_manager.h create mode 100644 src/storage/storage_manager.h diff --git a/src/storage/gpu_storage.h b/src/storage/gpu_storage.h index 5834dd7b58ba..08a9befe41aa 100644 --- a/src/storage/gpu_storage.h +++ b/src/storage/gpu_storage.h @@ -3,7 +3,6 @@ * \file gpu_storage.h * \brief GPU storage implementation. */ - #ifndef MXNET_STORAGE_GPU_STORAGE_H_ #define MXNET_STORAGE_GPU_STORAGE_H_ diff --git a/src/storage/naive_storage_manager.h b/src/storage/naive_storage_manager.h new file mode 100644 index 000000000000..c0cad9355262 --- /dev/null +++ b/src/storage/naive_storage_manager.h @@ -0,0 +1,49 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file naive_storage_manager.h + * \brief Naive storage manager. + */ +#ifndef MXNET_STORAGE_NAIVE_STORAGE_MANAGER_H_ +#define MXNET_STORAGE_NAIVE_STORAGE_MANAGER_H_ + +#include "./storage_manager.h" +#include "mxnet/base.h" + +namespace mxnet { +namespace storage { + +/*! + * \brief Naive storage manager. + */ +template +class NaiveStorageManager final : public StorageManager { + public: + /*! + * \brief Default constructor. + */ + NaiveStorageManager() = default; + /*! + * \brief Default destructor. + */ + ~NaiveStorageManager() = default; + void* Alloc(size_t size) override; + void Free(void* ptr) override; + + private: + DISALLOW_COPY_AND_ASSIGN(NaiveStorageManager); +}; + +template +void* NaiveStorageManager::Alloc(size_t size) { + return DeviceStorage::Alloc(size); +} + +template +void NaiveStorageManager::Free(void* ptr) { + DeviceStorage::Free(ptr); +} + +} // namespace storage +} // namespace mxnet + +#endif // MXNET_STORAGE_NAIVE_STORAGE_MANAGER_H_ diff --git a/src/storage/pooled_storage_manager.h b/src/storage/pooled_storage_manager.h new file mode 100644 index 000000000000..6420aa86125b --- /dev/null +++ b/src/storage/pooled_storage_manager.h @@ -0,0 +1,39 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file pooled_storage_manager.h + * \brief Storage manager with a memory pool. + */ +#ifndef MXNET_STORAGE_POOLED_STORAGE_MANAGER_H_ +#define MXNET_STORAGE_POOLED_STORAGE_MANAGER_H_ + +#include "./storage_manager.h" +#include "mxnet/base.h" + +namespace mxnet { +namespace storage { + +/*! + * \brief Storage manager with a memory pool. + */ +template +class PooledStorageManager final : public StorageManager { + public: + /*! + * \brief Default constructor. + */ + PooledStorageManager() = default; + /*! + * \brief Default destructor. + */ + ~PooledStorageManager() = default; + void* Alloc(size_t size) override; + void Free(void* ptr) override; + + private: + DISALLOW_COPY_AND_ASSIGN(PooledStorageManager); +}; + +} // namespace storage +} // namespace mxnet + +#endif // MXNET_STORAGE_POOLED_STORAGE_MANAGER_H_ diff --git a/src/storage/storage_manager.h b/src/storage/storage_manager.h new file mode 100644 index 000000000000..98129675ddf5 --- /dev/null +++ b/src/storage/storage_manager.h @@ -0,0 +1,39 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file storage_manager.h + * \brief Storage manager. + */ +#ifndef MXNET_STORAGE_STORAGE_MANAGER_H_ +#define MXNET_STORAGE_STORAGE_MANAGER_H_ + +#include + +namespace mxnet { +namespace storage { + +/*! + * \brief Storage manager interface. + */ +class StorageManager { + public: + /*! + * \brief Allocation. + * \param size Size to allocate. + * \return Pointer to the storage. + */ + virtual void* Alloc(size_t size) = 0; + /*! + * \brief Deallocation. + * \param ptr Pointer to deallocate. + */ + virtual void* Free(void* ptr) = 0; + /*! + * \brief Destructor. + */ + virtual ~StorageManager() = default; +}; + +} // namespace storage +} // namespace mxnet + +#endif // MXNET_STORAGE_STORAGE_MANAGER_H_ From c191d9beb012b1c5f5d487380184a5f88cfe8c82 Mon Sep 17 00:00:00 2001 From: Yutian Li Date: Mon, 10 Aug 2015 23:06:24 +0800 Subject: [PATCH 03/21] [storage] enable c11 feature --- src/storage/cpu_storage.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/storage/cpu_storage.cc b/src/storage/cpu_storage.cc index b19e67ef2e8e..75447ee4039c 100644 --- a/src/storage/cpu_storage.cc +++ b/src/storage/cpu_storage.cc @@ -5,6 +5,10 @@ #include #include +#ifndef _ISOC11_SOURCE +#define _ISOC11_SOURCE +#endif // _ISOC11_SOURCE + namespace mxnet { namespace storage { From 61ea32c6855058defbe60ba0091c946777702b62 Mon Sep 17 00:00:00 2001 From: Yutian Li Date: Mon, 10 Aug 2015 23:10:03 +0800 Subject: [PATCH 04/21] [storage] move define to top --- src/storage/cpu_storage.cc | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/storage/cpu_storage.cc b/src/storage/cpu_storage.cc index 75447ee4039c..315252de51d6 100644 --- a/src/storage/cpu_storage.cc +++ b/src/storage/cpu_storage.cc @@ -1,14 +1,11 @@ /*! * Copyright (c) 2015 by Contributors */ +#define _ISOC11_SOURCE #include "./cpu_storage.h" #include #include -#ifndef _ISOC11_SOURCE -#define _ISOC11_SOURCE -#endif // _ISOC11_SOURCE - namespace mxnet { namespace storage { From a3cabcd5916c6cb41e2ae7bb7f40c58865584f96 Mon Sep 17 00:00:00 2001 From: Yutian Li Date: Mon, 10 Aug 2015 23:17:46 +0800 Subject: [PATCH 05/21] [storage] try change header --- src/storage/cpu_storage.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/storage/cpu_storage.cc b/src/storage/cpu_storage.cc index 315252de51d6..15631c64eecb 100644 --- a/src/storage/cpu_storage.cc +++ b/src/storage/cpu_storage.cc @@ -4,7 +4,7 @@ #define _ISOC11_SOURCE #include "./cpu_storage.h" #include -#include +#include namespace mxnet { namespace storage { From b2bc69f988314fee2a2b02bf1ab501853191969c Mon Sep 17 00:00:00 2001 From: Yutian Li Date: Mon, 10 Aug 2015 23:23:30 +0800 Subject: [PATCH 06/21] [storage] use old function --- src/storage/cpu_storage.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/storage/cpu_storage.cc b/src/storage/cpu_storage.cc index 15631c64eecb..e7d23843f207 100644 --- a/src/storage/cpu_storage.cc +++ b/src/storage/cpu_storage.cc @@ -1,16 +1,15 @@ /*! * Copyright (c) 2015 by Contributors */ -#define _ISOC11_SOURCE #include "./cpu_storage.h" #include -#include +#include namespace mxnet { namespace storage { void* CpuStorage::Alloc(size_t size) { - return CHECK_NOTNULL(aligned_alloc(alignment_, size)); + return CHECK_NOTNULL(memalign(alignment_, size)); } void CpuStorage::Free(void* ptr) { free(ptr); } From 21674efec0dfb436b6d7c461d7474297b844793b Mon Sep 17 00:00:00 2001 From: Yutian Li Date: Mon, 10 Aug 2015 23:32:17 +0800 Subject: [PATCH 07/21] [storage] use if --- include/mxnet/cuda_utils.h | 4 ++-- src/storage/gpu_storage.cc | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/include/mxnet/cuda_utils.h b/include/mxnet/cuda_utils.h index 17c78aaf432c..5329649f63bb 100644 --- a/include/mxnet/cuda_utils.h +++ b/include/mxnet/cuda_utils.h @@ -8,7 +8,7 @@ #include -#ifdef MXNET_USE_CUDA +#if MXNET_USE_CUDA #include #include @@ -100,7 +100,7 @@ inline const char* CurandGetErrorString(curandStatus_t status) { #endif // MXNET_USE_CUDA -#ifdef MXNET_USE_CUDNN +#if MXNET_USE_CUDNN #include diff --git a/src/storage/gpu_storage.cc b/src/storage/gpu_storage.cc index cc30d9038644..d0f2b83b2815 100644 --- a/src/storage/gpu_storage.cc +++ b/src/storage/gpu_storage.cc @@ -3,7 +3,7 @@ */ #include "./gpu_storage.h" #include "mxnet/cuda_utils.h" -#ifdef MXNET_USE_CUDA +#if MXNET_USE_CUDA #include #endif // MXNET_USE_CUDA @@ -12,7 +12,7 @@ namespace storage { void* GpuStorage::Alloc(size_t size) { void* ret; -#ifdef MXNET_USE_CUDA +#if MXNET_USE_CUDA CUDA_CALL(cudaMalloc(&ret, size)); #else // MXNET_USE_CUDA LOG(FATAL) << "Please compile with CUDA enabled"; @@ -21,7 +21,7 @@ void* GpuStorage::Alloc(size_t size) { } void GpuStorage::Free(void* ptr) { -#ifdef MXNET_USE_CUDA +#if MXNET_USE_CUDA CUDA_CALL(cudaFree(ptr)); #else // MXNET_USE_CUDA LOG(FATAL) << "Please compile with CUDA enabled"; From 2214ce3e2cdf9d5b18919d1d875f5de014d46ca1 Mon Sep 17 00:00:00 2001 From: Yutian Li Date: Mon, 10 Aug 2015 23:51:44 +0800 Subject: [PATCH 08/21] [storage] storage managers --- src/storage/naive_storage_manager.h | 2 +- src/storage/pooled_storage_manager.h | 42 ++++++++++++++++++++++++++-- src/storage/storage_manager.h | 3 +- 3 files changed, 43 insertions(+), 4 deletions(-) diff --git a/src/storage/naive_storage_manager.h b/src/storage/naive_storage_manager.h index c0cad9355262..d145cada3f43 100644 --- a/src/storage/naive_storage_manager.h +++ b/src/storage/naive_storage_manager.h @@ -27,7 +27,7 @@ class NaiveStorageManager final : public StorageManager { */ ~NaiveStorageManager() = default; void* Alloc(size_t size) override; - void Free(void* ptr) override; + void Free(void* ptr, size_t) override; private: DISALLOW_COPY_AND_ASSIGN(NaiveStorageManager); diff --git a/src/storage/pooled_storage_manager.h b/src/storage/pooled_storage_manager.h index 6420aa86125b..fb168391dc5e 100644 --- a/src/storage/pooled_storage_manager.h +++ b/src/storage/pooled_storage_manager.h @@ -6,6 +6,8 @@ #ifndef MXNET_STORAGE_POOLED_STORAGE_MANAGER_H_ #define MXNET_STORAGE_POOLED_STORAGE_MANAGER_H_ +#include +#include #include "./storage_manager.h" #include "mxnet/base.h" @@ -15,7 +17,7 @@ namespace storage { /*! * \brief Storage manager with a memory pool. */ -template +template class PooledStorageManager final : public StorageManager { public: /*! @@ -27,12 +29,48 @@ class PooledStorageManager final : public StorageManager { */ ~PooledStorageManager() = default; void* Alloc(size_t size) override; - void Free(void* ptr) override; + void Free(void* ptr, size_t size) override; private: + void ReleaseAll(); + size_t used_memory_ = 0; + std::unordered_map> memory_pool_; DISALLOW_COPY_AND_ASSIGN(PooledStorageManager); }; +templace +void* PooledStorageManager::Alloc(size_t size) { + auto&& reuse_it = memory_pool_.find(size); + if (reuse_it == memory_pool_.end() || reuse_it->second.size() == 0) { + if (kThreshold <= used_memory_) { + ReleaseAll(); + } + used_memory_ += size; + return DeviceStorage::Alloc(size); + } else { + auto&& reuse_pool = reuse_it->second; + auto ret = reuse_pool.back(); + reuse_pool.pop_back(); + return ret; + } +} + +templace +void PooledStorageManager::Free(void* ptr, size_t size) { + auto&& reuse_pool = memory_pool_[size]; + reuse_pool.push_back(ptr); +} + +template +void PooledStorageManager::ReleaseAll() { + for (auto&& i : memory_pool_) { + for (auto&& j : i.second) { + DeviceStorage::Free(i.second); + used_memory_ -= i.first; + } + } +} + } // namespace storage } // namespace mxnet diff --git a/src/storage/storage_manager.h b/src/storage/storage_manager.h index 98129675ddf5..94967a210e80 100644 --- a/src/storage/storage_manager.h +++ b/src/storage/storage_manager.h @@ -25,8 +25,9 @@ class StorageManager { /*! * \brief Deallocation. * \param ptr Pointer to deallocate. + * \param size Size of the storage. */ - virtual void* Free(void* ptr) = 0; + virtual void* Free(void* ptr, size_t size) = 0; /*! * \brief Destructor. */ From 84bf795386613b53a333084f8c93db1475543859 Mon Sep 17 00:00:00 2001 From: Yutian Li Date: Tue, 11 Aug 2015 00:57:40 +0800 Subject: [PATCH 09/21] [storage] putting things together --- Makefile | 2 +- include/mxnet/narray.h | 10 +-- include/mxnet/storage.h | 56 +++++++++++----- src/storage/cpu_storage.h | 2 +- src/storage/gpu_storage.h | 2 +- src/storage/naive_storage_manager.h | 6 +- src/storage/pooled_storage_manager.h | 13 ++-- src/storage/storage.cc | 99 +++++++++++++++++++++------- src/storage/storage_manager.h | 4 +- 9 files changed, 132 insertions(+), 62 deletions(-) diff --git a/Makefile b/Makefile index e830c3eac7df..5ddb285f8c8e 100644 --- a/Makefile +++ b/Makefile @@ -58,7 +58,7 @@ endif BIN = test/api_registry_test OBJ = storage.o narray_op_cpu.o static_operator.o static_operator_cpu.o # add threaded engine after it is done -OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o operator.o fully_connect_op_cpu.o cpu_storage.o gpu_storage.o +OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o operator.o fully_connect_op_cpu.o cpu_storage.o gpu_storage.o storage.o CUOBJ = SLIB = lib/libmxnet.so ALIB = lib/libmxnet.a diff --git a/include/mxnet/narray.h b/include/mxnet/narray.h index 4e7b4448e667..9a73870c003b 100644 --- a/include/mxnet/narray.h +++ b/include/mxnet/narray.h @@ -126,14 +126,14 @@ class NArray { /*! \brief the real data chunk that backs NArray */ struct Chunk { /*! \brief storage handlefrom storage engine */ - StorageManager::Handle shandle; + Storage::Handle shandle; /*! \brief variable from DAG engine */ DAGEngine::Variable var; /*! \brief holds the data content */ TBlob data; /*! * \brief if this is true, this means the data do not come - * from StorageManager, and do not need to be freed + * from Storage, and do not need to be freed */ bool static_data; /*! \brief whether allocation is delayed */ @@ -161,7 +161,7 @@ class NArray { /*! \brief check if delay alloc is on, do alloc if not yet done */ inline void CheckAndAlloc(void) { if (delay_alloc) { - shandle = StorageManager::Get()->Alloc(data.shape_.Size() * sizeof(real_t), shandle.ctx); + shandle = Storage::Get()->Alloc(data.shape_.Size() * sizeof(real_t), shandle.ctx); data = TBlob(static_cast(shandle.dptr), data.shape_, shandle.ctx.dev_mask); delay_alloc = false; } @@ -172,9 +172,9 @@ class NArray { DAGEngine::Get()->PushDelete([](RunContext s) {}, shandle.ctx, var); } else { CHECK(!delay_alloc) << "deleted before allocation"; - StorageManager::Handle h = this->shandle; + Storage::Handle h = this->shandle; DAGEngine::Get()->PushDelete([h](RunContext s) { - StorageManager::Get()->Free(h); + Storage::Get()->Free(h); }, shandle.ctx, var); } } diff --git a/include/mxnet/storage.h b/include/mxnet/storage.h index a900537baff3..d9f4a9f17de0 100644 --- a/include/mxnet/storage.h +++ b/include/mxnet/storage.h @@ -1,49 +1,69 @@ /*! * Copyright (c) 2015 by Contributors * \file storage.h - * \brief the memory allocator that manages the memory across multiple devices + * \brief Storage manager across multiple devices. */ #ifndef MXNET_STORAGE_H_ #define MXNET_STORAGE_H_ + +#include #include "./base.h" #include "./tensor_blob.h" namespace mxnet { -/*! \brief memory allocator of storage */ -class StorageManager { +/*! + * \brief Storage manager across multiple devices. + */ +class Storage { public: /*! - * \brief storage handle the represents storage information + * \brief Storage handle. */ struct Handle { - /*! \brief pointer to the data */ + /*! + * \brief Pointer to the data. + */ void* dptr; - /*! \brief context information about device and deviceID */ + /*! + * \brief Size of the storage. + */ + size_t size; + /*! + * \brief Context information about device and ID. + */ Context ctx; }; /*! - * \brief allocate a new contiguous memory for a given size - * \param size the total size of memory in bytes - * \param ctx context information about the device and deviceID - * \return Handle struct + * \brief Allocate a new contiguous memory for a given size. + * \param size Total size of memory in bytes. + * \param ctx Context information about the device and ID. + * \return Handle struct. */ Handle Alloc(size_t size, Context ctx); /*! - * \brief free the space represened the handle - * \param handle the handle to memory to be freed + * \brief Free storage. + * \param handle Handle struect. */ void Free(Handle handle); - /*! \return storage manager singleton */ - static StorageManager* Get(); + /*! + * \brief Destructor. + */ + ~Storage(); + /*! + * \return Storage singleton. + */ + static Storage* Get(); private: /*! - * \brief disabled constructors + * \brief Hidden constructors. */ - StorageManager() {} - DISALLOW_COPY_AND_ASSIGN(StorageManager); -}; // class StorageManager + Storage(); + struct Impl; + std::unique_ptr impl_; + DISALLOW_COPY_AND_ASSIGN(Storage); +}; // class Storage } // namespace mxnet diff --git a/src/storage/cpu_storage.h b/src/storage/cpu_storage.h index 03aa0eee46ab..3c031959d9be 100644 --- a/src/storage/cpu_storage.h +++ b/src/storage/cpu_storage.h @@ -33,7 +33,7 @@ class CpuStorage { * \brief Alignment of allocation. */ static constexpr size_t alignment_ = 16; -}; +}; // class CpuStorage } // namespace storage } // namespace mxnet diff --git a/src/storage/gpu_storage.h b/src/storage/gpu_storage.h index 08a9befe41aa..5e0293f2edd6 100644 --- a/src/storage/gpu_storage.h +++ b/src/storage/gpu_storage.h @@ -27,7 +27,7 @@ class GpuStorage { * \param ptr Pointer to deallocate. */ static void Free(void* ptr); -}; +}; // class GpuStorage } // namespace storage } // namespace mxnet diff --git a/src/storage/naive_storage_manager.h b/src/storage/naive_storage_manager.h index d145cada3f43..3611b775ce74 100644 --- a/src/storage/naive_storage_manager.h +++ b/src/storage/naive_storage_manager.h @@ -31,15 +31,15 @@ class NaiveStorageManager final : public StorageManager { private: DISALLOW_COPY_AND_ASSIGN(NaiveStorageManager); -}; +}; // class NaiveStorageManager template -void* NaiveStorageManager::Alloc(size_t size) { +void* NaiveStorageManager::Alloc(size_t size) { return DeviceStorage::Alloc(size); } template -void NaiveStorageManager::Free(void* ptr) { +void NaiveStorageManager::Free(void* ptr, size_t) { DeviceStorage::Free(ptr); } diff --git a/src/storage/pooled_storage_manager.h b/src/storage/pooled_storage_manager.h index fb168391dc5e..b3a45b895985 100644 --- a/src/storage/pooled_storage_manager.h +++ b/src/storage/pooled_storage_manager.h @@ -36,14 +36,14 @@ class PooledStorageManager final : public StorageManager { size_t used_memory_ = 0; std::unordered_map> memory_pool_; DISALLOW_COPY_AND_ASSIGN(PooledStorageManager); -}; +}; // class PooledStorageManager -templace +template void* PooledStorageManager::Alloc(size_t size) { auto&& reuse_it = memory_pool_.find(size); if (reuse_it == memory_pool_.end() || reuse_it->second.size() == 0) { if (kThreshold <= used_memory_) { - ReleaseAll(); + ReleaseAll(); } used_memory_ += size; return DeviceStorage::Alloc(size); @@ -55,8 +55,9 @@ void* PooledStorageManager::Alloc(size_t size) { } } -templace -void PooledStorageManager::Free(void* ptr, size_t size) { +template +void PooledStorageManager::Free(void* ptr, + size_t size) { auto&& reuse_pool = memory_pool_[size]; reuse_pool.push_back(ptr); } @@ -65,7 +66,7 @@ template void PooledStorageManager::ReleaseAll() { for (auto&& i : memory_pool_) { for (auto&& j : i.second) { - DeviceStorage::Free(i.second); + DeviceStorage::Free(j); used_memory_ -= i.first; } } diff --git a/src/storage/storage.cc b/src/storage/storage.cc index 12e25e68ddaf..a5bfca9c9046 100644 --- a/src/storage/storage.cc +++ b/src/storage/storage.cc @@ -1,44 +1,93 @@ /*! * Copyright (c) 2015 by Contributors */ +#include "mxnet/storage.h" #include -#include +#include +#include "./storage_manager.h" +#include "./naive_storage_manager.h" +#include "./pooled_storage_manager.h" +#include "./cpu_storage.h" +#include "./gpu_storage.h" +#include "mxnet/cuda_utils.h" + namespace mxnet { -// class NaiveStorageManager : public StorageManager { -// public: -// virtual Handle Alloc(size_t size, Context ctx); -// virtual void Free(Handle handle); -// }; +struct Storage::Impl { + static constexpr size_t kPoolThreshold = 4096 * 1024 * 1024ul; + + template + using CurrentStorageManager = storage::PooledStorageManager; + + static void ActivateDevice(Context ctx) { + switch (ctx.dev_mask) { + case cpu::kDevMask: + break; + case gpu::kDevMask: +#if MXNET_USE_CUDA + CUDA_CALL(cudaSetDevice(ctx.dev_id)); +#else // MXNET_USE_CUDA + LOG(FATAL) << "Please compile with CUDA enabled"; +#endif // MXNET_USE_CUDA + break; + default: + LOG(FATAL) << "Unimplemented device"; + } + } + + std::unordered_map< + int, std::unordered_map>> + storage_managers; +}; // struct Storage::Impl -StorageManager::Handle StorageManager::Alloc(size_t size, Context ctx) { +Storage::Handle Storage::Alloc(size_t size, Context ctx) { Handle hd; hd.ctx = ctx; - if (ctx.dev_mask == cpu::kDevMask) { - hd.dptr = calloc(size, sizeof(real_t)); - } else { -#if MXNET_USE_CUDA - cudaMalloc(&hd.dptr, size); -#endif + auto&& device = impl_->storage_managers[ctx.dev_mask]; + auto&& device_id_it = device.find(ctx.dev_id); + // Allocate device if necessary. + if (device_id_it == device.end()) { + switch (ctx.dev_mask) { + case cpu::kDevMask: + device_id_it = + device.emplace(std::make_pair( + ctx.dev_id, + std::unique_ptr{ + new Storage::Impl::CurrentStorageManager< + storage::CpuStorage>{}})).first; + break; + case gpu::kDevMask: + device_id_it = + device.emplace(std::make_pair( + ctx.dev_id, + std::unique_ptr{ + new Storage::Impl::CurrentStorageManager< + storage::GpuStorage>{}})).first; + break; + default: + LOG(FATAL) << "Unimplemented device"; + } } + Impl::ActivateDevice(ctx); + hd.dptr = device_id_it->second->Alloc(size); + hd.size = size; return hd; } -void StorageManager::Free(StorageManager::Handle handle) { - if (handle.ctx.dev_mask == cpu::kDevMask) { - free(handle.dptr); - handle.dptr = NULL; - // cudaFreeHost(handle.dptr); - } else { -#if MXNET_USE_CUDA - cudaFree(handle.dptr); -#endif - } +void Storage::Free(Storage::Handle handle) { + Impl::ActivateDevice(handle.ctx); + impl_->storage_managers.at(handle.ctx.dev_mask) + .at(handle.ctx.dev_id) + ->Free(handle.dptr, handle.size); } -StorageManager* StorageManager::Get() { - static StorageManager inst; +Storage::~Storage() = default; + +Storage* Storage::Get() { + static Storage inst; return &inst; } +Storage::Storage() : impl_{new Impl{}} {} + } // namespace mxnet diff --git a/src/storage/storage_manager.h b/src/storage/storage_manager.h index 94967a210e80..3d264ab278ca 100644 --- a/src/storage/storage_manager.h +++ b/src/storage/storage_manager.h @@ -27,12 +27,12 @@ class StorageManager { * \param ptr Pointer to deallocate. * \param size Size of the storage. */ - virtual void* Free(void* ptr, size_t size) = 0; + virtual void Free(void* ptr, size_t size) = 0; /*! * \brief Destructor. */ virtual ~StorageManager() = default; -}; +}; // namespace StorageManager } // namespace storage } // namespace mxnet From 1ad28eb3a1f4da937f0e432c3b310e3299ac5f3e Mon Sep 17 00:00:00 2001 From: Yutian Li Date: Tue, 11 Aug 2015 01:01:43 +0800 Subject: [PATCH 10/21] [storage] fix Makefile --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 5ddb285f8c8e..f5c851b0d4fd 100644 --- a/Makefile +++ b/Makefile @@ -56,7 +56,7 @@ endif #BIN = test/test_threaded_engine test/api_registry_test BIN = test/api_registry_test -OBJ = storage.o narray_op_cpu.o static_operator.o static_operator_cpu.o +OBJ = narray_op_cpu.o static_operator.o static_operator_cpu.o # add threaded engine after it is done OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o operator.o fully_connect_op_cpu.o cpu_storage.o gpu_storage.o storage.o CUOBJ = From 36cde6073ee268de0616d193c4cbed1170d6119d Mon Sep 17 00:00:00 2001 From: Yutian Li Date: Sun, 16 Aug 2015 21:39:21 +0800 Subject: [PATCH 11/21] [storage] refactor things a bit --- Makefile | 4 +- {include/mxnet => src/common}/cuda_utils.h | 32 ++++++----- src/common/utils.h | 56 ++++++++++++++++++ .../{cpu_storage.h => cpu_device_storage.h} | 24 +++++--- src/storage/cpu_storage.cc | 18 ------ src/storage/gpu_device_storage.h | 57 +++++++++++++++++++ src/storage/gpu_storage.cc | 32 ----------- src/storage/gpu_storage.h | 35 ------------ src/storage/storage.cc | 46 +++++++-------- 9 files changed, 172 insertions(+), 132 deletions(-) rename {include/mxnet => src/common}/cuda_utils.h (79%) create mode 100644 src/common/utils.h rename src/storage/{cpu_storage.h => cpu_device_storage.h} (51%) delete mode 100644 src/storage/cpu_storage.cc create mode 100644 src/storage/gpu_device_storage.h delete mode 100644 src/storage/gpu_storage.cc delete mode 100644 src/storage/gpu_storage.h diff --git a/Makefile b/Makefile index f5c851b0d4fd..1ac4c583a384 100644 --- a/Makefile +++ b/Makefile @@ -58,7 +58,7 @@ endif BIN = test/api_registry_test OBJ = narray_op_cpu.o static_operator.o static_operator_cpu.o # add threaded engine after it is done -OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o operator.o fully_connect_op_cpu.o cpu_storage.o gpu_storage.o storage.o +OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o operator.o fully_connect_op_cpu.o storage.o CUOBJ = SLIB = lib/libmxnet.so ALIB = lib/libmxnet.a @@ -76,8 +76,6 @@ $(DMLC_CORE)/libdmlc.a: + cd $(DMLC_CORE); make libdmlc.a config=$(ROOTDIR)/$(config); cd $(ROOTDIR) storage.o: src/storage/storage.cc -cpu_storage.o: src/storage/cpu_storage.cc -gpu_storage.o: src/storage/gpu_storage.cc engine.o: src/dag_engine/simple_engine.cc #engine.o: src/dag_engine/threaded_engine.cc src/common/concurrent_blocking_queue.h src/common/spin_lock.h narray.o: src/narray/narray.cc diff --git a/include/mxnet/cuda_utils.h b/src/common/cuda_utils.h similarity index 79% rename from include/mxnet/cuda_utils.h rename to src/common/cuda_utils.h index 5329649f63bb..a2730481c828 100644 --- a/include/mxnet/cuda_utils.h +++ b/src/common/cuda_utils.h @@ -3,8 +3,8 @@ * \file cuda_utils.h * \brief CUDA debugging utilities. */ -#ifndef MXNET_CUDA_UTILS_H_ -#define MXNET_CUDA_UTILS_H_ +#ifndef MXNET_COMMON_CUDA_UTILS_H_ +#define MXNET_COMMON_CUDA_UTILS_H_ #include @@ -14,6 +14,9 @@ #include #include +namespace common { +namespace cuda { + inline const char* CublasGetErrorString(cublasStatus_t error) { switch (error) { case CUBLAS_STATUS_SUCCESS: @@ -72,6 +75,9 @@ inline const char* CurandGetErrorString(curandStatus_t status) { return "Unknown cuRAND status"; } +} // namespace cuda +} // namespace common + #define CHECK_CUDA_ERROR(msg) \ { \ cudaError_t e = cudaGetLastError(); \ @@ -84,18 +90,18 @@ inline const char* CurandGetErrorString(curandStatus_t status) { CHECK_EQ(e, cudaSuccess) << "CUDA: " << cudaGetErrorString(e); \ } -#define CUBLAS_CALL(func) \ - { \ - cublasStatus_t e = (func); \ - CHECK_EQ(e, CUBLAS_STATUS_SUCCESS) \ - << "cuBLAS: " << CublasGetErrorString(e); \ +#define CUBLAS_CALL(func) \ + { \ + cublasStatus_t e = (func); \ + CHECK_EQ(e, CUBLAS_STATUS_SUCCESS) \ + << "cuBLAS: " << common::cuda::CublasGetErrorString(e); \ } -#define CURAND_CALL(func) \ - { \ - curandStatus_t e = (func); \ - CHECK_EQ(e, CURAND_STATUS_SUCCESS) \ - << "cuRAND: " << CurandGetErrorString(e); \ +#define CURAND_CALL(func) \ + { \ + curandStatus_t e = (func); \ + CHECK_EQ(e, CURAND_STATUS_SUCCESS) \ + << "cuRAND: " << common::cuda::CurandGetErrorString(e); \ } #endif // MXNET_USE_CUDA @@ -112,4 +118,4 @@ inline const char* CurandGetErrorString(curandStatus_t status) { #endif // MXNET_USE_CUDNN -#endif // MXNET_CUDA_UTILS_H_ +#endif // MXNET_COMMON_CUDA_UTILS_H_ diff --git a/src/common/utils.h b/src/common/utils.h new file mode 100644 index 000000000000..b3c46897eb5d --- /dev/null +++ b/src/common/utils.h @@ -0,0 +1,56 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file utils.h + * \brief Basic utilility functions. + */ +#ifndef MXNET_COMMON_UTILS_H_ +#define MXNET_COMMON_UTILS_H_ + +#if DMLC_USE_CXX11 +#include +#include +#include +#endif // DMLC_USE_CXX11 + +namespace common { + +#if DMLC_USE_CXX11 + +namespace helper { + +template +struct UniqueIf { + using SingleObject = std::unique_ptr; +}; + +template +struct UniqueIf { + using UnknownBound = std::unique_ptr; +}; + +template +struct UniqueIf { + using KnownBound = void; +}; + +} // namespace helper + +template +typename helper::UniqueIf::SingleObject MakeUnique(Args&&... args) { + return std::unique_ptr(new T(std::forward(args)...)); +}; + +template +typename helper::UniqueIf::UnknownBound MakeUnique(size_t n) { + using U = typename std::remove_extent::type; + return std::unique_ptr(new U[n]{}); +} + +template +typename helper::UniqueIf::KnownBound MakeUnique(Args&&...) = delete; + +#endif // DMLC_USE_CXX11 + +} // namespace common + +#endif // MXNET_COMMON_UTILS_H_ diff --git a/src/storage/cpu_storage.h b/src/storage/cpu_device_storage.h similarity index 51% rename from src/storage/cpu_storage.h rename to src/storage/cpu_device_storage.h index 3c031959d9be..0b69b6b4fd8d 100644 --- a/src/storage/cpu_storage.h +++ b/src/storage/cpu_device_storage.h @@ -1,11 +1,13 @@ /*! * Copyright (c) 2015 by Contributors - * \file cpu_storage.h + * \file cpu_device_storage.h * \brief CPU storage implementation. */ -#ifndef MXNET_STORAGE_CPU_STORAGE_H_ -#define MXNET_STORAGE_CPU_STORAGE_H_ +#ifndef MXNET_STORAGE_CPU_DEVICE_STORAGE_H_ +#define MXNET_STORAGE_CPU_DEVICE_STORAGE_H_ +#include +#include #include "mxnet/base.h" namespace mxnet { @@ -14,28 +16,34 @@ namespace storage { /*! * \brief CPU storage implementation. */ -class CpuStorage { +class CPUDeviceStorage { public: /*! * \brief Aligned allocation on CPU. * \param size Size to allocate. * \return Pointer to the storage. */ - static void* Alloc(size_t size); + inline static void* Alloc(size_t size); /*! * \brief Deallocation. * \param ptr Pointer to deallocate. */ - static void Free(void* ptr); + inline static void Free(void* ptr); private: /*! * \brief Alignment of allocation. */ static constexpr size_t alignment_ = 16; -}; // class CpuStorage +}; // class CPUDeviceStorage + +inline void* CPUDeviceStorage::Alloc(size_t size) { + return CHECK_NOTNULL(memalign(alignment_, size)); +} + +inline void CPUDeviceStorage::Free(void* ptr) { free(ptr); } } // namespace storage } // namespace mxnet -#endif // MXNET_STORAGE_CPU_STORAGE_H_ +#endif // MXNET_STORAGE_CPU_DEVICE_STORAGE_H_ diff --git a/src/storage/cpu_storage.cc b/src/storage/cpu_storage.cc deleted file mode 100644 index e7d23843f207..000000000000 --- a/src/storage/cpu_storage.cc +++ /dev/null @@ -1,18 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - */ -#include "./cpu_storage.h" -#include -#include - -namespace mxnet { -namespace storage { - -void* CpuStorage::Alloc(size_t size) { - return CHECK_NOTNULL(memalign(alignment_, size)); -} - -void CpuStorage::Free(void* ptr) { free(ptr); } - -} // namespace storage -} // namespace mxnet diff --git a/src/storage/gpu_device_storage.h b/src/storage/gpu_device_storage.h new file mode 100644 index 000000000000..22956d192893 --- /dev/null +++ b/src/storage/gpu_device_storage.h @@ -0,0 +1,57 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file gpu_device_storage.h + * \brief GPU storage implementation. + */ +#ifndef MXNET_STORAGE_GPU_DEVICE_STORAGE_H_ +#define MXNET_STORAGE_GPU_DEVICE_STORAGE_H_ + +#include "mxnet/base.h" +#include "../common/cuda_utils.h" +#if MXNET_USE_CUDA +#include +#endif // MXNET_USE_CUDA + +namespace mxnet { +namespace storage { + +/*! + * \brief GPU storage implementation. + */ +class GPUDeviceStorage { + public: + /*! + * \brief Allocation. + * \param size Size to allocate. + * \return Pointer to the storage. + */ + inline static void* Alloc(size_t size); + /*! + * \brief Deallocation. + * \param ptr Pointer to deallocate. + */ + inline static void Free(void* ptr); +}; // class GPUDeviceStorage + +inline void* GPUDeviceStorage::Alloc(size_t size) { + void* ret; +#if MXNET_USE_CUDA + CUDA_CALL(cudaMalloc(&ret, size)); +#else // MXNET_USE_CUDA + LOG(FATAL) << "Please compile with CUDA enabled"; +#endif // MXNET_USE_CUDA + return ret; +} + +inline void GPUDeviceStorage::Free(void* ptr) { +#if MXNET_USE_CUDA + CUDA_CALL(cudaFree(ptr)); +#else // MXNET_USE_CUDA + LOG(FATAL) << "Please compile with CUDA enabled"; +#endif // MXNET_USE_CUDA +} + +} // namespace storage +} // namespace mxnet + +#endif // MXNET_STORAGE_GPU_DEVICE_STORAGE_H_ diff --git a/src/storage/gpu_storage.cc b/src/storage/gpu_storage.cc deleted file mode 100644 index d0f2b83b2815..000000000000 --- a/src/storage/gpu_storage.cc +++ /dev/null @@ -1,32 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - */ -#include "./gpu_storage.h" -#include "mxnet/cuda_utils.h" -#if MXNET_USE_CUDA -#include -#endif // MXNET_USE_CUDA - -namespace mxnet { -namespace storage { - -void* GpuStorage::Alloc(size_t size) { - void* ret; -#if MXNET_USE_CUDA - CUDA_CALL(cudaMalloc(&ret, size)); -#else // MXNET_USE_CUDA - LOG(FATAL) << "Please compile with CUDA enabled"; -#endif // MXNET_USE_CUDA - return ret; -} - -void GpuStorage::Free(void* ptr) { -#if MXNET_USE_CUDA - CUDA_CALL(cudaFree(ptr)); -#else // MXNET_USE_CUDA - LOG(FATAL) << "Please compile with CUDA enabled"; -#endif // MXNET_USE_CUDA -} - -} // namespace storage -} // namespace mxnet diff --git a/src/storage/gpu_storage.h b/src/storage/gpu_storage.h deleted file mode 100644 index 5e0293f2edd6..000000000000 --- a/src/storage/gpu_storage.h +++ /dev/null @@ -1,35 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file gpu_storage.h - * \brief GPU storage implementation. - */ -#ifndef MXNET_STORAGE_GPU_STORAGE_H_ -#define MXNET_STORAGE_GPU_STORAGE_H_ - -#include "mxnet/base.h" - -namespace mxnet { -namespace storage { - -/*! - * \brief GPU storage implementation. - */ -class GpuStorage { - public: - /*! - * \brief Allocation. - * \param size Size to allocate. - * \return Pointer to the storage. - */ - static void* Alloc(size_t size); - /*! - * \brief Deallocation. - * \param ptr Pointer to deallocate. - */ - static void Free(void* ptr); -}; // class GpuStorage - -} // namespace storage -} // namespace mxnet - -#endif // MXNET_STORAGE_GPU_STORAGE_H_ diff --git a/src/storage/storage.cc b/src/storage/storage.cc index a5bfca9c9046..b8e6669a100b 100644 --- a/src/storage/storage.cc +++ b/src/storage/storage.cc @@ -4,20 +4,25 @@ #include "mxnet/storage.h" #include #include +#include #include "./storage_manager.h" #include "./naive_storage_manager.h" #include "./pooled_storage_manager.h" -#include "./cpu_storage.h" -#include "./gpu_storage.h" -#include "mxnet/cuda_utils.h" +#include "./cpu_device_storage.h" +#include "./gpu_device_storage.h" +#include "../common/cuda_utils.h" +#include "../common/utils.h" namespace mxnet { struct Storage::Impl { static constexpr size_t kPoolThreshold = 4096 * 1024 * 1024ul; + static constexpr size_t kMaxNumberOfDevices = 2; + static constexpr size_t kMaxNumberOfDeviceIDs = 16; template - using CurrentStorageManager = storage::PooledStorageManager; + using CurrentStorageManager = + storage::PooledStorageManager; static void ActivateDevice(Context ctx) { switch (ctx.dev_mask) { @@ -35,41 +40,36 @@ struct Storage::Impl { } } - std::unordered_map< - int, std::unordered_map>> - storage_managers; + // std::unordered_map< + // int, std::unordered_map>> + // storage_managers; + std::array, + kMaxNumberOfDeviceIDs>, + kMaxNumberOfDevices> storage_managers; }; // struct Storage::Impl Storage::Handle Storage::Alloc(size_t size, Context ctx) { Handle hd; hd.ctx = ctx; - auto&& device = impl_->storage_managers[ctx.dev_mask]; - auto&& device_id_it = device.find(ctx.dev_id); + auto&& device = impl_->storage_managers.at(ctx.dev_mask); + auto&& device_id_it = device.at(ctx.dev_id); // Allocate device if necessary. - if (device_id_it == device.end()) { + if (!device_id_it) { switch (ctx.dev_mask) { case cpu::kDevMask: - device_id_it = - device.emplace(std::make_pair( - ctx.dev_id, - std::unique_ptr{ - new Storage::Impl::CurrentStorageManager< - storage::CpuStorage>{}})).first; + device_id_it = common::MakeUnique< + Storage::Impl::CurrentStorageManager>(); break; case gpu::kDevMask: - device_id_it = - device.emplace(std::make_pair( - ctx.dev_id, - std::unique_ptr{ - new Storage::Impl::CurrentStorageManager< - storage::GpuStorage>{}})).first; + device_id_it = common::MakeUnique< + Storage::Impl::CurrentStorageManager>(); break; default: LOG(FATAL) << "Unimplemented device"; } } Impl::ActivateDevice(ctx); - hd.dptr = device_id_it->second->Alloc(size); + hd.dptr = device_id_it->Alloc(size); hd.size = size; return hd; } From 1b3163c10b4a6e297ab753e75722f3d24fe303a8 Mon Sep 17 00:00:00 2001 From: Yutian Li Date: Sun, 16 Aug 2015 21:40:12 +0800 Subject: [PATCH 12/21] [storage] fix extra comma --- src/common/utils.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/common/utils.h b/src/common/utils.h index b3c46897eb5d..e10f240a31cc 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -38,7 +38,7 @@ struct UniqueIf { template typename helper::UniqueIf::SingleObject MakeUnique(Args&&... args) { return std::unique_ptr(new T(std::forward(args)...)); -}; +} template typename helper::UniqueIf::UnknownBound MakeUnique(size_t n) { From dda24c23faac638841eb5b5b7443dc6c225b45d4 Mon Sep 17 00:00:00 2001 From: Yutian Li Date: Sun, 16 Aug 2015 22:10:32 +0800 Subject: [PATCH 13/21] [storage] a simple test for storage --- Makefile | 3 ++- include/mxnet/storage.h | 4 +-- src/storage/naive_storage_manager.h | 2 +- src/storage/pooled_storage_manager.h | 2 +- src/storage/storage.cc | 12 ++++----- test/test_storage.cc | 39 ++++++++++++++++++++++++++++ 6 files changed, 51 insertions(+), 11 deletions(-) create mode 100644 test/test_storage.cc diff --git a/Makefile b/Makefile index 1ac4c583a384..948816967f83 100644 --- a/Makefile +++ b/Makefile @@ -55,7 +55,7 @@ ifneq ($(ADD_LDFLAGS), NONE) endif #BIN = test/test_threaded_engine test/api_registry_test -BIN = test/api_registry_test +BIN = test/api_registry_test test/test_storage OBJ = narray_op_cpu.o static_operator.o static_operator_cpu.o # add threaded engine after it is done OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o operator.o fully_connect_op_cpu.o storage.o @@ -96,6 +96,7 @@ lib/libmxnet.a: $(OBJ) $(OBJCXX11) $(CUOBJ) lib/libmxnet.so: $(OBJ) $(OBJCXX11) $(CUOBJ) test/api_registry_test: test/api_registry_test.cc lib/libmxnet.a +test/test_storage: test/test_storage.cc lib/libmxnet.a #test/test_threaded_engine: test/test_threaded_engine.cc api/libmxnet.a $(BIN) : diff --git a/include/mxnet/storage.h b/include/mxnet/storage.h index d9f4a9f17de0..9c65d954d5c5 100644 --- a/include/mxnet/storage.h +++ b/include/mxnet/storage.h @@ -7,8 +7,8 @@ #define MXNET_STORAGE_H_ #include -#include "./base.h" -#include "./tensor_blob.h" +#include "base.h" +#include "tensor_blob.h" namespace mxnet { diff --git a/src/storage/naive_storage_manager.h b/src/storage/naive_storage_manager.h index 3611b775ce74..a476f5ea2acc 100644 --- a/src/storage/naive_storage_manager.h +++ b/src/storage/naive_storage_manager.h @@ -6,7 +6,7 @@ #ifndef MXNET_STORAGE_NAIVE_STORAGE_MANAGER_H_ #define MXNET_STORAGE_NAIVE_STORAGE_MANAGER_H_ -#include "./storage_manager.h" +#include "storage_manager.h" #include "mxnet/base.h" namespace mxnet { diff --git a/src/storage/pooled_storage_manager.h b/src/storage/pooled_storage_manager.h index b3a45b895985..c7e1e0cde3c2 100644 --- a/src/storage/pooled_storage_manager.h +++ b/src/storage/pooled_storage_manager.h @@ -8,7 +8,7 @@ #include #include -#include "./storage_manager.h" +#include "storage_manager.h" #include "mxnet/base.h" namespace mxnet { diff --git a/src/storage/storage.cc b/src/storage/storage.cc index b8e6669a100b..98c7981b6a88 100644 --- a/src/storage/storage.cc +++ b/src/storage/storage.cc @@ -5,11 +5,11 @@ #include #include #include -#include "./storage_manager.h" -#include "./naive_storage_manager.h" -#include "./pooled_storage_manager.h" -#include "./cpu_device_storage.h" -#include "./gpu_device_storage.h" +#include "storage_manager.h" +#include "naive_storage_manager.h" +#include "pooled_storage_manager.h" +#include "cpu_device_storage.h" +#include "gpu_device_storage.h" #include "../common/cuda_utils.h" #include "../common/utils.h" @@ -17,7 +17,7 @@ namespace mxnet { struct Storage::Impl { static constexpr size_t kPoolThreshold = 4096 * 1024 * 1024ul; - static constexpr size_t kMaxNumberOfDevices = 2; + static constexpr size_t kMaxNumberOfDevices = 3; static constexpr size_t kMaxNumberOfDeviceIDs = 16; template diff --git a/test/test_storage.cc b/test/test_storage.cc new file mode 100644 index 000000000000..33995a055dc5 --- /dev/null +++ b/test/test_storage.cc @@ -0,0 +1,39 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file test_storage.cc + * \brief Test for storage. + */ +#include +#include +#include "mxnet/storage.h" + +int main() { + constexpr size_t kSize = 1024; + auto&& storage = mxnet::Storage::Get(); + mxnet::Context context_cpu{}; + auto&& handle = storage->Alloc(kSize, context_cpu); + assert(handle.ctx == context_cpu); + assert(handle.size == kSize); + auto ptr = handle.dptr; + storage->Free(handle); + handle = storage->Alloc(kSize, context_cpu); + assert(handle.ctx == context_cpu); + assert(handle.size == kSize); + assert(handle.dptr == ptr); + printf("Success on CPU!\n"); + +#if MXNET_USE_CUDA + mxnet::Context context_gpu{mxnet::gpu::kDevMask, 0}; + handle = storage->Alloc(kSize, context_gpu); + assert(handle.ctx == context_gpu); + assert(handle.size == kSize); + ptr = handle.dptr; + storage->Free(handle); + handle = storage->Alloc(kSize, context_gpu); + assert(handle.ctx == context_gpu); + assert(handle.size == kSize); + assert(handle.dptr == ptr); + printf("Success on GPU!\n"); +#endif // MXNET_USE_CUDA + return 0; +} From 7c85007e1ac52ba6017b1cb4b63a2d5430cf0af9 Mon Sep 17 00:00:00 2001 From: Yutian Li Date: Sun, 16 Aug 2015 22:14:57 +0800 Subject: [PATCH 14/21] [storage] Merge branch 'master' of github.com:dmlc/mxnet --- Makefile | 15 +- include/mxnet/atomic_symbol.h | 83 --- include/mxnet/base.h | 36 +- include/mxnet/c_api.h | 113 ++-- include/mxnet/context.h | 80 +++ include/mxnet/dag_engine.h | 2 +- include/mxnet/narray.h | 4 +- include/mxnet/operator.h | 364 ++++++++++-- include/mxnet/registry.h | 62 +-- include/mxnet/static_graph.h | 59 -- include/mxnet/static_operator.h | 73 --- include/mxnet/storage.h | 2 +- include/mxnet/symbol.h | 145 ----- include/mxnet/symbolic.h | 311 +++++++++++ include/mxnet/tensor_blob.h | 53 -- make/config.mk | 2 +- python/mxnet/symbol.py | 109 +++- python/mxnet/symbol_creator.py | 106 +++- python/test_python.py | 2 +- python/test_symbol.py | 27 + src/c_api.cc | 169 ++++-- src/narray/narray_op.h | 2 +- src/operator/fully_connected-inl.h | 172 ++++++ src/operator/fully_connected.cc | 22 + src/operator/fully_connected.cu | 14 + .../operator_common.h} | 49 +- src/{static_operator => operator}/param.h | 17 +- .../static_operator/activation_op-inl.h | 9 +- .../static_operator/convolution_op-inl.h | 9 +- .../static_operator/dropout_op-inl.h | 9 +- .../static_operator/mshadow_op.h | 6 +- .../static_operator/pooling_op-inl.h | 9 +- .../static_operator/reshape_op-inl.h | 9 +- src/operator/static_operator_wrapper.cc | 97 ---- src/registry.cc | 24 +- src/static_operator/fully_connect_op-inl.h | 163 ------ src/static_operator/fully_connect_op.cc | 33 -- src/static_operator/fully_connect_op.cu | 14 - src/static_operator/static_operator-inl.h | 49 -- src/static_operator/static_operator.cc | 44 -- src/static_operator/static_operator_cpu.cc | 20 - src/static_operator/static_operator_gpu.cu | 23 - src/symbol/static_graph.cc | 87 +++ src/symbol/symbol.cc | 517 +++++++++++++----- 44 files changed, 1901 insertions(+), 1314 deletions(-) delete mode 100644 include/mxnet/atomic_symbol.h create mode 100644 include/mxnet/context.h delete mode 100644 include/mxnet/static_graph.h delete mode 100644 include/mxnet/static_operator.h delete mode 100644 include/mxnet/symbol.h create mode 100644 include/mxnet/symbolic.h delete mode 100644 include/mxnet/tensor_blob.h create mode 100644 python/test_symbol.py create mode 100644 src/operator/fully_connected-inl.h create mode 100644 src/operator/fully_connected.cc create mode 100644 src/operator/fully_connected.cu rename src/{static_operator/static_operator_common.h => operator/operator_common.h} (53%) rename src/{static_operator => operator}/param.h (90%) rename src/{ => operator}/static_operator/activation_op-inl.h (88%) rename src/{ => operator}/static_operator/convolution_op-inl.h (97%) rename src/{ => operator}/static_operator/dropout_op-inl.h (91%) rename src/{ => operator}/static_operator/mshadow_op.h (92%) rename src/{ => operator}/static_operator/pooling_op-inl.h (95%) rename src/{ => operator}/static_operator/reshape_op-inl.h (89%) delete mode 100644 src/operator/static_operator_wrapper.cc delete mode 100644 src/static_operator/fully_connect_op-inl.h delete mode 100644 src/static_operator/fully_connect_op.cc delete mode 100644 src/static_operator/fully_connect_op.cu delete mode 100644 src/static_operator/static_operator-inl.h delete mode 100644 src/static_operator/static_operator.cc delete mode 100644 src/static_operator/static_operator_cpu.cc delete mode 100644 src/static_operator/static_operator_gpu.cu create mode 100644 src/symbol/static_graph.cc diff --git a/Makefile b/Makefile index 948816967f83..fdda325b8240 100644 --- a/Makefile +++ b/Makefile @@ -56,16 +56,16 @@ endif #BIN = test/test_threaded_engine test/api_registry_test BIN = test/api_registry_test test/test_storage -OBJ = narray_op_cpu.o static_operator.o static_operator_cpu.o +OBJ = narray_op_cpu.o # add threaded engine after it is done -OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o operator.o fully_connect_op_cpu.o storage.o +OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o operator.o fully_connect_op_cpu.o static_graph.o storage.o CUOBJ = SLIB = lib/libmxnet.so ALIB = lib/libmxnet.a LIB_DEP = $(DMLC_CORE)/libdmlc.a ifeq ($(USE_CUDA), 1) - CUOBJ += narray_op_gpu.o static_operator_gpu.o fully_connect_op_gpu.o + CUOBJ += narray_op_gpu.o fully_connected_gpu.o endif .PHONY: clean all test lint doc @@ -77,19 +77,16 @@ $(DMLC_CORE)/libdmlc.a: storage.o: src/storage/storage.cc engine.o: src/dag_engine/simple_engine.cc -#engine.o: src/dag_engine/threaded_engine.cc src/common/concurrent_blocking_queue.h src/common/spin_lock.h narray.o: src/narray/narray.cc narray_op_cpu.o: src/narray/narray_op_cpu.cc src/narray/narray_op-inl.h narray_op_gpu.o: src/narray/narray_op_gpu.cu src/narray/narray_op-inl.h -static_operator.o: src/static_operator/static_operator.cc -static_operator_cpu.o: src/static_operator/static_operator_cpu.cc -static_operator_gpu.o: src/static_operator/static_operator_gpu.cu symbol.o: src/symbol/symbol.cc +static_graph.o : src/symbol/static_graph.cc registry.o: src/registry.cc c_api.o: src/c_api.cc operator.o: src/operator/static_operator_wrapper.cc -fully_connect_op_cpu.o: src/static_operator/fully_connect_op.cc -fully_connect_op_gpu.o: src/static_operator/fully_connect_op.cu +fully_connected_cpu.o: src/operator/fully_connected.cc +fully_connected_gpu.o: src/operator/fully_connected.cu lib/libmxnet.a: $(OBJ) $(OBJCXX11) $(CUOBJ) diff --git a/include/mxnet/atomic_symbol.h b/include/mxnet/atomic_symbol.h deleted file mode 100644 index 54f8223e80a3..000000000000 --- a/include/mxnet/atomic_symbol.h +++ /dev/null @@ -1,83 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file atomic_symbol.h - * \brief atomic symbol interface of mxnet - */ -#ifndef MXNET_ATOMIC_SYMBOL_H_ -#define MXNET_ATOMIC_SYMBOL_H_ - -#include -#include -#include -#include -#include "./base.h" -#include "./tensor_blob.h" - -namespace mxnet { -// forward declare StaticOperator -class StaticOperator; -/*! - * \brief AtomicSymbol is the base class of all atomic symbols. - * This is not meant to be used by user, it should be wrapped in Symbol, so that the same instance - * of AtomicSymbol can be shared in the graphs of different Symbols - */ -class AtomicSymbol { - public: - /*! - * \brief virtual destructor - */ - virtual ~AtomicSymbol() {} - /*! \brief get the descriptions of inputs for this symbol */ - virtual std::vector DescribeArguments() const { - // default implementation returns "data" - return std::vector(1, std::string("data")); - } - /*! \brief get the descriptions of outputs for this symbol */ - virtual std::vector DescribeReturns() const { - // default implementation returns "output" - return std::vector(1, std::string("output")); - } - /*! - * \brief set param for the symbol from string - * \param name parameter name - * \param val string for the configuration - */ - virtual void SetParam(const char *name, const char *val) {} - /*! - * \brief infer the shapes of outputs and unknown input arguments - * \param in_shape the shape of input arguments of the operator - * this should be of same length as the vector returned by DescribeArgs - * in_shape allows unknown elements, which are checked by shape.ndim() == 0. - * For unknown shapes, InferShape will try to fill in the correct Shape in in_shape - * For known shapes, InferShape will check shape consistency - * - * common practice: set the shape of data input, and usually weight's shape can be infered - * - * \param out_shape the shape of outputs of the operator - * InferShape will modify the vector to fill output TShape - * \return if the shape inference is successful, return true, else return false. - */ - virtual bool InferShape(std::vector *in_shape, std::vector *out_shape) const = 0; - /*! - * \brief Copy this AtomicSymbol and returns a pointer to the copied object. - * this is a virtual function because different subclass of AtomicSymbol would copy differently. - * \return a pointer to the copied atomic symbol - */ - virtual AtomicSymbol* Copy() const = 0; - /*! - * \brief Bind this AtomicSymbol to a context and get back a static operator - * Bind function of AtomicSymbol does not return NArrayOperator, but static operator. - * Calling bind from the Symbol wrapper would generate a NArrayOperator. - */ - template - StaticOperator* Bind(Context ctx) const; - /*! - * \brief return the type string of the atomic symbol - * subclasses override this function. - */ - virtual std::string TypeString() const = 0; - friend class Symbol; -}; - -} // namespace mxnet -#endif // MXNET_ATOMIC_SYMBOL_H_ diff --git a/include/mxnet/base.h b/include/mxnet/base.h index 67c3a1b24b74..fe260e082148 100644 --- a/include/mxnet/base.h +++ b/include/mxnet/base.h @@ -41,38 +41,10 @@ typedef mshadow::index_t index_t; /*! \brief data type that will be used to store ndarray */ typedef mshadow::default_real_t real_t; -/*! \brief option to pass into the forward function */ -struct Option { - /*! \brief whether it is training phase*/ - int is_train; -}; -/*! \brief gradient request type the request can have */ -enum GradReqType { - /*! \brief no operation, do not write gradient */ - kNullOp = 0, - /*! \brief write gradient to provided space */ - kWriteTo = 1, - /*! \brief same as kWriteTo, but provided space is same as space of input-data */ - kWriteInplace = 2, - /*! \brief add to the provided space */ - kAddTo = 3 -}; -/*! \brief input argument type of the operator have */ -enum ArgType { - /*! \brief data argument */ - kDataArg = 0, - /*! \brief weight argument */ - kWeightArg = 1, - /*! \brief bias argument */ - kBiasArg = 2 -}; -/*! \brief Property for engine schedule */ -enum Property { - /*! \brief Op contains interanl state, won't influence engine schedule */ - kContainInteralState = 1, - /*! \brief Op forward require random number, will influence engine schedule */ - kForwardRequireRnd = 2, -}; +/*! \brief dynamic shape type */ +typedef mshadow::TShape TShape; +/*! \brief storage container type */ +typedef mshadow::TBlob TBlob; } // namespace mxnet #endif // MXNET_BASE_H_ diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 4f0a9ea5f87a..a9a15c4a8007 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -209,15 +209,23 @@ MXNET_DLL int MXFuncInvoke(FunctionHandle fun, // Part 3: symbolic configuration generation //-------------------------------------------- /*! - * \brief create symbol from config - * \param cfg configuration string - * \param out created symbol handle + * \brief list all the available AtomicSymbolEntry + * \param out_size the size of returned array + * \param out_array the output AtomicSymbolCreator array * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXSymbolCreateFromConfig(const char *cfg, - SymbolHandle *out); +MXNET_DLL int MXSymbolListAtomicSymbolCreators(mx_uint *out_size, + AtomicSymbolCreator **out_array); /*! - * \brief create Symbol by wrapping AtomicSymbol + * \brief Get the name of AtomicSymbol. + * \param creator the AtomicSymbolCreator + * \param out the returned name of the creator + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator, + const char **out); +/*! + * \brief Create an AtomicSymbol. * \param creator the AtomicSymbolCreator * \param num_param the number of parameters * \param keys the keys to the params @@ -225,56 +233,95 @@ MXNET_DLL int MXSymbolCreateFromConfig(const char *cfg, * \param out pointer to the created symbol handle * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXSymbolCreateFromAtomicSymbol(AtomicSymbolCreator creator, - int num_param, - const char **keys, - const char **vals, - SymbolHandle *out); +MXNET_DLL int MXSymbolCreateAtomicSymbol(AtomicSymbolCreator creator, + int num_param, + const char **keys, + const char **vals, + SymbolHandle *out); +/*! + * \brief Create a Variable Symbol. + * \param name name of the variable + * \param out pointer to the created symbol handle + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXSymbolCreateVariable(const char *name, SymbolHandle *out); +/*! + * \brief Create a Symbol by grouping list of symbols together + * \param num_symbols number of symbols to be grouped + * \param symbols array of symbol handles + * \param out pointer to the created symbol handle + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXSymbolCreateGroup(mx_uint num_symbols, + SymbolHandle *symbols, + SymbolHandle *out); +/*! + * \brief Create symbol from config. + * \param cfg configuration string + * \param out created symbol handle + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXSymbolCreateFromConfig(const char *cfg, + SymbolHandle *out); /*! - * \brief free the symbol handle + * \brief Free the symbol handle. * \param symbol the symbol * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXSymbolFree(SymbolHandle symbol); /*! - * \brief list all the available AtomicSymbolEntry - * \param out_size the size of returned array - * \param out_array the output AtomicSymbolCreator array + * \brief Copy the symbol to another handle + * \param symbol the source symbol + * \param out used to hold the result of copy * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXSymbolListAtomicSymbolCreators(mx_uint *out_size, - AtomicSymbolCreator **out_array); +MXNET_DLL int MXSymbolCopy(SymbolHandle symbol, SymbolHandle *out); /*! - * \brief get the singleton Symbol of the AtomicSymbol if any - * \param creator the AtomicSymbolCreator - * \param out the returned singleton Symbol of the AtomicSymbol the creator stands for + * \brief Print the content of symbol, used for debug. + * \param symbol the symbol + * \param out_str pointer to hold the output string of the printing. * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXSymbolGetSingleton(AtomicSymbolCreator creator, - SymbolHandle *out); +MXNET_DLL int MXSymbolPrint(SymbolHandle symbol, const char **out_str); /*! - * \brief get the singleton Symbol of the AtomicSymbol if any - * \param creator the AtomicSymbolCreator - * \param out the returned name of the creator + * \brief List arguments in the symbol. + * \param symbol the symbol + * \param out_size output size + * \param out_str_array pointer to hold the output string array * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator, - const char **out); +MXNET_DLL int MXSymbolListArguments(SymbolHandle symbol, + mx_uint *out_size, + const char ***out_str_array); +/*! + * \brief List returns in the symbol. + * \param symbol the symbol + * \param out_size output size + * \param out_str_array pointer to hold the output string array + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXSymbolListReturns(SymbolHandle symbol, + mx_uint *out_size, + const char ***out_str_array); /*! - * \brief compose the symbol on other symbol + * \brief Compose the symbol on other symbols. + * + * This function will change the sym hanlde. + * To achieve function apply behavior, copy the symbol first + * before apply. + * * \param sym the symbol to apply + * \param name the name of symbol * \param num_args number of arguments * \param keys the key of keyword args (optional) * \param args arguments to sym - * \param out the resulting symbol * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXSymbolCompose(SymbolHandle sym, + const char *name, mx_uint num_args, const char** keys, - SymbolHandle* args, - SymbolHandle* out); - + SymbolHandle* args); //-------------------------------------------- // Part 4: operator interface on NArray //-------------------------------------------- @@ -338,6 +385,7 @@ MXNET_DLL int MXOpForward(OperatorHandle op, * \param op the operator handle * \param grad_next array of output gradients * \param in_data array of input narray to the operator + * \param out_data array of output narray to the operator * \param out_grad array to holds the gradient on these input * can be NULL if that position request is kNullOp * \param reqs gradient request type @@ -347,6 +395,7 @@ MXNET_DLL int MXOpForward(OperatorHandle op, MXNET_DLL int MXOpBackward(OperatorHandle op, NArrayHandle *grad_next, NArrayHandle *in_data, + NArrayHandle *out_data, NArrayHandle *out_grad, mx_uint *reqs); diff --git a/include/mxnet/context.h b/include/mxnet/context.h new file mode 100644 index 000000000000..262ba2e787d4 --- /dev/null +++ b/include/mxnet/context.h @@ -0,0 +1,80 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file context.h + * \brief Context information and resources in mxnet. + */ +#ifndef MXNET_CONTEXT_H_ +#define MXNET_CONTEXT_H_ + +namespace mxnet { + +/*! \brief Context information about the execution enviroment */ +struct Context { + /*! \brief the device type we run the op can be cpu::kDevMask or gpu::kDevMask */ + int dev_mask; + /*! \brief device id we are going to run it on */ + int dev_id; + /*! \brief constructor */ + Context() : dev_mask(cpu::kDevMask), dev_id(0) {} + /*! + * \brief constructor of context + * \param dev_mask the device mask + * \param dev_id the device id + */ + Context(int dev_mask, int dev_id) + : dev_mask(dev_mask), dev_id(dev_id) {} + /*! + * \brief check if current context equals another one + * \param b another context to compare + * \return whether dev mask and id are same + */ + inline bool operator==(const Context &b) const { + return dev_mask == b.dev_mask && dev_id == b.dev_id; + } +}; + +/*! + * \brief execution time context. + * The information needed in runtime for actual execution. + */ +struct RunContext { + /*! + * \brief the stream of the device, can be NULL or Stream* in GPU mode + */ + void *stream; +}; + +/*! + * \brief Additional resources + */ +struct Resource { + /*! \brief Resource type, indicating what the pointer type is */ + enum Type { + /*! \brief mshadow::Random object */ + kRandom, + /*! \brief Temporal space */ + kTempSpace + }; + /*! \brief pointer to the resource */ + void *ptr; +}; + +/*! + * \brief The resources that can be requested by Operator + */ +struct ResourceRequest { + /*! \brief type of resources */ + Resource::Type type; + /*! \brief size requirment if it is an temp space request */ + size_t space_size; + /*! \brief default constructor */ + ResourceRequest() {} + /*! + * \brief default constructor, allow implicit conversion + * \param type type of resources + */ + ResourceRequest(Resource::Type type) : type(type) {} // NOLINT(*) +}; + +} // namespace mxnet +#endif // MXNET_CONTEXT_H_ diff --git a/include/mxnet/dag_engine.h b/include/mxnet/dag_engine.h index cf4008f9eb95..18b804b5a2d8 100644 --- a/include/mxnet/dag_engine.h +++ b/include/mxnet/dag_engine.h @@ -15,7 +15,7 @@ #include #include #include "./base.h" -#include "./tensor_blob.h" +#include "./context.h" namespace mxnet { /*! diff --git a/include/mxnet/narray.h b/include/mxnet/narray.h index 9a73870c003b..92257b3f0269 100644 --- a/include/mxnet/narray.h +++ b/include/mxnet/narray.h @@ -5,12 +5,14 @@ */ #ifndef MXNET_NARRAY_H_ #define MXNET_NARRAY_H_ + #include #include #include #include "./base.h" +#include "./context.h" #include "./storage.h" -#include "./tensor_blob.h" +#include "./context.h" #include "./dag_engine.h" // check c++11 #if DMLC_USE_CXX11 == 0 diff --git a/include/mxnet/operator.h b/include/mxnet/operator.h index 97f8ca035ebd..0299ef2bf167 100644 --- a/include/mxnet/operator.h +++ b/include/mxnet/operator.h @@ -1,78 +1,334 @@ /*! * Copyright (c) 2015 by Contributors * \file operator.h - * \brief operator interface of mxnet + * \brief Operator interface of mxnet. * \author Naiyan Wang */ #ifndef MXNET_OPERATOR_H_ #define MXNET_OPERATOR_H_ -// this file will be seen by cuda, no c++11 for now + #include #include +#include +#include #include "./base.h" -#include "./tensor_blob.h" -#include "./static_operator.h" -#include "./narray.h" -#include "./dag_engine.h" +#include "./context.h" namespace mxnet { +/*! \brief operation request type to Forward and Backward */ +enum OpReqType { + /*! \brief no operation, do not write anything */ + kNullOp, + /*! \brief write gradient to provided space */ + kWriteTo, + /*! + * \brief perform an inplace write, + * Target shares memory with one of input arguments. + * This option only happen when + */ + kWriteInplace, + /*! \brief add to the provided space */ + kAddTo +}; + /*! - * \brief operator interface - * operator is an object can be scheduled by DAG engine directly. + * \brief All the possible information needed by Operator.Forward and Backward + * This is the superset of RunContext. + * We use this data structure to bookkeep everything needed by Forward and Backward. + * \sa Resource + */ +struct OpContext { + /*! \brief whether it is training phase */ + int is_train; + /*! \brief Stream we are running on */ + void *stream; + /*! \brief Resources requested by the operator */ + std::vector requested; + /*! + * \brief set the RunContext related parts + * \param ctx the context + */ + inline void SetRunContext(const RunContext &ctx) { + stream = ctx.stream; + } +}; + +/*! + * \brief Operator interface. + * Operator defins basic operation unit of optimized computation graph in mxnet. + * This interface relies on pre-allocated memory in TBlob, the caller need to set + * the memory region in TBlob correctly before calling Forward and Backward. + * + * Operator is generated by OperatorProperty. + * To add new operator(aka. layers of neural nets) to mxnet, developer need to create + * a new OperatorProperty and its corresponding Operator. * - * This interface relies on NArray. The user should prepare the input NArray and - * output NArray by themselves. - * \sa Operator + * \sa TBlob, TShape, OperatorProperty */ class Operator { public: /*! \brief destructor */ virtual ~Operator() {} /*! - * \brief describe property of op - * \return a bit map in int + * \brief perform a forward operation of Operator, save the output to TBlob. + * \param ctx runtime context available to this call + * \param in_data array of input data, it is const + * \param req the request types of saving operation, can only be kWriteTo or kWriteInplace. + * \param out_data array of output data, pointer is used to indicate that this is holder + * the space of TBlob in out_data must be pre-allocated with InferShape + * \sa OpReqType, OpContext */ - virtual int DescribeProperty() const = 0; + virtual void Forward(const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data) = 0; /*! - * \brief perform a forward operation of operator, save the output to NArray - * This method only pushes an execution request to the DAG engine, and - * return immediately. Actual execution is conducted by the DAG engine. - * \param opt option on Forward such as whether this is training phase - * \param ctx runtime context - * \param in_data array of input data, it is const - * \param out_data array of output data, - * the space of NArray in out_data must be pre-allocated with InferShape - * \sa NArray - */ - virtual void Forward(Option opt, - RunContext ctx, - const std::vector &in_data, - const std::vector &out_data) = 0; - /*! - * \brief perform a backward operation of the operator to get the gradient - * This method only pushes an execution request to the DAG engine, and - * return immediately. Actual execution is conducted by the DAG engine. - * \param ctx runtime context - * \param grad_next the gradient value of the output of the operator, used by chain rule. - * \param in_data the array of input data - * \param out_grad array of output gradient - * \param req request types of the gradient saving operation - * only inplace will change input data - * \sa GradReqType, NArray - */ - virtual void Backward(RunContext ctx, - const std::vector &grad_next, - const std::vector &in_data, - const std::vector &out_grad, - const std::vector &req) = 0; - /*! - * \brief Create a wrapper of static operator to wrap it into Operator. - * This function takes ownership of op - * \param op static operator to wrap from - * \param ctx context of the created operator - * \return a wrapper operator - */ - static Operator *CreateWrapper(StaticOperator *op, Context ctx); -}; // class operator + * \brief Perform a Backward Operation, write gradient to the in_grad. + * \param ctx runtime context available to this call + * \param out_grad the gradient value we get from output of the Operator + * \param in_data the array of input data. + * \param out_data the array of output data. + * \param req request types of the saving operation, can be all types. + * \param in_grad the array of gradient we need to write to. + * \sa OpReqType, OpContext + */ + virtual void Backward(const OpContext &ctx, + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &req, + const std::vector &in_grad) = 0; +}; + +#if DMLC_USE_CXX11 +// OperatorProperty allows C++11, while Operator do not rely on it. +/*! + * \brief OperatorProperty is a object that stores all information about Operator. + * It also contains method to generate context(device) specific operators. + * + * It also contains various functions that can be optimally overriden to + * provide optimization chance for computation engine. + */ +class OperatorProperty { + public: + /*! + * \brief virtual destructor + */ + virtual ~OperatorProperty() {} + /*! + * \brief Get input arguments of the Operator. + * \return vector of arguments. + */ + virtual std::vector ListArguments() const { + return {"data"}; + } + /*! + * \brief Get name of return values of Operator + * \return name of return values. + */ + virtual std::vector ListReturns() const { + return {"output"}; + } + /*! \return number of real return values of the Operator */ + virtual int NumReturns() const { + return 1; + } + /*! + * \brief get number of visible return values during Symbol creation. + * If NumVisibleReturns() = k, and NumReturns() = n. + * The first k returns will be presented in the resulting symbol. + * + * The rest of the returns can be used for auxiliary states for Backward. + * For example, Dropout will return [data, mask], with NumVisibleReturns() == 1. + * So when user call sym = Dropout(input), only data is presented in sym. + * But all the returns will be presented in out_data parameter of Backward if requested. + * + * \return number of default return values + */ + virtual int NumVisibleReturns() const { + return NumReturns(); + } + /*! + * \brief Set the parameters of the Operator. + * \param name parameter name + * \param val string for the configuration + */ + virtual void SetParam(const char *name, const char *val) {} + /*! + * \brief infer the shapes of outputs and unknown input arguments + * \param in_shape the shape of input arguments of the operator + * this should be of same length as the vector returned by DescribeArgs + * in_shape allows unknown elements, which are checked by shape.ndim() == 0. + * For unknown shapes, InferShape will try to fill in the correct Shape in in_shape + * For known shapes, InferShape will check shape consistency + * + * common practice: set the shape of data input, and usually weight's shape can be infered + * + * \param out_shape the shape of outputs of the operator + * InferShape will modify the vector to fill output TShape + * \return if the shape inference is successful, return true, else return false. + */ + virtual bool InferShape(std::vector *in_shape, + std::vector *out_shape) const = 0; + /*! + * \brief Copy this OperatorProperty. + * \return a pointer to the copied OperatorProperty + */ + virtual OperatorProperty* Copy() const = 0; + /*! + * \brief Create a Operator on specific context + */ + virtual Operator* CreateOperator(Context ctx) const = 0; + /*! + * \brief return the type string of the Operator + * subclasses override this function. + */ + virtual std::string TypeString() const = 0; + //-------------------------------------------------------- + // All the below functions are optional to override. + //-------------------------------------------------------- + /*! + * \brief Declare additional resource required in forward pass. + * These additional resources will be presented in OpContext.requested + * in the same order of the returned Resource. + * \return Additional resource request + */ + virtual std::vector ForwardResource() const { + return std::vector(); + } + /*! + * \brief Decalre additional resource required in backward pass. + * These additional resources will be presented in OpContext.requested + * in the same order of the returned Resource. + * \return Additional resource request + */ + virtual std::vector BackwardResource() const { + return std::vector(); + } + /*! + * \brief Declare the input requirement of Backward pass. + * + * Only the returned list of variables will be used in Backward. + * This function is used for memory optimization. + * It is adviced to override and only return what is actually needed. + * If this function is not overriden, all the variables will be valid in Backward. + * + * \code + * // The following code declares Backward need out_grad[0], in_data[0],in_data[1] + * vector BackwardInputs(const vector &out_grad, + * const vector &in_data, + * const vector &out_data) const { + * return {out_grad[0], in_data[0], in_data[1]}; + * } + * \endcode + * \param out_grad gradient of outputs in backward pass. + * \param in_data the input data in forward pass. + * \param out_data the output data in forward pass. + * \return an integer vector indicating the input requirments + * \sa BackwardInputs + */ + virtual std::vector DeclareBackwardDependency( + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data) const { + // By default requires to see all the things. + // remember to override this function to get a better performance. + std::vector ret = out_grad; + ret.insert(ret.end(), in_data.begin(), in_data.end()); + ret.insert(ret.end(), out_data.begin(), out_data.end()); + return ret; + } + /*! + * \brief Get possible forward inplace options. + * This function enables optimization to reuse memory of inputs in output. + * Only override when necessary, by default in-place is disabled. + * + * \code + * // The following code says out_data[0] can share data with in_data[0] + * vector > ForwardInplaceOption(const vector &in_data, + * const vector &out_data) const { + * return {{out_data[0], in_data[0]}}; + * } + * \endcode + * \return list of pair of integers taken from the inputs vector, + * indicating possible in place operations. + */ + virtual std::vector > ForwardInplaceOption( + const std::vector &in_data, + const std::vector &out_data) const { + return std::vector >(); + } + /*! + * \brief Get possible backward inplace options. + * This function enables optimization to reuse memory of inputs in output. + * Only override when necessary, by default in-place is disabled. + * + * \code + * // The following code says in_grad[0] can share data with in_data[0] + * vector > BackwardInplaceOption( + * const std::vector &out_grad, + * const std::vector &in_data, + * const std::vector &out_data, + * const std::vector &in_grad) const { + * return {in_grad[0], in_data[0]}}; + * } + * \endcode + * \return list of pair of integers taken from the inputs vector, + * indicating possible in place operations. + */ + virtual std::vector > BackwardInplaceOption( + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &in_grad) const { + return std::vector >(); + } + /*! + * \brief Get Backward Input Dependency for generic types of data. + * Normally T can be pointer of Symbol::DataEntry, or NArray. + * This function will select the result list of T according to DeclareBackwardDependency. + * + * \param in_data the input data in forward pass. + * \param out_data the output data in forward pass. + * \param out_grad gradient of outputs in backward pass. + * \tparam T the generic type parameter. + * \return vector of inputs the Backward Operation depends on. + * \sa DeclareBackwardDependency + */ + template + inline std::vector BackwardInputs(const std::vector &in_data, + const std::vector &out_data, + const std::vector &out_grad) const { + int cnt = 0; + std::vector all_vec; + std::vector in_data_idx, out_data_idx, out_grad_idx; + for (size_t i = 0; i < in_data.size(); ++i) { + in_data_idx.push_back(cnt++); + all_vec.push_back(in_data[i]); + } + for (size_t i = 0; i < out_data.size(); ++i) { + out_data_idx.push_back(cnt++); + all_vec.push_back(out_data[i]); + } + for (size_t i = 0; i < out_grad.size(); ++i) { + out_grad_idx.push_back(cnt++); + all_vec.push_back(out_data[i]); + } + std::vector ret_idx = this->DeclareBackwardDependency( + in_data_idx, out_data_idx, out_grad_idx); + std::vector ret; + for (size_t i = 0; i < ret_idx.size(); ++i) { + ret.push_back(all_vec[ret_idx[i]]); + } + return ret; + } + /*! + * \brief create OperatorProperty + * \param type_name the type string of the OperatorProperty + * \return a new constructed OperatorProperty + */ + static OperatorProperty *Create(const char* type_name); +}; +#endif } // namespace mxnet #endif // MXNET_OPERATOR_H_ diff --git a/include/mxnet/registry.h b/include/mxnet/registry.h index df9c27b9a4ad..ddc0a3ca22a0 100644 --- a/include/mxnet/registry.h +++ b/include/mxnet/registry.h @@ -10,9 +10,10 @@ #include #include #include +#include #include "./base.h" #include "./narray.h" -#include "./symbol.h" +#include "./operator.h" namespace mxnet { @@ -63,9 +64,6 @@ class Registry { } }; -/*! NArrayFunctionEntry requires c++11 */ -#if DMLC_USE_CXX11 -#include /*! \brief mask information on how functions can be exposed */ enum FunctionTypeMask { /*! \brief all the use_vars should go before scalar */ @@ -216,61 +214,45 @@ struct NArrayFunctionEntry { #define REGISTER_NARRAY_FUN(name) \ static auto __ ## name ## _narray_fun__ = \ ::mxnet::Registry::Get()->Register("" # name) -#endif // DMLC_USE_CXX11 -class Symbol; -/*! \brief AtomicSymbolEntry to register */ -struct AtomicSymbolEntry { + +/*! \brief OperatorPropertyEntry to register */ +struct OperatorPropertyEntry { /*! \brief typedef Creator function */ - typedef AtomicSymbol*(*Creator)(); - /*! \brief if AtomicSymbol use param */ + typedef OperatorProperty*(*Creator)(); + /*! \brief if OperatorProperty use param */ bool use_param; /*! \brief name of the entry */ std::string name; - /*! \brief function body to create AtomicSymbol */ + /*! \brief function body to create OperatorProperty */ Creator body; - /*! \brief singleton is created when no param is needed for the AtomicSymbol */ - Symbol *singleton_symbol; /*! \brief constructor */ - explicit AtomicSymbolEntry(const std::string& name) - : use_param(true), name(name), body(NULL), singleton_symbol(NULL) {} - /*! - * \brief set if param is needed by this AtomicSymbol - */ - inline AtomicSymbolEntry &set_use_param(bool use_param) { - this->use_param = use_param; - return *this; - } + explicit OperatorPropertyEntry(const std::string& name) + : use_param(true), name(name), body(NULL) {} /*! * \brief set the function body */ - inline AtomicSymbolEntry &set_body(Creator body) { + inline OperatorPropertyEntry &set_body(Creator body) { this->body = body; return *this; } - /*! - * \brief return the singleton symbol - */ - Symbol *GetSingletonSymbol(); - /*! \brief destructor */ - ~AtomicSymbolEntry(); /*! * \brief invoke the function - * \return the created AtomicSymbol + * \return the created OperatorProperty */ - inline AtomicSymbol* operator () () const { + inline OperatorProperty* operator () () const { return body(); } private: /*! \brief disable copy constructor */ - AtomicSymbolEntry(const AtomicSymbolEntry& other) {} + OperatorPropertyEntry(const OperatorPropertyEntry& other) {} /*! \brief disable assignment operator */ - const AtomicSymbolEntry& operator = (const AtomicSymbolEntry& other) { return *this; } + const OperatorPropertyEntry& operator = (const OperatorPropertyEntry& other) { return *this; } }; /*! - * \brief macro to register AtomicSymbol to AtomicSymbolFactory + * \brief macro to register OperatorProperty to OperatorPropertyFactory * * Example: the following code is example to register aplus * \code @@ -280,13 +262,13 @@ struct AtomicSymbolEntry { * * \endcode */ -#define REGISTER_ATOMIC_SYMBOL(name, AtomicSymbolType) \ - ::mxnet::AtomicSymbol* __make_ ## AtomicSymbolType ## __() { \ - return new AtomicSymbolType; \ +#define REGISTER_OP_PROPERTY(name, OperatorPropertyType) \ + ::mxnet::OperatorProperty* __make_ ## OperatorPropertyType ## __() { \ + return new OperatorPropertyType; \ } \ - static ::mxnet::AtomicSymbolEntry& __ ## name ## _atomic_symbol__ = \ - ::mxnet::Registry< ::mxnet::AtomicSymbolEntry >::Get()->Register("" # name) \ - .set_body(__make_ ## AtomicSymbolType ## __) + static ::mxnet::OperatorPropertyEntry& __ ## name ## _atomic_symbol__ = \ + ::mxnet::Registry< ::mxnet::OperatorPropertyEntry >::Get()->Register("" # name) \ + .set_body(__make_ ## OperatorPropertyType ## __) } // namespace mxnet #endif // MXNET_REGISTRY_H_ diff --git a/include/mxnet/static_graph.h b/include/mxnet/static_graph.h deleted file mode 100644 index e06cbf4f9c42..000000000000 --- a/include/mxnet/static_graph.h +++ /dev/null @@ -1,59 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file static_graph.h - * \brief the static graph of symbols - */ -#ifndef MXNET_STATIC_GRAPH_H_ -#define MXNET_STATIC_GRAPH_H_ - -#include -#include -#include -#include -#include "./atomic_symbol.h" -namespace mxnet { -/*! \brief static graph interface - * static graph is an internal representation of symbol graph. - * - * The main purpose for static graph for binding a composite operator - */ -struct StaticGraph { - /*! \brief Node in static graph */ - struct StaticNode { - /*! \brief wrapped atomic symbol */ - AtomicSymbol* sym_; - /*! \brief name of the node */ - std::string name_; - }; - /*! \brief node name to id dictionary */ - std::unordered_map name_id_map; - /*! \brief all nodes in the graph */ - std::vector nodes; - /*! \brief output id for each node */ - std::vector > output_index; - /*! \brief connected graph for each node */ - std::vector > connected_graph; - /*! \brief find node by using name - * \param name node name - * \param sym symbol need to be copied into node - * \return node id - */ - int FindNodeByName(const std::string& name, const AtomicSymbol* sym) { - int id = 0; - if (name_id_map.find(name) == name_id_map.end()) { - name_id_map[name] = name_id_map.size(); - StaticNode static_node; - static_node.sym_ = sym->Copy(); - static_node.name_ = name; - nodes.push_back(static_node); - output_index.push_back(std::vector()); - connected_graph.push_back(std::vector()); - id = name_id_map.size(); - } else { - id = name_id_map[name]; - } - return id; - } -}; -} // namespace mxnet -#endif // MXNET_STATIC_GRAPH_H_ diff --git a/include/mxnet/static_operator.h b/include/mxnet/static_operator.h deleted file mode 100644 index 27efccd2ce58..000000000000 --- a/include/mxnet/static_operator.h +++ /dev/null @@ -1,73 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file static_operator.h - * \brief static operator interface of mxnet - */ -#ifndef MXNET_STATIC_OPERATOR_H_ -#define MXNET_STATIC_OPERATOR_H_ -// this file will be seen by cuda, no c++11 for now -#include -#include -#include "./base.h" -#include "./tensor_blob.h" - -namespace mxnet { -/*! - * \brief static StaticOperator interface (current interface have not yet todo with scheduler), - * StaticOperator is a stateful object that can be used to call forward and backprop - * - * This interface relies on pre-allocated memory in TBlob, the caller need to set - * the memory region in TBlob correctly before calling Forward and Backward - * - * \sa TBlob, TShape - */ -class StaticOperator { - public: - /*! \brief destructor */ - virtual ~StaticOperator() {} - /*! - * \brief describe property of op - * \return a bit map in int - */ - virtual int DescribeProperty() const { - // default most of layer only conatin internal state - return kContainInteralState; - } - /*! - * \brief perform a forward operation of StaticOperator, save the output to TBlob - * \param opt option on Forward such as whether this is training phase - * \param ctx runtime context - * \param in_data array of input data, it is const - * \param out_data array of output data, - * the space of TBlob in out_data must be pre-allocated with InferShape - */ - virtual void Forward(Option opt, - RunContext ctx, - const std::vector &in_data, - const std::vector &out_data) = 0; - /*! - * \brief perform a backward operation of the StaticOperator to get the gradient - * \param ctx runtime context - * \param grad_next the gradient value we get from output of the StaticOperator - * \param in_data the array of input data - * \param out_grad array of output gradient, there could be three possible TBlob - * in the each element in the array - * \param req request types of the gradient saving operation - * only inplace will change input data - * \sa GradReqType - */ - virtual void Backward(RunContext ctx, - const std::vector &grad_next, - const std::vector &in_data, - const std::vector &out_grad, - const std::vector &req) = 0; - /*! - * \brief factory function, create a new StaticOperator - * \param type the type of StaticOperator - * \param ctx the context device type of StaticOperator - * \return a pointer of StaticOperator object - */ - static StaticOperator *Create(const char *type, Context ctx); -}; -} // namespace mxnet -#endif // MXNET_STATIC_OPERATOR_H_ diff --git a/include/mxnet/storage.h b/include/mxnet/storage.h index 9c65d954d5c5..575dc4cde1a2 100644 --- a/include/mxnet/storage.h +++ b/include/mxnet/storage.h @@ -8,7 +8,7 @@ #include #include "base.h" -#include "tensor_blob.h" +#include "context.h" namespace mxnet { diff --git a/include/mxnet/symbol.h b/include/mxnet/symbol.h deleted file mode 100644 index df1e78438560..000000000000 --- a/include/mxnet/symbol.h +++ /dev/null @@ -1,145 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file symbol.h - * \brief symbol interface of mxnet - */ -#ifndef MXNET_SYMBOL_H_ -#define MXNET_SYMBOL_H_ - -#include -#include -#include -#include -#include -#include -#include -#include "./base.h" -#include "./tensor_blob.h" -#include "./operator.h" -#include "./static_graph.h" - -namespace mxnet { -class CompositeOperator; -/*! - * \brief Symbol is the wrapper of AtomicSymbol, the reason for this indirection is that Symbol - * should support expressions and often passed by value. While AtomicSymbol have many subclasses, - * passing by value would result in object slicing. - * - * Symbol is always composite, the head Node is the output node of the symbol. - * A atomic symbol can be seen as a special case of the composite symbol with only the head node. - */ -class Symbol { - public: - /*! - * \brief Node is the container of AtomicSymbol, it also stores the connection of the AtomicSymbol - * with input symbols. - */ - struct Node { - /*! \brief wrapped atomic symbol */ - AtomicSymbol* sym_; - /*! \brief name of the node */ - std::string name_; - /*! \brief inputs to this node */ - std::vector > in_symbol_; - /*! \brief index of the inputs if the inputs are tuple */ - std::vector in_index_; - /*! \brief the output shape of the wrapped symbol */ - std::vector out_shape_; - /*! - * \brief constructor - */ - explicit Node(AtomicSymbol* sym = nullptr, const std::string& name = "") : - sym_(sym), name_(name) { - } - /*! - * \brief destructor - */ - ~Node() { - if (sym_) { - delete sym_; - } - } - }; - - protected: - /*! \brief the head node of the Symbol, it could be shared in many graphs */ - std::shared_ptr head_; - /*! \brief if the head has multiple return values, index is used to specify */ - int index_; - /*! \brief find the nodes that use placeholder arguments */ - std::shared_ptr > > arg_users_; - /*! \brief find arg users */ - void FindArgUsers(); - /** - * @brief Recursively parse the symbol to equivalent static graph. - * - * @param node The current node in dfs - * @param graph The static graph - */ - void Dfs(const std::shared_ptr node, StaticGraph *graph); - - public: - /*! - * \brief declare virtual destructor in case it is subclassed. - */ - virtual ~Symbol() {} - /*! - * \brief bind to device and returns an operator. - * \param ctx context of the operator - * \return returns the pointer to a created operator. It is on the user to delete. - */ - virtual CompositeOperator* Bind(Context ctx) const { return nullptr; } - /** - * \brief Bind the symbol to a composite operator - * \param ctx context of the operator - * \param in A map denotes name and corresponding NArray for binding - * \return The composite operator - */ - virtual CompositeOperator* Bind(Context ctx, const std::unordered_map& in); - /*! - * \brief copy the symbol - * \return a deep copy of the graph - */ - virtual Symbol Copy() const; - /*! - * \brief compose with arguments - * \param args positional arguments for the symbol - * \return a new Symbol which is the composition of current symbol with its arguments - */ - virtual Symbol operator () (const std::vector& args) const; - /*! - * \brief compose with named arguments - * \param kwargs keyword arguments for the symbol - * \return a new symbol which is the composition of current symbol with its arguments - */ - virtual Symbol operator () (const std::unordered_map& kwargs) const; - /*! - * \brief get the index th element from the returned tuple. - */ - virtual Symbol operator[] (int index) const; - /*! - * \brief arguments information - * \return the arguments list of this symbol, they can be either named or unnamed (empty string). - */ - virtual std::vector ListArgs(); - /** - * @brief Convert current symbol to its equivalent static graph representation. - * @return the static graph - */ - virtual StaticGraph ToStaticGraph(); - /*! - * \brief create Symbol by wrapping AtomicSymbol - */ - static Symbol Create(AtomicSymbol* atomic_symbol); - /*! - * \brief create atomic symbol wrapped in symbol - * \param type_name the type string of the AtomicSymbol - * \param param the parameter stored as key value pairs - * \return the constructed Symbol - */ - static Symbol Create(const std::string& type_name, - const std::vector >& param); -}; - -} // namespace mxnet -#endif // MXNET_SYMBOL_H_ diff --git a/include/mxnet/symbolic.h b/include/mxnet/symbolic.h new file mode 100644 index 000000000000..dc00f5a33fb6 --- /dev/null +++ b/include/mxnet/symbolic.h @@ -0,0 +1,311 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file symbolic.h + * \brief Symbolic interface of mxnet. + * \author Min Lin, Bing Xu +*/ +#ifndef MXNET_SYMBOLIC_H_ +#define MXNET_SYMBOLIC_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "./base.h" +#include "./narray.h" +#include "./operator.h" + +// check c++11 +#if DMLC_USE_CXX11 == 0 +#error "CXX11 was required for symbolic module" +#endif + +namespace mxnet { +/*! + * \brief StaticGraph is the configuration of computation graphs. + * This is the "configuration file" of mxnet. + * It can be converted to/from Symbol, and can be used to bind to operators. + */ +class StaticGraph { + public: + /*! \brief represents a data in the graph */ + struct DataEntry { + /*! \brief the source node id in the computation graph */ + uint32_t source_id; + /*! \brief index of output from the source. */ + uint32_t index; + }; + /*! \brief Operation Node in static graph */ + struct Node { + /*! \brief wrapped operator property */ + std::unique_ptr op; + /*! \brief name of the node */ + std::string name; + /*! \brief inputs (node_id, index) for of the nodes*/ + std::vector inputs; + }; + /*! \brief all nodes in the graph */ + std::vector nodes; + /*! \brief index is nodes that correspods to arguments */ + std::vector arg_nodes; + /*! \brief outputs(heads) of the graph */ + std::vector outputs; + // funtions to help inference in static graph + /*! + * \brief Perform a topological sort on the graph + * \return a topological order of node indices. + */ + std::vector TopoSort() const; + /*! + * \brief infer the node shapes in the computation graph. + * + * When calling this function, user can setup the shape information known into right position. + * Unknown shape are indicated by shape.ndim() == 0. + * + * \param topo_order The topological order of node index, as created by TopoSort. + * \param node_out_shapes The shapes of the each outputs of nodes in the graph. + * \return if the shape inference is successful, return true, else return false. + */ + bool InferNodeShapes(const std::vector &topo_order, + std::vector > *node_out_shapes) const; + /*! + * \brief infer the shapes of outputs and unknown input arguments + * \param in_shape the shape of input arguments of the operator + * this should be of same length as the vector returned by ListArguments + * in_shape allows unknown elements, which are checked by shape.ndim() == 0. + * For unknown shapes, InferShape will try to fill in the correct Shape in in_shape + * For known shapes, InferShape will check shape consistency + * + * common practice: set the shape of data input, and usually weight's shape can be infered + * + * \param out_shape the shape of outputs of the operator + * InferShape will modify the vector to fill output TShape + * \return if the shape inference is successful, return true, else return false. + */ + bool InferShape(std::vector *in_shape, + std::vector *out_shape) const; +}; + +/*! + * \brief Symbol is used to represent dynamically generated symbolic computation graph. + * + * This class is used as a tool to generate computation graphs(aka. configuration) of the network. + * Symbol is always composite, the head Node is the output node of the symbol. + * An atomic symbol can be seen as a special case of the composite symbol with only the head node. + * + * The symbol can be converted from/to StaticGraph, the actual configuration used by mxnet. + * Symbol offers more flexible way to composite nodes than StaticGraph, which makes it good + * tool to generate configurations from language bindings such as python. + * \sa StaticGraph + */ +class Symbol { + public: + /*! + * \brief copy the symbol + * \return a deep copy of the graph + */ + Symbol Copy() const; + /*! + * \brief print the symbol info to output stream. + * \param os the output stream we like to print to + */ + void Print(std::ostream &os) const; // NOLINT(*) + /*! + * \brief List the arguments names. + * + * The position of the returned list also corresponds to calling position in operator() + * \return the arguments list of this symbol, they can be either named or unnamed (empty string). + */ + std::vector ListArguments() const; + /*! \return get the descriptions of outputs for this symbol */ + std::vector ListReturns() const; + /*! + * \brief get the index th element from the returned tuple. + * \param index index of multi output + * \return the symbol corresponds to the indexed element. + */ + Symbol operator[] (size_t index) const; + /*! + * \brief Compose the symbol with arguments, this changes current symbol. + * + * The positional arguments passed in must be complete(contain all arguments). + * + * \param args positional arguments for the symbol + * \param name name of returned symbol. + */ + void Compose(const std::vector& args, + const std::string& name); + /*! + * \brief Compose the symbol with arguments, this changes the current symbol. + * The kwargs passed in can be in-complete, + * + * The rest of the symbols will remain the same name. + * + * \param kwargs keyword arguments for the symbol + * \param name name of returned symbol. + */ + void Compose(const std::unordered_map& kwargs, + const std::string& name); + /*! + * \brief Convert a list of symbols into static graph + * + * The user can go further to call bind function on static graph + * + * \param out_graph the pointer holder of the output graph + */ + void ToStaticGraph(StaticGraph *out_graph) const; + /*! + * \brief Apply the symbol as a function, compose with arguments + * \param args positional arguments for the symbol + * \param name name of returned symbol. + * \return a new Symbol which is the composition of current symbol with its arguments + */ + Symbol operator () (const std::vector& args, const std::string& name) const; + /*! + * \brief compose with named arguments + * \param kwargs keyword arguments for the symbol + * \param name name of returned symbol. + * \return a new symbol which is the composition of current symbol with its arguments + */ + Symbol operator () (const std::unordered_map& kwargs, + const std::string& name) const; + /*! + * \brief infer the shapes of outputs and unknown input arguments + * \param in_shape the shape of input arguments of the operator + * this should be of same length as the vector returned by ListArguments + * in_shape allows unknown elements, which are checked by shape.ndim() == 0. + * For unknown shapes, InferShape will try to fill in the correct Shape in in_shape + * For known shapes, InferShape will check shape consistency + * + * common practice: set the shape of data input, and usually weight's shape can be infered + * + * \param out_shape the shape of outputs of the operator + * InferShape will modify the vector to fill output TShape + * \return if the shape inference is successful, return true, else return false. + */ + bool InferShape(std::vector *in_shape, std::vector *out_shape) const; + /*! + * \brief get number of outputs of this symbol + * \return number of outputs + */ + inline size_t NumReturns() const { + return heads_.size(); + } + /*! + * \brief create Symbol by wrapping OperatorProperty + * This function takes the ownership of op + * + * \param op the OperatorProperty of the Operator + * \return Symbol + * \sa OperatorProperty::Create + */ + static Symbol Create(OperatorProperty *op); + /*! + * \brief create equivalence of symbol from static graphs + * \param graph the static graph + * \return the created symbol + */ + static Symbol Create(const StaticGraph &graph); + /*! + * \brief create equivalence of symbol by grouping the symbols together + * \param symbols list of symbols + * \return the grouped symbol + */ + static Symbol CreateGroup(const std::vector &symbols); + /*! + * \brief create variable symbol node + * \param name name of the variable + * \return the new variable + */ + static Symbol CreateVariable(const std::string &name); + + protected: + // Decalre node, internal data structure. + struct Node; + /*! \brief an entry that represents output data from a node */ + struct DataEntry { + /*! \brief the source node of this data */ + std::shared_ptr source; + /*! \brief index of output from the source. */ + uint32_t index; + /*! \brief enabled default copy constructor */ + DataEntry() {} + /*! \brief constructor from index */ + DataEntry(std::shared_ptr source, uint32_t index) + : source(source), index(index) {} + }; + /*! + * \brief the head nodes of Symbols + * This head is only effective when + */ + std::vector heads_; + + private: + /*! \return whwther the symbol is atomic */ + inline bool is_atomic() const; + /*! + * \brief Visit all the nodes in left-to-right depth first order. + * + * This function will visit the graph in DFS order, call fvisit exactly once + * for each Node, and store the result in out_result. + * + * \param fvisit function applied for each visit. + * \tparam FVisit visiting function type + */ + template + inline void DFSVisit(FVisit fvisit) const; + /*! + * \brief Find duplicate arguments in the composition + * \param out the map of argument-name -> occurence count + * \return maximum number of duplication factor + */ + int FindDuplicateArgs(std::unordered_map *out) const; +}; + +/*! + * \brief Executor of a computation graph. + * Executor can be created by Binding a symbol. + */ +class Executor { + public: + /*! \brief destructor */ + virtual ~Executor() {} + /*! + * \brief Perform a Forward operation of Operator + * After this operation, user can get the result by using function head. + */ + virtual void Forward() = 0; + /*! + * \brief Perform a Backward operation of the Operator. + * This must be called after Forward. + * After this operation, NArrays specified by grad_in_args_store will be updated accordingly. + * \param head_grads the gradient of head nodes to be backproped. + */ + virtual void Backward(const std::vector &head_grads) = 0; + /*! + * \brief get array of heads in the executor. + * \return array of heads in the executor. + */ + virtual const std::vector &heads() const = 0; + /*! + * \brief Create an operator by bind symbol with context and arguments. + * If user do not want to compute the gradients of i-th argument, grad_req_type[i] can be kNullOp. + * + * \param ctx the context of binding. + * \param symbol the symbol that specifies the output of Forward pass. + * \param in_args the NArray that stores the input arguments to the symbol. + * \param arg_grad_store NArray that is used to store the gradient output of the input arguments. + * \param grad_req_type requirment type of gradient saving. Can only be in {kNullOp, kAddTo, kWriteTo}. + * \return a new executor. + */ + static Executor *Bind(Symbol symbol, + Context ctx, + const std::vector &in_args, + const std::vector &arg_grad_store, + const std::vector &grad_req_type); +}; // class operator +} // namespace mxnet +#endif // MXNET_SYMBOLIC_H_ diff --git a/include/mxnet/tensor_blob.h b/include/mxnet/tensor_blob.h deleted file mode 100644 index b39939cb1425..000000000000 --- a/include/mxnet/tensor_blob.h +++ /dev/null @@ -1,53 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file tensor_blob.h - * \brief tensor blob used to hold static memory used by - */ -#ifndef MXNET_TENSOR_BLOB_H_ -#define MXNET_TENSOR_BLOB_H_ -#include - -namespace mxnet { -/*! \brief context information about the execution enviroment */ -struct Context { - /*! \brief the device type we run the op can be cpu::kDevMask or gpu::kDevMask */ - int dev_mask; - /*! \brief device id we are going to run it on */ - int dev_id; - /*! \brief constructor */ - Context() : dev_mask(cpu::kDevMask), dev_id(0) {} - /*! - * \brief constructor of context - * \param dev_mask the device mask - * \param dev_id the device id - */ - Context(int dev_mask, int dev_id) - : dev_mask(dev_mask), dev_id(dev_id) {} - /*! - * \brief check if current context equals another one - * \param b another context to compare - * \return whether dev mask and id are same - */ - inline bool operator==(const Context &b) const { - return dev_mask == b.dev_mask && dev_id == b.dev_id; - } -}; - - -/*! - * \brief execution context provides the information needed - * in runtime to actually execute the operation - */ -struct RunContext { - /*! - * \brief the stream of the device, can be NULL or Stream* in GPU mode - */ - void *stream; -}; - -/*! \brief dynamic shape type */ -typedef mshadow::TShape TShape; -/*! \brief storage container type */ -typedef mshadow::TBlob TBlob; -} // namespace mxnet -#endif // MXNET_TENSOR_BLOB_H_ diff --git a/make/config.mk b/make/config.mk index dccb959c2f36..48587a4f9114 100644 --- a/make/config.mk +++ b/make/config.mk @@ -49,7 +49,7 @@ PS_PATH = NONE PS_THIRD_PATH = NONE # whether compile with rabit -USE_RABIT_PS = 1 +USE_RABIT_PS = 0 RABIT_PATH = rabit # use openmp iterator diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index 6f4146d162e3..031b18ab862f 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -1,4 +1,5 @@ # coding: utf-8 +# pylint: disable=invalid-name, protected-access """Symbol support of mxnet""" from __future__ import absolute_import @@ -37,37 +38,113 @@ def __init__(self, handle): """ self.handle = handle + def __del__(self): + check_call(_LIB.MXSymbolFree(self.handle)) + + def __copy__(self): + return self.__deepcopy__() + + def __deepcopy__(self): + handle = SymbolHandle() + check_call(_LIB.MXSymbolCopy(self.handle, + ctypes.byref(handle))) + return Symbol(handle) + def __call__(self, *args, **kwargs): - """Compose Symbols + """Invoke symbol as function on inputs. + + Parameters + ---------- + args: + provide positional arguments + + kwargs: + provide keyword arguments + Returns + ------- + the resulting symbol + """ + s = self.__deepcopy__() + s._compose(*args, **kwargs) + return s + + def _compose(self, *args, **kwargs): + """Compose symbol on inputs. + + This call mutates the current symbol. Parameters ---------- args: provide positional arguments + kwargs: provide keyword arguments Returns ------- the resulting symbol """ - assert (len(args) == 0 or len(kwargs) == 0) + name = kwargs.pop('name', None) + if name: + name = c_str(name) + if len(args) != 0 and len(kwargs) != 0: + raise TypeError('compose only accept input Symbols \ + either as positional or keyword arguments, not both') + for arg in args: - assert isinstance(arg, Symbol) - for _, val in kwargs: - assert isinstance(val, Symbol) + if not isinstance(arg, Symbol): + raise TypeError('Compose expect `Symbol` as arguments') + for _, val in kwargs.items(): + if not isinstance(val, Symbol): + raise TypeError('Compose expect `Symbol` as arguments') + num_args = len(args) + len(kwargs) if len(kwargs) != 0: keys = c_array(ctypes.c_char_p, [c_str(key) for key in kwargs.keys()]) - args = c_array(SymbolHandle, kwargs.values()) + args = c_array(SymbolHandle, [s.handle for s in kwargs.values()]) else: keys = None - args = c_array(SymbolHandle, args) - - out = SymbolHandle() - check_call(_LIB.MXSymbolCompose( - self.handle, - num_args, - keys, - args, - ctypes.byref(out))) - return Symbol(out) + args = c_array(SymbolHandle, [s.handle for s in args]) + check_call(_LIB.MXSymbolCompose( \ + self.handle, name, num_args, keys, args)) + + def list_arguments(self): + """List all the arguments in the symbol. + + Returns + ------- + args : list of string + List of all the arguments. + """ + size = ctypes.c_uint() + sarr = ctypes.POINTER(ctypes.c_char_p)() + check_call(_LIB.MXSymbolListArguments( \ + self.handle, ctypes.byref(size), ctypes.byref(sarr))) + return [sarr[i] for i in range(size.value)] + + def list_returns(self): + """List all returns in the symbol. + + Returns + ------- + args: list of string + List of all the returns. + """ + size = ctypes.c_uint() + sarr = ctypes.POINTER(ctypes.c_char_p)() + check_call(_LIB.MXSymbolListReturns( \ + self.handle, ctypes.byref(size), ctypes.byref(sarr))) + return [sarr[i] for i in range(size.value)] + + def debug_str(self): + """Get a debug string. + + Returns + ------- + debug_str : string + Debug string of the symbol. + """ + debug_str = ctypes.c_char_p() + check_call(_LIB.MXSymbolPrint( \ + self.handle, ctypes.byref(debug_str))) + return debug_str.value diff --git a/python/mxnet/symbol_creator.py b/python/mxnet/symbol_creator.py index ee8f8bba1525..c81deebaef11 100644 --- a/python/mxnet/symbol_creator.py +++ b/python/mxnet/symbol_creator.py @@ -1,11 +1,12 @@ # coding: utf-8 +# pylint: disable=invalid-name, protected-access, no-self-use """Symbol support of mxnet""" from __future__ import absolute_import import ctypes from .base import _LIB -from .base import c_array, c_str -from .base import mx_uint, SymbolHandle +from .base import c_array, c_str, string_types +from .base import SymbolHandle from .base import check_call from .symbol import Symbol @@ -25,34 +26,53 @@ def __init__(self, name, handle): """ self.name = name self.handle = handle - singleton_ = SymbolHandle() - check_call(_LIB.MXSymbolGetSingleton(self.handle, ctypes.byref(singleton_))) - if singleton_: - self.singleton = Symbol(singleton_) - else: - self.singleton = None - - def __call__(self, **kwargs): + + def __call__(self, *args, **kwargs): """Invoke creator of symbol by passing kwargs Parameters ---------- + name : string + Name of the resulting symbol. + + *args + Positional arguments + **kwargs - provide the params necessary for the symbol creation + Provide the params necessary for the symbol creation. + Returns ------- the resulting symbol """ - keys = c_array(ctypes.c_char_p, [c_str(key) for key in kwargs.keys()]) - vals = c_array(ctypes.c_char_p, [c_str(str(val)) for val in kwargs.values()]) + param_keys = [] + param_vals = [] + symbol_kwargs = {} + name = kwargs.pop('name', None) + + for k, v in kwargs.items(): + if isinstance(v, Symbol): + symbol_kwargs[k] = v + else: + param_keys.append(k) + param_vals.append(c_str(str(v))) + + # create atomic symbol + param_keys = c_array(ctypes.c_char_p, param_keys) + param_vals = c_array(ctypes.c_char_p, param_vals) sym_handle = SymbolHandle() - check_call(_LIB.MXSymbolCreateFromAtomicSymbol( - self.handle, - mx_uint(len(kwargs)), - keys, - vals, + check_call(_LIB.MXSymbolCreateAtomicSymbol( + self.handle, len(param_keys), + param_keys, param_vals, ctypes.byref(sym_handle))) - return Symbol(sym_handle) + + if len(args) != 0 and len(symbol_kwargs) != 0: + raise TypeError('%s can only accept input \ + Symbols either as positional or keyword arguments, not both' % self.name) + + s = Symbol(sym_handle) + s._compose(*args, name=name, **symbol_kwargs) + return s class _SymbolCreatorRegistry(object): """Function Registry""" @@ -62,8 +82,50 @@ def __init__(self): check_call(_LIB.MXSymbolListAtomicSymbolCreators(ctypes.byref(size), ctypes.byref(plist))) hmap = {} - name = ctypes.c_char_p() for i in range(size.value): - name = _LIB.MXSymbolGetAtomicSymbolName(plist[i], ctypes.byref(name)) - hmap[name] = _SymbolCreator(name, plist[i]) + name = ctypes.c_char_p() + check_call(_LIB.MXSymbolGetAtomicSymbolName(plist[i], ctypes.byref(name))) + hmap[name.value] = _SymbolCreator(name, plist[i]) self.__dict__.update(hmap) + + def Variable(self, name): + """Create a symbolic variable with specified name. + + Parameters + ---------- + name : str + Name of the variable. + + Returns + ------- + variable : Symbol + The created variable symbol. + """ + if not isinstance(name, string_types): + raise TypeError('Expect a string for variable `name`') + handle = SymbolHandle() + check_call(_LIB.MXSymbolCreateVariable(name, ctypes.byref(handle))) + return Symbol(handle) + + def Group(self, symbols): + """Create a symbolic variable that groups several symbols together. + + Parameters + ---------- + symbols : list + List of symbols to be grouped. + + Returns + ------- + sym : Symbol + The created group symbol. + """ + ihandles = [] + for sym in symbols: + if not isinstance(sym, Symbol): + raise TypeError('Expect Symbols in the list input') + ihandles.append(sym.handle) + handle = SymbolHandle() + check_call(_LIB.MXSymbolCreateGroup( + len(ihandles), c_array(SymbolHandle, ihandles), ctypes.byref(handle))) + return Symbol(handle) diff --git a/python/test_python.py b/python/test_python.py index 5fe4a03ac11a..7aa4c432f1db 100644 --- a/python/test_python.py +++ b/python/test_python.py @@ -1,4 +1,4 @@ -#pylint: skip-file +# pylint: skip-file import mxnet as mx a = mx.narray.create((3000, 4000)) diff --git a/python/test_symbol.py b/python/test_symbol.py new file mode 100644 index 000000000000..6d876fd46fb8 --- /dev/null +++ b/python/test_symbol.py @@ -0,0 +1,27 @@ +# pylint: skip-file +import mxnet as mx + +data = mx.sym.Variable('data') +print data.debug_str() + +fc1 = mx.sym.FullyConnected(data=data, name='fc1', no_bias=0) +fc2 = mx.sym.FullyConnected(data=fc1, name='fc2', no_bias=0) + +print fc2.debug_str() + +print fc2.list_arguments() + +fc3 = mx.sym.FullyConnected(name='fc3') +fc4 = mx.sym.FullyConnected(data=fc3, name='fc4') + +print fc4.debug_str() + +print "-" * 10 +composed_fc4 = fc4(fc3_data=fc2, name='composed') +print composed_fc4.debug_str() + +multi_out = mx.sym.Group([composed_fc4, fc2]) + +print multi_out.debug_str() +print multi_out.list_arguments() +print multi_out.list_returns() diff --git a/src/c_api.cc b/src/c_api.cc index 5452c1be2e3d..d5a1a67d70c6 100644 --- a/src/c_api.cc +++ b/src/c_api.cc @@ -7,11 +7,12 @@ #include #include #include -#include -#include +#include #include +#include #include #include +#include // macro hanlding for threadlocal variables #ifdef __GNUC__ @@ -26,6 +27,18 @@ #message("Warning: Threadlocal is not enabled"); #endif +/*! \brief symbol wrapper to easily hold returning information */ +struct MXAPISymbolWrapper { + /*! \brief the actual symbol */ + mxnet::Symbol sym; + /*! \brief result holder for returning string */ + std::string ret_str; + /*! \brief result holder for returning strings */ + std::vector ret_vec_str; + /*! \brief result holder for returning string pointers */ + std::vector ret_vec_charp; +}; + /*! * \brief helper to store error message in threadlocal storage */ @@ -86,8 +99,15 @@ using namespace mxnet; /*! \brief macro to guard beginning and end section of all functions */ #define API_BEGIN() try { -/*! \brief every function starts with API_BEGIN(); and finishes with API_END(); */ +/*! \brief every function starts with API_BEGIN(); + and finishes with API_END() or API_END_HANDLE_ERROR */ #define API_END() } catch(dmlc::Error &e) { return MXHandleException(e); } return 0; +/*! + * \brief every function starts with API_BEGIN(); + * and finishes with API_END() or API_END_HANDLE_ERROR + * The finally clause contains procedure to cleanup states when an error happens. + */ +#define API_END_HANDLE_ERROR(Finalize) } catch(dmlc::Error &e) { Finalize; return MXHandleException(e); } return 0; // NOLINT(*) /*! \brief return str message of the last error */ const char *MXGetLastError() { @@ -249,74 +269,147 @@ int MXFuncInvoke(FunctionHandle fun, API_END(); } -int MXSymbolCreateFromAtomicSymbol(AtomicSymbolCreator creator, - int num_param, - const char **keys, - const char **vals, - SymbolHandle *out) { +//-------------------------------------------- +// Part 3: symbolic configuration generation +//-------------------------------------------- + +int MXSymbolListAtomicSymbolCreators(mx_uint *out_size, + AtomicSymbolCreator **out_array) { + API_BEGIN(); + auto &vec = Registry::List(); + *out_size = static_cast(vec.size()); + *out_array = (AtomicSymbolCreator*)(dmlc::BeginPtr(vec)); // NOLINT(*) + API_END(); +} + +int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator, + const char **out) { API_BEGIN(); - AtomicSymbolEntry *e = static_cast(creator); - *out = static_cast(new Symbol); - AtomicSymbol *atomic_symbol = (*e)(); + OperatorPropertyEntry *e = static_cast(creator); + *out = e->name.c_str(); + API_END(); +} + +int MXSymbolCreateAtomicSymbol(AtomicSymbolCreator creator, + int num_param, + const char **keys, + const char **vals, + SymbolHandle *out) { + MXAPISymbolWrapper *s = new MXAPISymbolWrapper(); + OperatorProperty *op = nullptr; + + API_BEGIN(); + OperatorPropertyEntry *e = static_cast(creator); + op = (*e)(); for (int i = 0; i < num_param; ++i) { - atomic_symbol->SetParam(keys[i], vals[i]); + op->SetParam(keys[i], vals[i]); } - *static_cast(*out) = Symbol::Create(atomic_symbol); - API_END(); + s->sym = Symbol::Create(op); + *out = s; + API_END_HANDLE_ERROR(delete s; delete op); +} + +int MXSymbolCreateVariable(const char *name, SymbolHandle *out) { + MXAPISymbolWrapper *s = new MXAPISymbolWrapper(); + API_BEGIN(); + s->sym = Symbol::CreateVariable(name); + *out = s; + API_END_HANDLE_ERROR(delete s); +} + +int MXSymbolCreateGroup(mx_uint num_symbols, + SymbolHandle *symbols, + SymbolHandle *out) { + MXAPISymbolWrapper *s = new MXAPISymbolWrapper(); + MXAPISymbolWrapper **sym_arr = (MXAPISymbolWrapper**)symbols; // NOLINT(*) + API_BEGIN(); + std::vector syms; + for (mx_uint i = 0; i < num_symbols; ++i) { + syms.push_back(sym_arr[i]->sym); + } + s->sym = Symbol::CreateGroup(syms); + *out = s; + API_END_HANDLE_ERROR(delete s); } int MXSymbolFree(SymbolHandle symbol) { API_BEGIN(); - delete static_cast(symbol); + delete static_cast(symbol); API_END(); } -int MXSymbolListAtomicSymbolCreators(mx_uint *out_size, - AtomicSymbolCreator **out_array) { +int MXSymbolCopy(SymbolHandle symbol, SymbolHandle *out) { + MXAPISymbolWrapper *s = new MXAPISymbolWrapper(); + API_BEGIN(); - auto &vec = Registry::List(); - *out_size = static_cast(vec.size()); - *out_array = (AtomicSymbolCreator*)(dmlc::BeginPtr(vec)); // NOLINT(*) + s->sym = (static_cast(symbol)->sym).Copy(); + *out = s; + API_END_HANDLE_ERROR(delete s); +} + +int MXSymbolPrint(SymbolHandle symbol, const char **out_str) { + MXAPISymbolWrapper *s = static_cast(symbol); + + API_BEGIN(); + std::ostringstream os; + (s->sym).Print(os); + s->ret_str = os.str(); + *out_str = (s->ret_str).c_str(); API_END(); } -int MXSymbolGetSingleton(AtomicSymbolCreator creator, - SymbolHandle *out) { +int MXSymbolListArguments(SymbolHandle symbol, + mx_uint *out_size, + const char ***out_str_array) { + MXAPISymbolWrapper *s = static_cast(symbol); API_BEGIN(); - AtomicSymbolEntry *e = static_cast(creator); - *out = static_cast(e->GetSingletonSymbol()); + s->ret_vec_str = std::move((s->sym).ListArguments()); + s->ret_vec_charp.clear(); + for (size_t i = 0; i < s->ret_vec_str.size(); ++i) { + s->ret_vec_charp.push_back(s->ret_vec_str[i].c_str()); + } + *out_size = static_cast(s->ret_vec_charp.size()); + *out_str_array = dmlc::BeginPtr(s->ret_vec_charp); API_END(); } -int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator, - const char **out) { +int MXSymbolListReturns(SymbolHandle symbol, + mx_uint *out_size, + const char ***out_str_array) { + MXAPISymbolWrapper *s = static_cast(symbol); API_BEGIN(); - AtomicSymbolEntry *e = static_cast(creator); - *out = e->name.c_str(); + s->ret_vec_str = std::move((s->sym).ListReturns()); + s->ret_vec_charp.clear(); + for (size_t i = 0; i < s->ret_vec_str.size(); ++i) { + s->ret_vec_charp.push_back(s->ret_vec_str[i].c_str()); + } + *out_size = static_cast(s->ret_vec_charp.size()); + *out_str_array = dmlc::BeginPtr(s->ret_vec_charp); API_END(); } int MXSymbolCompose(SymbolHandle sym, + const char *name, mx_uint num_args, const char** keys, - SymbolHandle* args, - SymbolHandle* out) { + SymbolHandle* args) { API_BEGIN(); - const Symbol* s = static_cast(sym); - Symbol* ret = new Symbol; - if (keys == NULL) { + std::string s_name; + if (name != nullptr) s_name = name; + + MXAPISymbolWrapper* s = static_cast(sym); + if (keys == nullptr && num_args != 0) { std::vector pos_args; for (mx_uint i = 0; i < num_args; ++i) { - pos_args.push_back(*(Symbol*)(args[i])); // NOLINT(*) + pos_args.push_back(((MXAPISymbolWrapper*)(args[i]))->sym); // NOLINT(*) } - *ret = (*s)(pos_args); + (s->sym).Compose(pos_args, s_name); } else { std::unordered_map kwargs; for (mx_uint i = 0; i < num_args; ++i) { - kwargs[keys[i]] = *(Symbol*)(args[i]); // NOLINT(*) + kwargs[keys[i]] = ((MXAPISymbolWrapper*)(args[i]))->sym; // NOLINT(*) } - *ret = (*s)(kwargs); + (s->sym).Compose(kwargs, s_name); } - *out = ret; API_END(); } diff --git a/src/narray/narray_op.h b/src/narray/narray_op.h index 1ce546ed295d..21a8da782972 100644 --- a/src/narray/narray_op.h +++ b/src/narray/narray_op.h @@ -8,7 +8,7 @@ #include #include #include -#include +#include namespace mxnet { /*! \brief namespace to support all possible NArray operator */ diff --git a/src/operator/fully_connected-inl.h b/src/operator/fully_connected-inl.h new file mode 100644 index 000000000000..5c54d37220ee --- /dev/null +++ b/src/operator/fully_connected-inl.h @@ -0,0 +1,172 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file fully_connect_op-inl.h + * \brief fully connect operator and symbol +*/ +#ifndef MXNET_OPERATOR_FULLY_CONNECTED_INL_H_ +#define MXNET_OPERATOR_FULLY_CONNECTED_INL_H_ + +#include +#include +#include +#include +#include +#include "./operator_common.h" +#include "./param.h" + +namespace mxnet { +namespace op { + +// Declare enumeration of input order to make code more intuitive. +// These enums are only visible within this header +enum FullyConnectedOpInputs {kData, kWeight, kBias}; +enum FullyConnectedOpOutputs {kOut}; + +/** + * \brief This is the implementation of fully connected operator. + * \tparam xpu The device that the op will be executed on. + */ +template +class FullyConnectedOp : public Operator { + public: + explicit FullyConnectedOp(Param p) { + this->param_ = p; + } + + virtual void Forward(const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(req[kOut], kWriteTo); + size_t expected = param_.no_bias == 0 ? 3 : 2; + CHECK_EQ(in_data.size(), expected); + CHECK_EQ(out_data.size(), 1); + // TODO(bing): check the BLAS Handle, be careful + // maybe need blas handle from context + Stream *s = static_cast *>(ctx.stream); + Tensor data = in_data[kData].FlatTo2D(s); + Tensor wmat = in_data[kWeight].get(s); + Tensor out = out_data[kOut].FlatTo2D(s); + out = dot(data, wmat.T()); + if (param_.no_bias == 0) { + Tensor bias = in_data[kBias].get(s); + out += repmat(bias, data.size(0)); + } + } + + virtual void Backward(const OpContext &ctx, + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &req, + const std::vector &in_grad) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(out_grad.size(), 1); + size_t expected = param_.no_bias == 0 ? 3 : 2; + CHECK(in_data.size() == expected && in_grad.size() == expected); + CHECK_EQ(req.size(), expected); + // TODO(bing): check the BLAS Handle, be careful + // maybe need blas handle from context + Stream *s = static_cast *>(ctx.stream); + Tensor data = in_data[kData].FlatTo2D(s); + Tensor wmat = in_data[kWeight].get(s); + Tensor grad = out_grad[kOut].FlatTo2D(s); + // backprop + CHECK_NE(req[kWeight], kWriteInplace) << "cannot write weight inplace"; + // gradient of weight + Tensor gwmat = in_grad[kWeight].get(s); + Assign(gwmat, req[kWeight], dot(grad.T(), data)); + // gradient of bias + if (param_.no_bias == 0) { + Tensor gbias = in_grad[kBias].get(s); + Assign(gbias, req[kBias], sum_rows(grad)); + } + // gradient of data + Tensor gdata = in_grad[kData].FlatTo2D(s); + Assign(gdata, req[kData], dot(grad, wmat)); + } + + private: + /** The param of the fully connected layer.*/ + Param param_; +}; // class FullyConnectedOp + +// Decalre Factory function, used for dispatch specialization +template +Operator* CreateFullyConnectedOp(Param param); + +#if DMLC_USE_CXX11 +class FullyConnectedProp : public OperatorProperty { + public: + virtual std::vector ListArguments() const { + if (param_.no_bias == 0) { + return {"data", "weight", "bias"}; + } else { + return {"data", "weight"}; + } + } + + virtual void SetParam(const char *name, const char *val) { + param_.SetParam(name, val); + } + + virtual bool InferShape(std::vector *in_shape, + std::vector *out_shape) const { + using namespace mshadow; + if (param_.no_bias == 0) { + CHECK_EQ(in_shape->size(), 3) << "Input:[data, weight, bias]"; + } else { + CHECK_EQ(in_shape->size(), 2) << "Input:[data, weight]"; + } + CHECK_GT(param_.num_hidden, 0); + const TShape &dshape = (*in_shape)[0]; + CHECK_EQ(dshape.ndim(), 4) << \ + "Input data should be 4D in batch-1-1-hidden"; + CHECK_NE(dshape.ndim(), 0) << "Require data shape to be known"; + ShapeAssignCheck((*in_shape)[kWeight], Shape2(param_.num_hidden, dshape[3])); + if (param_.no_bias == 0) { + ShapeAssignCheck((*in_shape)[kBias], Shape1(param_.num_hidden)); + } + out_shape->clear(); + out_shape->push_back(dshape); + (*out_shape)[0][3] = param_.num_hidden; + return true; + } + + virtual OperatorProperty* Copy() const { + FullyConnectedProp* fc_sym = new FullyConnectedProp(); + fc_sym->param_ = this->param_; + return fc_sym; + } + + virtual std::string TypeString() const { + return "FullyConnecteded"; + } + // decalre dependency and inplace optimization options + virtual std::vector DeclareBackwardDependency( + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data) const { + return {out_grad[kOut], in_data[kData], in_data[kWeight]}; + } + + virtual std::vector > BackwardInplaceOption( + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &in_grad) const { + return {{in_grad[kData], in_data[kData]}}; + } + + Operator* CreateOperator(Context ctx) const; + + private: + Param param_; +}; // class FullyConnectedSymbol +#endif +} // namespace op +} // namespace mxnet +#endif // MXNET_OPERATOR_FULLY_CONNECTED_INL_H_ diff --git a/src/operator/fully_connected.cc b/src/operator/fully_connected.cc new file mode 100644 index 000000000000..362d3c5698aa --- /dev/null +++ b/src/operator/fully_connected.cc @@ -0,0 +1,22 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file fully_connected.cc + * \brief fully connect operator +*/ +#include +#include "./fully_connected-inl.h" +namespace mxnet { +namespace op { +template<> +Operator* CreateFullyConnectedOp(Param param) { + return new FullyConnectedOp(param); +} + +// DO_BIND_DISPATCH comes from static_operator_common.h +Operator* FullyConnectedProp::CreateOperator(Context ctx) const { + DO_BIND_DISPATCH(CreateFullyConnectedOp, param_); +} + +REGISTER_OP_PROPERTY(FullyConnected, FullyConnectedProp); +} // namespace op +} // namespace mxnet diff --git a/src/operator/fully_connected.cu b/src/operator/fully_connected.cu new file mode 100644 index 000000000000..223ef5166cc9 --- /dev/null +++ b/src/operator/fully_connected.cu @@ -0,0 +1,14 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file fully_connected.cu + * \brief fully connect operator +*/ +#include "./fully_connected-inl.h" +namespace mxnet { +namespace op { +template<> +Operator* CreateFullyConnectedOp(Param param) { + return new FullyConnectedOp(param); +} +} // namespace op +} // namespace mxnet diff --git a/src/static_operator/static_operator_common.h b/src/operator/operator_common.h similarity index 53% rename from src/static_operator/static_operator_common.h rename to src/operator/operator_common.h index 0d1553703200..87b581f28278 100644 --- a/src/static_operator/static_operator_common.h +++ b/src/operator/operator_common.h @@ -1,17 +1,17 @@ /*! * Copyright (c) 2015 by Contributors - * \file static_operator_common.h + * \file operator_common.h * \brief common internal header of most operators * this header includes utility functions operator can use - * common type definitions * \author Bing Xu */ -#ifndef MXNET_STATIC_OPERATOR_STATIC_OPERATOR_COMMON_H_ -#define MXNET_STATIC_OPERATOR_STATIC_OPERATOR_COMMON_H_ +#ifndef MXNET_OPERATOR_OPERATOR_COMMON_H_ +#define MXNET_OPERATOR_OPERATOR_COMMON_H_ #include -#include +#include #include + namespace mxnet { namespace op { /*! @@ -24,7 +24,7 @@ namespace op { */ template inline void Assign(OType &out, // NOLINT(*) - GradReqType req, + OpReqType req, const Exp &exp) { switch (req) { case kNullOp: break; @@ -49,25 +49,24 @@ inline void ShapeAssignCheck(TShape &out, const TS &shape) { // NOLINT(*) } } -/*! \brief type of operators */ -enum OpType { - kReLU = 0, - kFullc = 1, - kConv = 2, - kMaxPooling = 3, - kAvgPooling = 4, - kSumPooling = 5, - kFlatten = 6, - kReshape = 7, - kDropout = 8, -}; +// helper macro to implement bind dispatch +#if MXNET_USE_CUDA +#define DO_BIND_DISPATCH(Method, ...) \ + if (ctx.dev_mask == cpu::kDevMask) { \ + return Method(__VA_ARGS__); \ + } else { \ + return Method(__VA_ARGS__); \ + } +#else +#define DO_BIND_DISPATCH(Method, ...) \ + if (ctx.dev_mask == cpu::kDevMask) { \ + return Method(__VA_ARGS__); \ + } else { \ + LOG(FATAL) << "GPU is not enabled"; \ + return nullptr; \ + } +#endif -/*! - * \brief device invariant function to create operators - * \param type the type of operator - */ -template -StaticOperator *CreateOperator(OpType type); } // namespace op } // namespace mxnet -#endif // MXNET_STATIC_OPERATOR_STATIC_OPERATOR_COMMON_H_ +#endif // MXNET_OPERATOR_OPERATOR_COMMON_H_ diff --git a/src/static_operator/param.h b/src/operator/param.h similarity index 90% rename from src/static_operator/param.h rename to src/operator/param.h index c2829aced8ae..e1f6b4ee58d8 100644 --- a/src/static_operator/param.h +++ b/src/operator/param.h @@ -1,11 +1,13 @@ /*! * Copyright (c) 2015 by Contributors * \file param.h - * \brief operator params + * \brief Common operator parameters * \author Bing Xu */ -#ifndef MXNET_STATIC_OPERATOR_PARAM_H_ -#define MXNET_STATIC_OPERATOR_PARAM_H_ +#ifndef MXNET_OPERATOR_PARAM_H_ +#define MXNET_OPERATOR_PARAM_H_ + +#include namespace mxnet { namespace op { @@ -39,6 +41,12 @@ struct Param { int num_input_node; /*! \brief reserved fields, for future compatibility */ int reserved[64]; + + // constructor + Param() { + memset(this, 0, sizeof(Param)); + } + inline void SetParam(const char *name, const char* val) { if (!strcmp(name, "nhidden")) num_hidden = atoi(val); if (!strcmp(name, "num_input_node")) num_input_node = atoi(val); @@ -68,6 +76,5 @@ struct Param { } // namespace op } // namespace mxnet -#endif // MXNET_STATIC_OPERATOR_PARAM_H_ - +#endif // MXNET_OPERATOR_PARAM_H_ diff --git a/src/static_operator/activation_op-inl.h b/src/operator/static_operator/activation_op-inl.h similarity index 88% rename from src/static_operator/activation_op-inl.h rename to src/operator/static_operator/activation_op-inl.h index b1ad0d090706..cfb0b7cec8b5 100644 --- a/src/static_operator/activation_op-inl.h +++ b/src/operator/static_operator/activation_op-inl.h @@ -4,11 +4,11 @@ * \brief activation operator of mxnet */ -#ifndef MXNET_STATIC_OPERATOR_ACTIVATION_OP_INL_H_ -#define MXNET_STATIC_OPERATOR_ACTIVATION_OP_INL_H_ +#ifndef MXNET_OPERATOR_STATIC_OPERATOR_ACTIVATION_OP_INL_H_ +#define MXNET_OPERATOR_STATIC_OPERATOR_ACTIVATION_OP_INL_H_ #include -#include +#include #include #include "./static_operator_common.h" @@ -39,6 +39,7 @@ class ActivationOp : public StaticOperator { virtual void Backward(RunContext ctx, const std::vector &grad_next, const std::vector &in_data, + const std::vector &out_data, const std::vector &out_grad, const std::vector &req) { CHECK_EQ(grad_next.size(), 1); @@ -57,4 +58,4 @@ class ActivationOp : public StaticOperator { } // namespace op } // namespace mxnet -#endif // MXNET_STATIC_OPERATOR_ACTIVATION_OP_INL_H_ +#endif // MXNET_OPERATOR_STATIC_OPERATOR_ACTIVATION_OP_INL_H_ diff --git a/src/static_operator/convolution_op-inl.h b/src/operator/static_operator/convolution_op-inl.h similarity index 97% rename from src/static_operator/convolution_op-inl.h rename to src/operator/static_operator/convolution_op-inl.h index 0f7c5ccbb631..fc9b3369f2a6 100644 --- a/src/static_operator/convolution_op-inl.h +++ b/src/operator/static_operator/convolution_op-inl.h @@ -4,10 +4,10 @@ * \brief convolution op * \author Bing Xu */ -#ifndef MXNET_STATIC_OPERATOR_CONVOLUTION_OP_INL_H_ -#define MXNET_STATIC_OPERATOR_CONVOLUTION_OP_INL_H_ +#ifndef MXNET_OPERATOR_STATIC_OPERATOR_CONVOLUTION_OP_INL_H_ +#define MXNET_OPERATOR_STATIC_OPERATOR_CONVOLUTION_OP_INL_H_ -#include +#include #include #include #include "./static_operator_common.h" @@ -134,6 +134,7 @@ class ConvolutionOp : public StaticOperator { virtual void Backward(RunContext ctx, const std::vector &grad_next, const std::vector &in_data, + const std::vector &out_data, const std::vector &out_grad, const std::vector &req) { using namespace mshadow; @@ -266,4 +267,4 @@ class ConvolutionOp : public StaticOperator { }; // class ConvolutionOp } // namespace op } // namespace mxnet -#endif // MXNET_STATIC_OPERATOR_CONVOLUTION_OP_INL_H_ +#endif // MXNET_OPERATOR_STATIC_OPERATOR_CONVOLUTION_OP_INL_H_ diff --git a/src/static_operator/dropout_op-inl.h b/src/operator/static_operator/dropout_op-inl.h similarity index 91% rename from src/static_operator/dropout_op-inl.h rename to src/operator/static_operator/dropout_op-inl.h index aba19ad3c88b..23c9f6aab457 100644 --- a/src/static_operator/dropout_op-inl.h +++ b/src/operator/static_operator/dropout_op-inl.h @@ -4,10 +4,10 @@ * \brief dropout operator * \author Bing Xu */ -#ifndef MXNET_STATIC_OPERATOR_DROPOUT_OP_INL_H_ -#define MXNET_STATIC_OPERATOR_DROPOUT_OP_INL_H_ +#ifndef MXNET_OPERATOR_STATIC_OPERATOR_DROPOUT_OP_INL_H_ +#define MXNET_OPERATOR_STATIC_OPERATOR_DROPOUT_OP_INL_H_ -#include +#include #include #include "./mshadow_op.h" @@ -59,6 +59,7 @@ class DropoutOp : public StaticOperator { virtual void Backward(RunContext ctx, const std::vector &grad_next, const std::vector &in_data, + const std::vector &out_data, const std::vector &out_grad, const std::vector &req) { CHECK_EQ(grad_next.size(), 1); @@ -90,4 +91,4 @@ class DropoutOp : public StaticOperator { }; // class DropoutOp } // namespace op } // namespace mxnet -#endif // MXNET_STATIC_OPERATOR_DROPOUT_OP_INL_H_ +#endif // MXNET_OPERATOR_STATIC_OPERATOR_DROPOUT_OP_INL_H_ diff --git a/src/static_operator/mshadow_op.h b/src/operator/static_operator/mshadow_op.h similarity index 92% rename from src/static_operator/mshadow_op.h rename to src/operator/static_operator/mshadow_op.h index 2954b1f81a48..bb33471f168a 100644 --- a/src/static_operator/mshadow_op.h +++ b/src/operator/static_operator/mshadow_op.h @@ -4,8 +4,8 @@ * \brief extra mshadow operation for mxnet * \author Bing Xu */ -#ifndef MXNET_STATIC_OPERATOR_MSHADOW_OP_H_ -#define MXNET_STATIC_OPERATOR_MSHADOW_OP_H_ +#ifndef MXNET_OPERATOR_STATIC_OPERATOR_MSHADOW_OP_H_ +#define MXNET_OPERATOR_STATIC_OPERATOR_MSHADOW_OP_H_ #include #include @@ -102,5 +102,5 @@ struct square_root { } // namespace op } // namespace mxnet -#endif // MXNET_STATIC_OPERATOR_MSHADOW_OP_H_ +#endif // MXNET_OPERATOR_STATIC_OPERATOR_MSHADOW_OP_H_ diff --git a/src/static_operator/pooling_op-inl.h b/src/operator/static_operator/pooling_op-inl.h similarity index 95% rename from src/static_operator/pooling_op-inl.h rename to src/operator/static_operator/pooling_op-inl.h index e4bf344f7e5a..8c6014a8c2cf 100644 --- a/src/static_operator/pooling_op-inl.h +++ b/src/operator/static_operator/pooling_op-inl.h @@ -4,10 +4,10 @@ * \brief pooling operator * \author Bing Xu */ -#ifndef MXNET_STATIC_OPERATOR_POOLING_OP_INL_H_ -#define MXNET_STATIC_OPERATOR_POOLING_OP_INL_H_ +#ifndef MXNET_OPERATOR_STATIC_OPERATOR_POOLING_OP_INL_H_ +#define MXNET_OPERATOR_STATIC_OPERATOR_POOLING_OP_INL_H_ -#include +#include #include #include #include "./param.h" @@ -88,6 +88,7 @@ class PoolingOp : public StaticOperator { virtual void Backward(RunContext ctx, const std::vector &grad_next, const std::vector &in_data, + const std::vector &out_data, const std::vector &out_grad, const std::vector &req) { CHECK_EQ(grad_next.size(), 1); @@ -149,4 +150,4 @@ class PoolingOp : public StaticOperator { } // namespace op } // namespace mxnet -#endif // MXNET_STATIC_OPERATOR_POOLING_OP_INL_H_ +#endif // MXNET_OPERATOR_STATIC_OPERATOR_POOLING_OP_INL_H_ diff --git a/src/static_operator/reshape_op-inl.h b/src/operator/static_operator/reshape_op-inl.h similarity index 89% rename from src/static_operator/reshape_op-inl.h rename to src/operator/static_operator/reshape_op-inl.h index eb05a460573d..ba966a62a29f 100644 --- a/src/static_operator/reshape_op-inl.h +++ b/src/operator/static_operator/reshape_op-inl.h @@ -4,10 +4,10 @@ * \brief * \author Bing Xu */ -#ifndef MXNET_STATIC_OPERATOR_RESHAPE_OP_INL_H_ -#define MXNET_STATIC_OPERATOR_RESHAPE_OP_INL_H_ +#ifndef MXNET_OPERATOR_STATIC_OPERATOR_RESHAPE_OP_INL_H_ +#define MXNET_OPERATOR_STATIC_OPERATOR_RESHAPE_OP_INL_H_ -#include +#include #include namespace mxnet { @@ -52,6 +52,7 @@ class ReshapeOp : public StaticOperator { virtual void Backward(RunContext ctx, const std::vector &grad_next, const std::vector &in_data, + const std::vector &out_data, const std::vector &out_grad, const std::vector &req) { CHECK_EQ(grad_next.size(), 1); @@ -72,4 +73,4 @@ class ReshapeOp : public StaticOperator { } // namespace op } // namespace mxnet -#endif // MXNET_STATIC_OPERATOR_RESHAPE_OP_INL_H_ +#endif // MXNET_OPERATOR_STATIC_OPERATOR_RESHAPE_OP_INL_H_ diff --git a/src/operator/static_operator_wrapper.cc b/src/operator/static_operator_wrapper.cc deleted file mode 100644 index 97ed3b307291..000000000000 --- a/src/operator/static_operator_wrapper.cc +++ /dev/null @@ -1,97 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file static_operator.cc - * \brief the implementation of static operator - * \author Naiyan Wang - */ -#include -#include -#include -#include -#include -#include -#include -#include - -namespace mxnet { -namespace op { -/*! - * \brief StaticOperatorWrapper that wraps a static_operator - * This class do not need to be seen by others, so it sit in cc file. - * \sa Operator, StaticOperator - */ -class StaticOperatorWrapper: public Operator { - public: - StaticOperatorWrapper(StaticOperator* op, Context ctx) - : op_(op), ctx_(ctx) {} - - virtual ~StaticOperatorWrapper() { - delete op_; - } - - virtual int DescribeProperty() const { - return op_->DescribeProperty(); - } - - virtual void Forward(Option opt, - RunContext ctx, - const std::vector &in_data, - const std::vector &out_data) { - std::vector used_var; - std::vector mutate_var; - std::vector in; - std::vector out; - for (size_t i = 0; i < in_data.size(); ++i) { - used_var.push_back(in_data[i].var()); - in.push_back(in_data[i].data()); - } - for (size_t i = 0; i < out_data.size(); ++i) { - mutate_var.push_back(out_data[i].var()); - out.push_back(out_data[i].data()); - } - DAGEngine::Get()->Push([this, opt, ctx, in, out](RunContext ctx) { - op_->Forward(opt, ctx, in, out); - }, ctx_, used_var, mutate_var); - } - - virtual void Backward(RunContext ctx, - const std::vector &grad_next, - const std::vector &in_data, - const std::vector &out_grad, - const std::vector &req) { - std::vector used_var; - std::vector mutate_var; - std::vector grad_in; - std::vector grad_out; - std::vector data; - for (size_t i = 0; i < grad_next.size(); ++i) { - used_var.push_back(grad_next[i].var()); - grad_in.push_back(grad_next[i].data()); - } - for (size_t i = 0; i < in_data.size(); ++i) { - used_var.push_back(in_data[i].var()); - data.push_back(in_data[i].data()); - } - for (size_t i = 0; i < out_grad.size(); ++i) { - mutate_var.push_back(out_grad[i].var()); - grad_out.push_back(out_grad[i].data()); - } - DAGEngine::Get()->Push([this, ctx, grad_in, grad_out, data, req](RunContext ctx) { - op_->Backward(ctx, grad_in, data, grad_out, req); - }, ctx_, used_var, mutate_var); - } - - private: - /* \brief the static operator */ - StaticOperator* op_; - /** \brief the global context denots the device info. */ - Context ctx_; -}; -} // namespace op - -// implements CreateWrapper -Operator *Operator::CreateWrapper(StaticOperator *op, Context ctx) { - return new op::StaticOperatorWrapper(op, ctx); -} - -} // namespace mxnet diff --git a/src/registry.cc b/src/registry.cc index d51907dbdcd7..42fef1df3423 100644 --- a/src/registry.cc +++ b/src/registry.cc @@ -6,7 +6,7 @@ #include #include #include -#include +#include namespace mxnet { @@ -30,25 +30,7 @@ template NArrayFunctionEntry &Registry::Register(const std: template Registry *Registry::Get(); #endif -Symbol *AtomicSymbolEntry::GetSingletonSymbol() { - if (singleton_symbol) { - return singleton_symbol; - } else if (body && !use_param) { - singleton_symbol = new Symbol; - *singleton_symbol = Symbol::Create(body()); - return singleton_symbol; - } else { - return NULL; - } -} - -AtomicSymbolEntry::~AtomicSymbolEntry() { - if (singleton_symbol) { - delete singleton_symbol; - } -} - -template AtomicSymbolEntry &Registry::Register(const std::string& name); -template Registry *Registry::Get(); +template OperatorPropertyEntry &Registry::Register(const std::string& name); +template Registry *Registry::Get(); } // namespace mxnet diff --git a/src/static_operator/fully_connect_op-inl.h b/src/static_operator/fully_connect_op-inl.h deleted file mode 100644 index 062e17c8ea98..000000000000 --- a/src/static_operator/fully_connect_op-inl.h +++ /dev/null @@ -1,163 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file fully_connect_op-inl.h - * \brief fully connect operator and symbol -*/ -#ifndef MXNET_STATIC_OPERATOR_FULLY_CONNECT_OP_INL_H_ -#define MXNET_STATIC_OPERATOR_FULLY_CONNECT_OP_INL_H_ - -#include -#include -#include -#include -#include -#include "./static_operator_common.h" -#include "./param.h" - -namespace mxnet { -namespace op { -/** - * \brief This is the implementation of fully connected layer. - * - * \tparam xpu The device that the op will be executed on. - */ -template -class FullyConnectOp : public StaticOperator { - public: - /*! - * \brief constructor with parameters. Used in Bind() in corresponding symbol. - */ - explicit FullyConnectOp(Param p) { - this->param_ = p; - } - - virtual void Forward(Option opt, - RunContext ctx, - const std::vector &in_data, - const std::vector &out_data) { - using namespace mshadow; - using namespace mshadow::expr; - size_t expected = param_.no_bias == 0 ? 3 : 2; - CHECK_EQ(in_data.size(), expected); - CHECK_EQ(out_data.size(), 1); - // TODO(bing): check the BLAS Handle, be careful - // maybe need blas handle from context - Stream *s = static_cast *>(ctx.stream); - Tensor data = in_data[0].FlatTo2D(s); - Tensor wmat = in_data[1].get(s); - Tensor out = out_data[0].FlatTo2D(s); - out = dot(data, wmat.T()); - if (param_.no_bias == 0) { - Tensor bias = in_data[2].get(s); - out += repmat(bias, data.size(0)); - } - } - - virtual void Backward(RunContext ctx, - const std::vector &grad_next, - const std::vector &in_data, - const std::vector &out_grad, - const std::vector &req) { - using namespace mshadow; - using namespace mshadow::expr; - CHECK_EQ(grad_next.size(), 1); - size_t expected = param_.no_bias == 0 ? 3 : 2; - CHECK(in_data.size() == expected && out_grad.size() == expected); - CHECK_EQ(req.size(), expected); - // TODO(bing): check the BLAS Handle, be careful - // maybe need blas handle from context - Stream *s = static_cast *>(ctx.stream); - Tensor data = in_data[0].FlatTo2D(s); - Tensor wmat = in_data[1].get(s); - Tensor grad = grad_next[0].FlatTo2D(s); - // backprop - CHECK_NE(req[1], kWriteInplace) << "cannot write weight inplace"; - // gradient of weight - Tensor gwmat = out_grad[1].get(s); - Assign(gwmat, req[1], dot(grad.T(), data)); - // gradient of bias - if (param_.no_bias == 0) { - Tensor gbias = out_grad[2].get(s); - Assign(gbias, req[2], sum_rows(grad)); - } - // gradient of data - Tensor gdata = out_grad[0].FlatTo2D(s); - Assign(gdata, req[0], dot(grad, wmat)); - } - - private: - /** The param of the fully connected layer.*/ - Param param_; -}; // class FullyConnectOp - -/** - * @brief The symbol part of the fully connected layer. - */ -class FullyConnectSymbol : public AtomicSymbol { - public: - virtual std::vector DescribeArguments() const { - std::string ret[] = {"data", "weight", "bias"}; - if (param_.no_bias == 0) { - return std::vector(ret, ret + 3); - } else { - return std::vector(ret, ret + 2); - } - } - - virtual void SetParam(const char *name, const char *val) { - param_.SetParam(name, val); - } - - virtual bool InferShape(std::vector *in_shape, - std::vector *out_shape) const { - using namespace mshadow; - if (param_.no_bias == 0) { - CHECK_EQ(in_shape->size(), 3) << "Input:[data, weight, bias]"; - } else { - CHECK_EQ(in_shape->size(), 2) << "Input:[data, weight]"; - } - CHECK_GT(param_.num_hidden, 0); - const TShape &dshape = (*in_shape)[0]; - CHECK_EQ(dshape.ndim(), 4) << \ - "Input data should be 4D in batch-1-1-hidden"; - CHECK_NE(dshape.ndim(), 0) << "Require data shape to be known"; - ShapeAssignCheck((*in_shape)[1], Shape2(param_.num_hidden, dshape[3])); - if (param_.no_bias == 0) { - ShapeAssignCheck((*in_shape)[2], Shape1(param_.num_hidden)); - } - out_shape->clear(); - out_shape->push_back(dshape); - (*out_shape)[0][3] = param_.num_hidden; - return true; - } - - virtual AtomicSymbol* Copy() const { - FullyConnectSymbol* fc_sym = new FullyConnectSymbol(); - fc_sym->param_ = this->param_; - return fc_sym; - } - - virtual std::string TypeString() const { - return "FullyConnected"; - } - - /** - * @brief This is the template function of bind() implementation. - * - * @param ctx The device context - * @return A device dependent static operator can be used for execution. - */ - template - StaticOperator* Bind_(Context ctx) const; - // the real bind - StaticOperator* Bind(Context ctx) const; - - private: - /** The param of the fully connected layer.*/ - Param param_; -}; // class FullyConnectSymbol - -} // namespace op -} // namespace mxnet - -#endif // MXNET_STATIC_OPERATOR_FULLY_CONNECT_OP_INL_H_ diff --git a/src/static_operator/fully_connect_op.cc b/src/static_operator/fully_connect_op.cc deleted file mode 100644 index 9f3cad3292b0..000000000000 --- a/src/static_operator/fully_connect_op.cc +++ /dev/null @@ -1,33 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file fully_connect_sym.cc - * \brief fully connect operator symbol -*/ -#include -#include "../static_operator/fully_connect_op-inl.h" -namespace mxnet { -namespace op { -template<> -StaticOperator* FullyConnectSymbol::Bind_(Context ctx) const { - return new FullyConnectOp(param_); -} - -// put this after the template specialization -StaticOperator* FullyConnectSymbol::Bind(Context ctx) const { - if (ctx.dev_mask == cpu::kDevMask) { - return Bind_(ctx); - } else { - #if MXNET_USE_CUDA - return Bind_(ctx); - #else - LOG(FATAL) << "GPU is not enabled"; - return NULL; - #endif - } -} - -// register the symbol -REGISTER_ATOMIC_SYMBOL(FullyConnected, FullyConnectSymbol); - -} // namespace op -} // namespace mxnet diff --git a/src/static_operator/fully_connect_op.cu b/src/static_operator/fully_connect_op.cu deleted file mode 100644 index 8e3efbcaddfd..000000000000 --- a/src/static_operator/fully_connect_op.cu +++ /dev/null @@ -1,14 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file fully_connect_sym.cu - * \brief fully connect operator symbol -*/ -#include "../static_operator/fully_connect_op-inl.h" -namespace mxnet { -namespace op { -template<> -StaticOperator* FullyConnectSymbol::Bind_(Context ctx) const { - return new FullyConnectOp(param_); -} -} // namespace op -} // namespace mxnet diff --git a/src/static_operator/static_operator-inl.h b/src/static_operator/static_operator-inl.h deleted file mode 100644 index 99776b3db621..000000000000 --- a/src/static_operator/static_operator-inl.h +++ /dev/null @@ -1,49 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file static_operator-inl.h - * \brief static device invarient code to create operators - * \author Bing Xu -*/ -#ifndef MXNET_STATIC_OPERATOR_STATIC_OPERATOR_INL_H_ -#define MXNET_STATIC_OPERATOR_STATIC_OPERATOR_INL_H_ -#include -#include -#include -#include "./mshadow_op.h" -#include "./activation_op-inl.h" -#include "./convolution_op-inl.h" -#include "./pooling_op-inl.h" -#include "./reshape_op-inl.h" -#include "./dropout_op-inl.h" - -namespace mxnet { -namespace op { -/*! - * \brief device invariant function to create operators - * \param type the type of operator - * \tparam xpu the device type we are at - */ -template -inline StaticOperator *CreateOperator_(OpType type, mshadow::Random *prnd) { - switch (type) { - case kReLU: - return new ActivationOp(); - case kConv: - return new ConvolutionOp(); - case kMaxPooling: - return new PoolingOp(); - case kAvgPooling: - return new PoolingOp(); - case kFlatten: - return new ReshapeOp(); - case kReshape: - return new ReshapeOp(); - case kDropout: - return new DropoutOp(prnd); - default: LOG(FATAL) << "unknown OpType"; - } - return NULL; -} -} // namespace op -} // namespace mxnet -#endif // MXNET_STATIC_OPERATOR_STATIC_OPERATOR_INL_H_ diff --git a/src/static_operator/static_operator.cc b/src/static_operator/static_operator.cc deleted file mode 100644 index 4a2a121532dd..000000000000 --- a/src/static_operator/static_operator.cc +++ /dev/null @@ -1,44 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file static_operator.cc - * \brief - * \author: Bing Xu - */ -#include -#include -#include -#include -#include "./static_operator_common.h" - -namespace mxnet { -namespace op { -/** - * @brief return a OpType based on string description - * - * @param type the string description of operators - * @return the OpType indicated the type of operators - */ -OpType GetOpType(const char *type) { - if (!strcmp(type, "relu")) return kReLU; - if (!strcmp(type, "fullc")) return kFullc; - LOG(FATAL) << "unknown operator type " << type; - return kReLU; -} -} // namespace op - -StaticOperator *StaticOperator::Create(const char *type, - Context ctx) { - op::OpType otype = op::GetOpType(type); - if (ctx.dev_mask == cpu::kDevMask) { - return op::CreateOperator(otype); - } - if (ctx.dev_mask == gpu::kDevMask) { -#if MXNET_USE_CUDA - return op::CreateOperator(otype); -#else - LOG(FATAL) << "GPU is not enabled"; -#endif - } - return NULL; -} // namespace op -} // namespace mxnet diff --git a/src/static_operator/static_operator_cpu.cc b/src/static_operator/static_operator_cpu.cc deleted file mode 100644 index 5b6ea861213b..000000000000 --- a/src/static_operator/static_operator_cpu.cc +++ /dev/null @@ -1,20 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file static_operator_cpu.cc - * \brief CPU specialization of operator codes - * \author Bing Xu -*/ -#include "./static_operator-inl.h" - -namespace mxnet { -namespace op { -// todo add managing for prnd -mshadow::Random prnd_cpu(0); - -template<> -StaticOperator *CreateOperator(OpType type) { - return CreateOperator_(type, &prnd_cpu); -} - -} // namespace op -} // namespace mxnet diff --git a/src/static_operator/static_operator_gpu.cu b/src/static_operator/static_operator_gpu.cu deleted file mode 100644 index 580fe65d630d..000000000000 --- a/src/static_operator/static_operator_gpu.cu +++ /dev/null @@ -1,23 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file static_operator_gpu.cu - * \brief GPU specialization of operator code - * \author Bing Xu -*/ -#include -#include -#include "static_operator-inl.h" - -namespace mxnet { -namespace op { - -mshadow::Random prnd_gpu(0); - -template<> -StaticOperator *CreateOperator(OpType type) { - return CreateOperator_(type, &prnd_gpu); -} - -} // namespace op -} // namespace mxnet - diff --git a/src/symbol/static_graph.cc b/src/symbol/static_graph.cc new file mode 100644 index 000000000000..5419e26afe86 --- /dev/null +++ b/src/symbol/static_graph.cc @@ -0,0 +1,87 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file static_graph.cc + * \brief static graph of mxnet + */ +#include +#include +#include +#include + +namespace mxnet { +std::vector StaticGraph::TopoSort() const { + std::vector out_degree(nodes.size(), 0); + for (const Node &n : nodes) { + for (const DataEntry &e : n.inputs) { + ++out_degree[e.source_id]; + } + } + std::vector ret(nodes.size()); + auto result = ret.rbegin(); + std::queue queue; + for (size_t i = 0; i < nodes.size(); ++i) { + if (out_degree[i] == 0) { + queue.push(static_cast(i)); + } + } + while (!queue.empty()) { + uint32_t node_id = queue.front(); + queue.pop(); + *result = node_id; + ++result; + for (const DataEntry &e : nodes[node_id].inputs) { + out_degree[e.source_id] -= 1; + if (out_degree[e.source_id] == 0) { + queue.push(e.source_id); + } + } + } + return std::move(ret); +} + +bool StaticGraph::InferNodeShapes(const std::vector &topo_order, + std::vector > *node_out_shapes) const { + for (uint32_t nid : topo_order) { + const Node &node = nodes[nid]; + if (node.op != nullptr) { + std::vector in_shape; + for (const DataEntry &e : node.inputs) { + in_shape.push_back((*node_out_shapes)[e.source_id][e.index]); + } + if (!node.op->InferShape(&in_shape, &(*node_out_shapes)[nid])) return false; + for (size_t i = 0; i < node.inputs.size(); ++i) { + const DataEntry &e = node.inputs[i]; + (*node_out_shapes)[e.source_id][e.index] = in_shape[i]; + } + } + } + return true; +} + +bool StaticGraph::InferShape(std::vector *in_shape, + std::vector *out_shape) const { + std::vector > node_out_shapes(nodes.size()); + for (size_t i = 0; i < nodes.size(); ++i) { + int nout = 1; + if (nodes[i].op != nullptr) { + nout = nodes[i].op->NumReturns(); + } + node_out_shapes[i].resize(nout); + } + CHECK(in_shape->size() == arg_nodes.size()) + << "Wrong number of inputs to infer shape"; + for (size_t i = 0; i < arg_nodes.size(); ++i) { + node_out_shapes[arg_nodes[i]][0] = (*in_shape)[i]; + } + if (!InferNodeShapes(this->TopoSort(), + &node_out_shapes)) return false; + for (size_t i = 0; i < arg_nodes.size(); ++i) { + (*in_shape)[i] = node_out_shapes[arg_nodes[i]][0]; + } + for (size_t i = 0; i < outputs.size(); ++i) { + DataEntry e = outputs[i]; + (*out_shape)[i] = node_out_shapes[e.source_id][e.index]; + } + return true; +} +} // namespace mxnet diff --git a/src/symbol/symbol.cc b/src/symbol/symbol.cc index a81b7ce0cccd..86cf54feabfa 100644 --- a/src/symbol/symbol.cc +++ b/src/symbol/symbol.cc @@ -4,185 +4,432 @@ * \brief symbol of mxnet */ #include -#include -#include -#include -#include +#include +#include +#include +#include +#include namespace mxnet { +/*! + * \brief Node is represents node of an operator in the symbolic graph. + * + * It stores connection to the inputs to function represented by OperatorProperty + * NOTE on data structure: there are three types of node: + * - Normal node: contains all the necessary elements of a graph. + * - OperatorProperty: the inputs_ is empty, represents an OperatorProperty that has not been applied. + * - Variable: the sym_ is nullptr, represents an named Variable of tensors that can be composed. + */ +struct Symbol::Node { + /*! \brief Operator of this node */ + std::unique_ptr op; + /*! \brief name of the node */ + std::string name; + /*! \brief inputs to this node */ + std::vector inputs; + /*! + * \brief constructor + * \param op the OperatorProperty to construct the Node + * \param name the name of the symbol + */ + explicit Node(OperatorProperty* op = nullptr, const std::string& name = "") + : op(op), name(name) { + } + /*! \return Whether the symbol is atomic */ + inline bool is_atomic() const { + return inputs.size() == 0 && op != nullptr; + } + /*! \return Whether it is unit variable */ + inline bool is_variable() const { + return op == nullptr; + } +}; -void Symbol::FindArgUsers() { - arg_users_.reset(new std::vector >); - // depth first traversing - std::vector > stk; - stk.push_back({head_.get(), 0}); - while (!stk.empty()) { - std::pair& back = stk.back(); - if (back.first->in_symbol_.size() == back.second) { - stk.pop_back(); - } else { - Node* next_level = back.first->in_symbol_[back.second].get(); - if (next_level->sym_) { - stk.push_back({next_level, 0}); - } else { // back uses next_level which is a placeholder - arg_users_->push_back({back.first, back.second}); +/*! \return whwther the symbol is atomic */ +inline bool Symbol::is_atomic() const { + return heads_.size() == 1 && heads_[0].source->is_atomic(); +} +// implementation of template functions +template +inline void Symbol::DFSVisit(FVisit fvisit) const { + std::vector stack; + std::unordered_set visited; + // put the head into the graph + for (auto &head : heads_) { + Node *ptr = head.source.get(); + if (visited.count(ptr) == 0) { + stack.push_back(ptr); + visited.insert(ptr); + } + } + while (!stack.empty()) { + Node* back = stack.back(); + stack.pop_back(); + fvisit(back); + for (auto it = back->inputs.rbegin(); it != back->inputs.rend(); ++it) { + Node *ptr = it->source.get(); + if (visited.count(ptr) == 0) { + stack.push_back(ptr); + visited.insert(ptr); } - back.second += 1; } } } +int Symbol::FindDuplicateArgs(std::unordered_map *out) const { + out->clear(); + int max_dup = 1; + this->DFSVisit([out, &max_dup](Node *node) { + if (node->is_variable()) { + auto iter = out->find(node->name); + if (iter == out->end()) { + (*out)[node->name] = 1; + } else { + ++iter->second; + max_dup = std::max(max_dup, iter->second); + } + } + }); + return max_dup; +} + +// public functions Symbol Symbol::Copy() const { - Symbol s; std::unordered_map > old_new; - std::vector stk; - stk.push_back(head_.get()); - // copy nodes - while (!stk.empty()) { - Node* back = stk.back(); - stk.pop_back(); - if (old_new.count(back) == 0) { - if (back->sym_) { - old_new[back] = std::make_shared(back->sym_->Copy(), back->name_); + // use DFSVisit to copy all the nodes + this->DFSVisit([&old_new](Node *node) { + if (node->op == nullptr) { + old_new[node] = std::make_shared(nullptr, node->name); } else { - old_new[back] = std::make_shared(nullptr, back->name_); - } - } - for (const std::shared_ptr& n : back->in_symbol_) { - if (old_new.count(n.get()) == 0) { - stk.push_back(n.get()); + old_new[node] = std::make_shared(node->op->Copy(), node->name); } + }); + // connect nodes of new graph + for (const auto &kv : old_new) { + for (const DataEntry& n : kv.first->inputs) { + Node *ptr = n.source.get(); + kv.second->inputs.push_back(DataEntry(old_new[ptr], n.index)); } } - // connect nodes - for (auto kv : old_new) { - for (const std::shared_ptr& n : kv.first->in_symbol_) { - kv.second->in_symbol_.push_back(old_new[n.get()]); + // set the head + Symbol s; + for (auto &head : heads_) { + s.heads_.push_back(DataEntry(old_new[head.source.get()], head.index)); + } + return s; +} + +void Symbol::Print(std::ostream &os) const { + if (this->is_atomic()) { + os << "AtomicFunction "<< " Type:" << heads_[0].source->op->TypeString() << '\n' + << "Inputs:"; + std::vector args = this->ListArguments(); + for (size_t i = 0; i < args.size(); ++i) { + os << "\targ[" << i << "]=" << args[i] << "\n"; + } + } else { + // use DFSVisit to copy all the nodes + os << "Outputs:\n"; + for (size_t i = 0; i < heads_.size(); ++i) { + os << "\toutput[" << i << "]=" << heads_[i].source->name + << '(' << heads_[i].index << ")\n"; } + this->DFSVisit([&os](Node *node) { + if (node->is_variable()) { + os << "Variable:" << node->name << '\n'; + } else { + os << "Name: " << node->name << " Type:" << node->op->TypeString() << '\n' + << "Inputs:\n"; + for (size_t i = 0; i < node->inputs.size(); ++i) { + os << "\targ[" << i << "]=" << node->inputs[i].source->name + << '(' << node->inputs[i].index << ")\n"; + } + } + }); } - s.head_ = old_new[this->head_.get()]; - // copy arg_users_ - if (arg_users_) { - s.arg_users_.reset(new std::vector >); - std::transform(arg_users_->begin(), arg_users_->end(), std::back_inserter(*s.arg_users_), - [&old_new](const std::pair& n) -> std::pair { - return { old_new[n.first].get(), n.second }; - }); +} + +std::vector Symbol::ListArguments() const { + std::vector ret; + if (this->is_atomic()) { + return heads_[0].source->op->ListArguments(); + } else { + this->DFSVisit([&ret](Node *node) { + if (node->is_variable()) { + ret.push_back(node->name); + } + }); + return ret; } - return s; } -Symbol Symbol::operator () (const std::vector& args) const { - Symbol s = this->Copy(); - if (!s.arg_users_) { // if arg_users_ has not been populated - s.FindArgUsers(); +std::vector Symbol::ListReturns() const { + std::vector ret; + for (auto &head : heads_) { + if (head.source->is_variable()) { + ret.push_back(head.source->name); + } else { + // TODO(bing) rethink about output naming + auto &hname = head.source->name; + std::string rname = head.source->op->ListReturns()[head.index]; + if (hname.length() == 0) { + ret.push_back(std::move(rname)); + } else { + ret.push_back(hname + '_' + rname); + } + } + } + return std::move(ret); +} + +Symbol Symbol::operator[] (size_t index) const { + size_t nreturn = NumReturns(); + CHECK_LT(index, nreturn) << "Symbol only accept nonnegative index"; + if (nreturn == 1) { + return *this; + } else { + Symbol s; + s.heads_.push_back(heads_[index]); + return s; } - CHECK_LT(args.size(), s.arg_users_->size()) << "Too many args, requires " << s.arg_users_->size() - << " provided " << args.size(); +} + +void Symbol::Compose(const std::vector& args, + const std::string& name) { + CHECK_EQ(NumReturns(), 1) << "Only composition of value function is supported currently"; + CHECK(!heads_[0].source->is_variable()) << "Variable cannot be composed"; + heads_[0].source->name = name; for (size_t i = 0; i < args.size(); ++i) { - const std::pair& arg_user = (*s.arg_users_)[i]; - arg_user.first->in_symbol_[arg_user.second] = args[i].head_; - CHECK_NE(args[i].index_, -1) << "Argument " << i << " is a tuple, scalar is required"; - arg_user.first->in_index_[arg_user.second] = args[i].index_; + CHECK_NE(args[i].NumReturns(), 1) + << "Argument " << i << " is a tuple, scalar is required"; + } + // positional arguments requires all arguments for now. + // TODO(bing) consider partial assignments + if (this->is_atomic()) { + // atomic symbol do not have place holder for all the arguments + std::vector req_args = heads_[0].source->op->ListArguments(); + CHECK_EQ(args.size(), req_args.size()) + << "Incorrect number of arguments, requires " << req_args.size() + << ", provided " << args.size(); + heads_[0].source->inputs.resize(args.size()); + for (size_t i = 0; i < args.size(); ++i) { + heads_[0].source->inputs[i] = args[i].heads_[0]; + } + } else { + // find all the place holders + size_t arg_counter = 0; + std::unordered_map replace_map; + std::vector > replace_plan; + // replace map stores the existing replacement plan for arguments node + this->DFSVisit([&arg_counter, &replace_map, &replace_plan, &args](Node *node) { + // visit all the childs, find possible replacement + for (size_t i = 0; i < node->inputs.size(); ++i) { + DataEntry *e = &(node->inputs[i]); + if (e->source->is_variable()) { + const DataEntry *target = nullptr; + auto iter = replace_map.find(e->source.get()); + if (iter == replace_map.end()) { + if (arg_counter < args.size()) { + target = &(args[arg_counter].heads_[0]); + replace_map[e->source.get()] = target; + } + ++arg_counter; + } else { + target = iter->second; + } + replace_plan.push_back(std::make_pair(e, target)); + } + } + }); + CHECK_EQ(args.size(), arg_counter) + << "Incorrect number of arguments, requires " << arg_counter + << ", provided " << args.size(); + // now run the replacement + for (const auto& kv : replace_plan) { + *(kv.first) = *(kv.second); + } } - s.arg_users_.reset(); - return s; } -Symbol Symbol::operator () (const std::unordered_map& kwargs) const { +void Symbol::Compose(const std::unordered_map& kwargs, + const std::string& name) { + CHECK_EQ(NumReturns(), 1) << "Only composition of value function is supported currently"; + CHECK(!heads_[0].source->is_variable()) << "Variable cannot be composed"; + heads_[0].source->name = name; + for (const auto& kv : kwargs) { + CHECK_EQ(kv.second.NumReturns(), 1) + << "Keyword Argument " << kv.first << " is a tuple, scalar is required"; + } + size_t nmatched = 0; + if (this->is_atomic()) { + // atomic symbol do not have place holder for all the arguments + std::vector req_args = heads_[0].source->op->ListArguments(); + heads_[0].source->inputs.resize(req_args.size()); + for (size_t i = 0; i < req_args.size(); ++i) { + auto iter = kwargs.find(req_args[i]); + if (iter != kwargs.end()) { + heads_[0].source->inputs[i] = iter->second.heads_[0]; + ++nmatched; + } else { + // create a variable node + // TODO(bing): think of naming convention + if (name.length() == 0) { + heads_[0].source->inputs[i] = DataEntry( + std::make_shared(nullptr, req_args[i]), 0); + } else { + heads_[0].source->inputs[i] = DataEntry( + std::make_shared(nullptr, name + '_' + req_args[i]), 0); + } + } + } + // if things goes wrong recover the old state + if (nmatched != kwargs.size()) { + heads_[0].source->inputs.clear(); + } + } else { + // find all the arguments positions + std::unordered_map dup_args; + int max_dup = this->FindDuplicateArgs(&dup_args); + if (max_dup > 1) { + for (const auto& kv : dup_args) { + CHECK_EQ(kv.second, 1) + << " Argument name=\"" << kv.first << "\" occured in " + << kv.second << " places in the Symbol, " + << "Keyword argument call is not supported because this duplication."; + } + } + CHECK_EQ(max_dup, 1); + std::vector > replace_plan; + std::unordered_set visited; + // replace map stores the existing replacement plan for arguments node + this->DFSVisit([&nmatched, &visited, &kwargs, &replace_plan](Node *node) { + // visit all the childs, find possible replacement + for (size_t i = 0; i < node->inputs.size(); ++i) { + DataEntry *e = &(node->inputs[i]); + if (e->source->is_variable()) { + const DataEntry *target = nullptr; + auto iter = kwargs.find(e->source->name); + if (iter != kwargs.end()) { + target = &(iter->second.heads_[0]); + // count how many arguments have been matched. + if (visited.count(e->source.get()) == 0) { + visited.insert(e->source.get()); + ++nmatched; + } + replace_plan.push_back(std::make_pair(e, target)); + } + } + } + }); + if (nmatched == kwargs.size()) { + for (const auto& kv : replace_plan) { + *(kv.first) = *(kv.second); + } + } + } + if (nmatched != kwargs.size()) { + // Error message handling + std::vector req_args = this->ListArguments(); + std::unordered_set keys(req_args.begin(), req_args.end()); + std::ostringstream msg; + msg << "\nCandidate arguments:\n"; + for (size_t i = 0; i < req_args.size(); ++i) { + msg << "\t[" << i << ']' << req_args[i] << '\n'; + } + for (const auto& kv : kwargs) { + CHECK_NE(keys.count(kv.first), 0) + << "Keyword Argument " << kv.first << " not found in arguments." + << msg.str(); + } + } +} + +Symbol Symbol::operator () (const std::vector& args, + const std::string& name) const { Symbol s = this->Copy(); - if (!s.arg_users_) { // if arg_users_ has not been populated - s.FindArgUsers(); - } - CHECK_LT(kwargs.size(), s.arg_users_->size()) << "Too many args, requires " - << s.arg_users_->size() << " provided " << kwargs.size(); - for (size_t i = 0; i < s.arg_users_->size(); ++i) { - const std::pair& arg_user = (*s.arg_users_)[i]; - const std::string& name = arg_user.first->name_; - if (!(name == "") && kwargs.count(name) != 0) { - const Symbol& bind = kwargs.at(name); - arg_user.first->in_symbol_[arg_user.second] = bind.head_; - CHECK_NE(bind.index_, -1) << "Argument " << name << " is a tuple, scalar is required"; - arg_user.first->in_index_[arg_user.second] = bind.index_; - } - } - s.arg_users_.reset(); - // TODO(linmin): report error if kwargs contains non-existing keys + s.Compose(args, name); return s; } -Symbol Symbol::operator[] (int index) const { - CHECK_EQ(index_, -1) << "Current symbol can't be indexed because it returns a scalar."; - Symbol s = *this; - s.index_ = index; +Symbol Symbol::operator () (const std::unordered_map& kwargs, + const std::string& name) const { + Symbol s = this->Copy(); + s.Compose(kwargs, name); return s; } -std::vector Symbol::ListArgs() { - std::vector ret; - if (!arg_users_) { - FindArgUsers(); - } - std::transform(arg_users_->begin(), arg_users_->end(), std::back_inserter(ret), - [&](const std::pair& n) -> std::string { - return n.first->in_symbol_[n.second]->name_; - }); - return ret; +bool Symbol::InferShape(std::vector *in_shape, + std::vector *out_shape) const { + StaticGraph g; + this->ToStaticGraph(&g); + return g.InferShape(in_shape, out_shape); } -Symbol Symbol::Create(AtomicSymbol *atomic_symbol) { +Symbol Symbol::Create(OperatorProperty *op) { + // use special representation for atomic symbol + auto node = std::make_shared(op, ""); + size_t nret = op->NumVisibleReturns(); Symbol s; - std::vector args = atomic_symbol->DescribeArguments(); - std::vector rets = atomic_symbol->DescribeReturns(); - // set head_ - s.head_ = std::make_shared(atomic_symbol, ""); - // set index_ - s.index_ = rets.size() > 1 ? -1 : 0; - // set head_->in_index_ - s.head_->in_index_ = std::vector(args.size(), 0); - // set head_->in_symbol_ - for (auto name : args) { - s.head_->in_symbol_.push_back(std::make_shared(nullptr, name)); - } - // set head_->out_shape_ - s.head_->out_shape_ = std::vector(rets.size()); + for (uint32_t i = 0; i < nret; ++i) { + s.heads_.push_back(DataEntry(node, i)); + } return s; } -Symbol Symbol::Create(const std::string& type_name, - const std::vector >& param) { - const AtomicSymbolEntry *entry = Registry::Find(type_name); - CHECK_NE(entry, NULL) << type_name << " is not a valid Symbol type"; - AtomicSymbol* atomic_symbol = (*entry)(); - for (auto p : param) { - atomic_symbol->SetParam(p.first.c_str(), p.second.c_str()); +Symbol Symbol::CreateGroup(const std::vector &symbols) { + Symbol ret; + for (const auto &s : symbols) { + ret.heads_.insert(ret.heads_.end(), s.heads_.begin(), s.heads_.end()); } - return Create(atomic_symbol); + return std::move(ret); } -StaticGraph Symbol::ToStaticGraph() { - StaticGraph graph; - Dfs(this->head_, &graph); - return graph; +Symbol Symbol::CreateVariable(const std::string &name) { + Symbol s; + s.heads_.push_back(DataEntry(std::make_shared(nullptr, name), 0)); + return std::move(s); } -CompositeOperator* Symbol::Bind(Context ctx, const std::unordered_map& in) { - StaticGraph graph = this->ToStaticGraph(); - return NULL; - // TODO(bing): pass the graph and in to initlialize a composite op. -} +void Symbol::ToStaticGraph(StaticGraph *out_graph) const { + // TODO(bing): Check unique name + std::vector node_order; + std::unordered_map node_index; + auto &arg_nodes = out_graph->arg_nodes; + arg_nodes.clear(); -void Symbol::Dfs(const std::shared_ptr node, StaticGraph *graph) { - int id = graph->FindNodeByName(node->name_, node->sym_); - for (size_t i = 0; i < node->in_symbol_.size(); ++i) { - std::shared_ptr parent = node->in_symbol_[i]; - int parent_id = graph->FindNodeByName(parent->name_, parent->sym_); - graph->connected_graph[parent_id].push_back(id); - graph->output_index[parent_id].push_back(node->in_index_[i]); - if (parent->sym_) { - Dfs(parent, graph); + this->DFSVisit([&node_order, &node_index, &arg_nodes](Node *n) { + uint32_t nid = static_cast(node_index.size()); + node_index[n] = nid; + if (n->is_variable()) { + arg_nodes.push_back(nid); + } + node_order.push_back(n); + }); + // setup nodes + out_graph->nodes.resize(node_index.size()); + for (uint32_t nid = 0; nid < node_order.size(); ++nid) { + if (node_order[nid]->op != nullptr) { + out_graph->nodes[nid].op.reset(node_order[nid]->op->Copy()); + } else { + out_graph->nodes[nid].op.reset(nullptr); + } + out_graph->nodes[nid].name = node_order[nid]->name; + auto &inputs = out_graph->nodes[nid].inputs; + inputs.clear(); + for (const DataEntry &src : node_order[nid]->inputs) { + StaticGraph::DataEntry e; + e.index = src.index; + e.source_id = node_index[src.source.get()]; + inputs.push_back(e); } } + // setup heads + out_graph->outputs.clear(); + for (auto &head : heads_) { + StaticGraph::DataEntry e; + e.source_id = node_index[head.source.get()]; + e.index = head.index; + out_graph->outputs.push_back(e); + } } - } // namespace mxnet From f932cd22a6a2ddf220d2ea6df9d4e78e376dc18e Mon Sep 17 00:00:00 2001 From: Yutian Li Date: Sun, 16 Aug 2015 22:27:48 +0800 Subject: [PATCH 15/21] [storage] fix Makefile --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index fdda325b8240..69a5b7c5c052 100644 --- a/Makefile +++ b/Makefile @@ -58,7 +58,7 @@ endif BIN = test/api_registry_test test/test_storage OBJ = narray_op_cpu.o # add threaded engine after it is done -OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o operator.o fully_connect_op_cpu.o static_graph.o storage.o +OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o operator.o fully_connect_op_cpu.o static_graph.o CUOBJ = SLIB = lib/libmxnet.so ALIB = lib/libmxnet.a From 364349f6b70ca45c58e42290e0e95c979eba350a Mon Sep 17 00:00:00 2001 From: Yutian Li Date: Sun, 16 Aug 2015 22:29:23 +0800 Subject: [PATCH 16/21] [storage] make Makefile right this time --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 69a5b7c5c052..eeb62a31fdaa 100644 --- a/Makefile +++ b/Makefile @@ -58,7 +58,7 @@ endif BIN = test/api_registry_test test/test_storage OBJ = narray_op_cpu.o # add threaded engine after it is done -OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o operator.o fully_connect_op_cpu.o static_graph.o +OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o storage.o fully_connect_op_cpu.o static_graph.o CUOBJ = SLIB = lib/libmxnet.so ALIB = lib/libmxnet.a From 62f98a8b2e8380ff1d2b0035a25b217c4eb7ae19 Mon Sep 17 00:00:00 2001 From: Yutian Li Date: Sun, 16 Aug 2015 22:37:46 +0800 Subject: [PATCH 17/21] [storage] fix Makefile for dependency --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index eeb62a31fdaa..74944012df8e 100644 --- a/Makefile +++ b/Makefile @@ -58,7 +58,7 @@ endif BIN = test/api_registry_test test/test_storage OBJ = narray_op_cpu.o # add threaded engine after it is done -OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o storage.o fully_connect_op_cpu.o static_graph.o +OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o storage.o fully_connected_cpu.o static_graph.o CUOBJ = SLIB = lib/libmxnet.so ALIB = lib/libmxnet.a From 7cae8ceb5c4f53097ea539db62596b14e64e76bc Mon Sep 17 00:00:00 2001 From: Yutian Li Date: Mon, 17 Aug 2015 00:04:34 +0800 Subject: [PATCH 18/21] [storage] add document to some common functions --- doc/Doxyfile | 2 +- src/common/cuda_utils.h | 10 ++++++++++ src/common/utils.h | 18 ++++++++++++++++++ 3 files changed, 29 insertions(+), 1 deletion(-) diff --git a/doc/Doxyfile b/doc/Doxyfile index 7688fa1d7bb4..3e232e58fd7e 100644 --- a/doc/Doxyfile +++ b/doc/Doxyfile @@ -753,7 +753,7 @@ WARN_LOGFILE = # spaces. # Note: If this tag is empty the current directory is searched. -INPUT = include +INPUT = include src/common # This tag can be used to specify the character encoding of the source files # that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses diff --git a/src/common/cuda_utils.h b/src/common/cuda_utils.h index a2730481c828..8e5e79473b6c 100644 --- a/src/common/cuda_utils.h +++ b/src/common/cuda_utils.h @@ -17,6 +17,11 @@ namespace common { namespace cuda { +/*! + * \brief Get string representation of cuBLAS errors. + * \param error The error. + * \return String representation. + */ inline const char* CublasGetErrorString(cublasStatus_t error) { switch (error) { case CUBLAS_STATUS_SUCCESS: @@ -43,6 +48,11 @@ inline const char* CublasGetErrorString(cublasStatus_t error) { return "Unknown cuBLAS status"; } +/*! + * \brief Get string representation of cuRAND errors. + * \param error The error. + * \return String representation. + */ inline const char* CurandGetErrorString(curandStatus_t status) { switch (status) { case CURAND_STATUS_SUCCESS: diff --git a/src/common/utils.h b/src/common/utils.h index e10f240a31cc..00bdc1c1754e 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -35,17 +35,35 @@ struct UniqueIf { } // namespace helper +/*! + * \brief Constructs an object of type `T` and wraps it in a `std:::unique_ptr`. + * + * Constructs a non-array type `T`. The arguments `args` are passed to the + * constructor of `T`. The function does not participate in the overload + * resolution if `T` is an array type. + */ template typename helper::UniqueIf::SingleObject MakeUnique(Args&&... args) { return std::unique_ptr(new T(std::forward(args)...)); } +/*! + * \brief Constructs an object of type `T` and wraps it in a `std:::unique_ptr`. + * + * Constructs an array of unknown bound `T`. The function does not participate + * in the overload resolution unless `T` is an array of unknown bound. + */ template typename helper::UniqueIf::UnknownBound MakeUnique(size_t n) { using U = typename std::remove_extent::type; return std::unique_ptr(new U[n]{}); } +/*! + * \brief Constructs an object of type `T` and wraps it in a `std:::unique_ptr`. + * + * Constructs an arrays of known bound is disallowed. + */ template typename helper::UniqueIf::KnownBound MakeUnique(Args&&...) = delete; From 1fe4efb94a7ee5d93c82b73454c00df781a7c236 Mon Sep 17 00:00:00 2001 From: Yutian Li Date: Mon, 17 Aug 2015 00:17:43 +0800 Subject: [PATCH 19/21] [storage] disable Doxygen preprocessing so every doc works --- doc/Doxyfile | 2 +- src/common/concurrent_blocking_queue.h | 3 +++ src/common/cuda_utils.h | 4 ++++ 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/doc/Doxyfile b/doc/Doxyfile index 3e232e58fd7e..c49c0267bd1f 100644 --- a/doc/Doxyfile +++ b/doc/Doxyfile @@ -1925,7 +1925,7 @@ PERLMOD_MAKEVAR_PREFIX = # C-preprocessor directives found in the sources and include files. # The default value is: YES. -ENABLE_PREPROCESSING = YES +ENABLE_PREPROCESSING = NO # If the MACRO_EXPANSION tag is set to YES doxygen will expand all macro names # in the source code. If set to NO only conditional compilation will be diff --git a/src/common/concurrent_blocking_queue.h b/src/common/concurrent_blocking_queue.h index 14bab00d8280..82e2598816a5 100644 --- a/src/common/concurrent_blocking_queue.h +++ b/src/common/concurrent_blocking_queue.h @@ -13,6 +13,9 @@ #include #include +/*! + * \brief Common components. + */ namespace common { /*! diff --git a/src/common/cuda_utils.h b/src/common/cuda_utils.h index 8e5e79473b6c..3a1bf7e4b2a4 100644 --- a/src/common/cuda_utils.h +++ b/src/common/cuda_utils.h @@ -15,6 +15,10 @@ #include namespace common { + +/*! + * \brief CUDA utilities. + */ namespace cuda { /*! From f5358fe0d55671c097fb9fc3766ac8ef2e227956 Mon Sep 17 00:00:00 2001 From: Yutian Li Date: Mon, 17 Aug 2015 00:42:29 +0800 Subject: [PATCH 20/21] [storage] document some more to pass lint --- doc/Doxyfile | 4 ++-- include/mxnet/operator.h | 6 ++++++ src/common/cuda_utils.h | 24 +++++++++++++++++++++++- src/common/utils.h | 36 ++++++++++++++++++++++++++++++++---- 4 files changed, 63 insertions(+), 7 deletions(-) diff --git a/doc/Doxyfile b/doc/Doxyfile index c49c0267bd1f..41c86905b59f 100644 --- a/doc/Doxyfile +++ b/doc/Doxyfile @@ -1925,7 +1925,7 @@ PERLMOD_MAKEVAR_PREFIX = # C-preprocessor directives found in the sources and include files. # The default value is: YES. -ENABLE_PREPROCESSING = NO +ENABLE_PREPROCESSING = YES # If the MACRO_EXPANSION tag is set to YES doxygen will expand all macro names # in the source code. If set to NO only conditional compilation will be @@ -1974,7 +1974,7 @@ INCLUDE_FILE_PATTERNS = # recursively expanded use the := operator instead of the = operator. # This tag requires that the tag ENABLE_PREPROCESSING is set to YES. -PREDEFINED = +PREDEFINED = MXNET_USE_CUDA DMLC_USE_CXX11 # If the MACRO_EXPANSION and EXPAND_ONLY_PREDEF tags are set to YES then this # tag can be used to specify a list of macro names that should be expanded. The diff --git a/include/mxnet/operator.h b/include/mxnet/operator.h index 0299ef2bf167..938083dbab33 100644 --- a/include/mxnet/operator.h +++ b/include/mxnet/operator.h @@ -250,6 +250,8 @@ class OperatorProperty { * return {{out_data[0], in_data[0]}}; * } * \endcode + * \param in_data The input data in forward pass. + * \param out_data The output data in forward pass. * \return list of pair of integers taken from the inputs vector, * indicating possible in place operations. */ @@ -273,6 +275,10 @@ class OperatorProperty { * return {in_grad[0], in_data[0]}}; * } * \endcode + * \param in_data The input data in forward pass. + * \param out_data The output data in forward pass. + * \param in_grad Gradient of inputs in backward pass. + * \param out_grad Gradient of outputs in backward pass. * \return list of pair of integers taken from the inputs vector, * indicating possible in place operations. */ diff --git a/src/common/cuda_utils.h b/src/common/cuda_utils.h index 3a1bf7e4b2a4..6002da20c1fe 100644 --- a/src/common/cuda_utils.h +++ b/src/common/cuda_utils.h @@ -54,7 +54,7 @@ inline const char* CublasGetErrorString(cublasStatus_t error) { /*! * \brief Get string representation of cuRAND errors. - * \param error The error. + * \param status The status. * \return String representation. */ inline const char* CurandGetErrorString(curandStatus_t status) { @@ -92,18 +92,34 @@ inline const char* CurandGetErrorString(curandStatus_t status) { } // namespace cuda } // namespace common +/*! + * \brief Check CUDA error. + * \param msg Message to print if an error occured. + */ #define CHECK_CUDA_ERROR(msg) \ { \ cudaError_t e = cudaGetLastError(); \ CHECK_EQ(e, cudaSuccess) << (msg) << " CUDA: " << cudaGetErrorString(e); \ } +/*! + * \brief Protected CUDA call. + * \param func Expression to call. + * + * It checks for CUDA errors after invocation of the expression. + */ #define CUDA_CALL(func) \ { \ cudaError_t e = (func); \ CHECK_EQ(e, cudaSuccess) << "CUDA: " << cudaGetErrorString(e); \ } +/*! + * \brief Protected cuBLAS call. + * \param func Expression to call. + * + * It checks for cuBLAS errors after invocation of the expression. + */ #define CUBLAS_CALL(func) \ { \ cublasStatus_t e = (func); \ @@ -111,6 +127,12 @@ inline const char* CurandGetErrorString(curandStatus_t status) { << "cuBLAS: " << common::cuda::CublasGetErrorString(e); \ } +/*! + * \brief Protected cuRAND call. + * \param func Expression to call. + * + * It checks for cuRAND errors after invocation of the expression. + */ #define CURAND_CALL(func) \ { \ curandStatus_t e = (func); \ diff --git a/src/common/utils.h b/src/common/utils.h index 00bdc1c1754e..7067c85d28b3 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -16,27 +16,51 @@ namespace common { #if DMLC_USE_CXX11 +/*! + * \brief Helper functions. + */ namespace helper { +/*! + * \brief Helper for non-array type `T`. + */ template struct UniqueIf { + /*! + * \brief Type of `T`. + */ using SingleObject = std::unique_ptr; }; +/*! + * \brief Helper for an array of unknown bound `T`. + */ template struct UniqueIf { + /*! + * \brief Type of `T`. + */ using UnknownBound = std::unique_ptr; }; +/*! + * \brief Helper for an array of known bound `T`. + */ template struct UniqueIf { + /*! + * \brief Type of `T`. + */ using KnownBound = void; }; } // namespace helper /*! - * \brief Constructs an object of type `T` and wraps it in a `std:::unique_ptr`. + * \brief Constructs an object of type `T` and wraps it in a `std::unique_ptr`. + * \param args List of arguments with which an instance of `T` will be + * constructed. + * \return `std::unique_ptr` of an instance of type `T`. * * Constructs a non-array type `T`. The arguments `args` are passed to the * constructor of `T`. The function does not participate in the overload @@ -48,7 +72,9 @@ typename helper::UniqueIf::SingleObject MakeUnique(Args&&... args) { } /*! - * \brief Constructs an object of type `T` and wraps it in a `std:::unique_ptr`. + * \brief Constructs an object of type `T` and wraps it in a `std::unique_ptr`. + * \param n The size of the array to construct. + * \return `std::unique_ptr` of an instance of type `T`. * * Constructs an array of unknown bound `T`. The function does not participate * in the overload resolution unless `T` is an array of unknown bound. @@ -60,12 +86,14 @@ typename helper::UniqueIf::UnknownBound MakeUnique(size_t n) { } /*! - * \brief Constructs an object of type `T` and wraps it in a `std:::unique_ptr`. + * \brief Constructs an object of type `T` and wraps it in a `std::unique_ptr`. + * \param args List of arguments with which an instance of `T` will be + * constructed. * * Constructs an arrays of known bound is disallowed. */ template -typename helper::UniqueIf::KnownBound MakeUnique(Args&&...) = delete; +typename helper::UniqueIf::KnownBound MakeUnique(Args&&... args) = delete; #endif // DMLC_USE_CXX11 From 3a8e0714d51e2f4c46dd92af6ffb72cd559ed21f Mon Sep 17 00:00:00 2001 From: Yutian Li Date: Mon, 17 Aug 2015 00:51:54 +0800 Subject: [PATCH 21/21] [storage] Doxygen does not support scope operator? --- src/common/utils.h | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/common/utils.h b/src/common/utils.h index 7067c85d28b3..f55ebc26535f 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -57,10 +57,11 @@ struct UniqueIf { } // namespace helper /*! - * \brief Constructs an object of type `T` and wraps it in a `std::unique_ptr`. + * \brief Constructs an object of type `T` and wraps it in a + * `std``::``unique_ptr`. * \param args List of arguments with which an instance of `T` will be * constructed. - * \return `std::unique_ptr` of an instance of type `T`. + * \return `std``::``unique_ptr` of an instance of type `T`. * * Constructs a non-array type `T`. The arguments `args` are passed to the * constructor of `T`. The function does not participate in the overload @@ -72,9 +73,10 @@ typename helper::UniqueIf::SingleObject MakeUnique(Args&&... args) { } /*! - * \brief Constructs an object of type `T` and wraps it in a `std::unique_ptr`. + * \brief Constructs an object of type `T` and wraps it in a + * `std``::``unique_ptr`. * \param n The size of the array to construct. - * \return `std::unique_ptr` of an instance of type `T`. + * \return `std``::``unique_ptr` of an instance of type `T`. * * Constructs an array of unknown bound `T`. The function does not participate * in the overload resolution unless `T` is an array of unknown bound. @@ -86,7 +88,8 @@ typename helper::UniqueIf::UnknownBound MakeUnique(size_t n) { } /*! - * \brief Constructs an object of type `T` and wraps it in a `std::unique_ptr`. + * \brief Constructs an object of type `T` and wraps it in a + * `std``::``unique_ptr`. * \param args List of arguments with which an instance of `T` will be * constructed. *