diff --git a/.gitignore b/.gitignore index 1c1cd99ab7a4..1aa2ce47da11 100644 --- a/.gitignore +++ b/.gitignore @@ -49,3 +49,7 @@ Debug *.swp *.swo *.swn + +# Emacs +.clang_complete +.dir-locals.el diff --git a/Makefile b/Makefile index 581674c784a2..74944012df8e 100644 --- a/Makefile +++ b/Makefile @@ -55,10 +55,10 @@ ifneq ($(ADD_LDFLAGS), NONE) endif #BIN = test/test_threaded_engine test/api_registry_test -BIN = test/api_registry_test -OBJ = storage.o narray_op_cpu.o +BIN = test/api_registry_test test/test_storage +OBJ = narray_op_cpu.o # add threaded engine after it is done -OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o fully_connected_cpu.o static_graph.o +OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o storage.o fully_connected_cpu.o static_graph.o CUOBJ = SLIB = lib/libmxnet.so ALIB = lib/libmxnet.a @@ -93,6 +93,7 @@ lib/libmxnet.a: $(OBJ) $(OBJCXX11) $(CUOBJ) lib/libmxnet.so: $(OBJ) $(OBJCXX11) $(CUOBJ) test/api_registry_test: test/api_registry_test.cc lib/libmxnet.a +test/test_storage: test/test_storage.cc lib/libmxnet.a #test/test_threaded_engine: test/test_threaded_engine.cc api/libmxnet.a $(BIN) : diff --git a/doc/Doxyfile b/doc/Doxyfile index 7688fa1d7bb4..41c86905b59f 100644 --- a/doc/Doxyfile +++ b/doc/Doxyfile @@ -753,7 +753,7 @@ WARN_LOGFILE = # spaces. # Note: If this tag is empty the current directory is searched. -INPUT = include +INPUT = include src/common # This tag can be used to specify the character encoding of the source files # that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses @@ -1974,7 +1974,7 @@ INCLUDE_FILE_PATTERNS = # recursively expanded use the := operator instead of the = operator. # This tag requires that the tag ENABLE_PREPROCESSING is set to YES. -PREDEFINED = +PREDEFINED = MXNET_USE_CUDA DMLC_USE_CXX11 # If the MACRO_EXPANSION and EXPAND_ONLY_PREDEF tags are set to YES then this # tag can be used to specify a list of macro names that should be expanded. The diff --git a/include/mxnet/narray.h b/include/mxnet/narray.h index c2b6ac3bc882..92257b3f0269 100644 --- a/include/mxnet/narray.h +++ b/include/mxnet/narray.h @@ -128,14 +128,14 @@ class NArray { /*! \brief the real data chunk that backs NArray */ struct Chunk { /*! \brief storage handlefrom storage engine */ - StorageManager::Handle shandle; + Storage::Handle shandle; /*! \brief variable from DAG engine */ DAGEngine::Variable var; /*! \brief holds the data content */ TBlob data; /*! * \brief if this is true, this means the data do not come - * from StorageManager, and do not need to be freed + * from Storage, and do not need to be freed */ bool static_data; /*! \brief whether allocation is delayed */ @@ -163,7 +163,7 @@ class NArray { /*! \brief check if delay alloc is on, do alloc if not yet done */ inline void CheckAndAlloc(void) { if (delay_alloc) { - shandle = StorageManager::Get()->Alloc(data.shape_.Size() * sizeof(real_t), shandle.ctx); + shandle = Storage::Get()->Alloc(data.shape_.Size() * sizeof(real_t), shandle.ctx); data = TBlob(static_cast(shandle.dptr), data.shape_, shandle.ctx.dev_mask); delay_alloc = false; } @@ -174,9 +174,9 @@ class NArray { DAGEngine::Get()->PushDelete([](RunContext s) {}, shandle.ctx, var); } else { CHECK(!delay_alloc) << "deleted before allocation"; - StorageManager::Handle h = this->shandle; + Storage::Handle h = this->shandle; DAGEngine::Get()->PushDelete([h](RunContext s) { - StorageManager::Get()->Free(h); + Storage::Get()->Free(h); }, shandle.ctx, var); } } diff --git a/include/mxnet/operator.h b/include/mxnet/operator.h index 0299ef2bf167..938083dbab33 100644 --- a/include/mxnet/operator.h +++ b/include/mxnet/operator.h @@ -250,6 +250,8 @@ class OperatorProperty { * return {{out_data[0], in_data[0]}}; * } * \endcode + * \param in_data The input data in forward pass. + * \param out_data The output data in forward pass. * \return list of pair of integers taken from the inputs vector, * indicating possible in place operations. */ @@ -273,6 +275,10 @@ class OperatorProperty { * return {in_grad[0], in_data[0]}}; * } * \endcode + * \param in_data The input data in forward pass. + * \param out_data The output data in forward pass. + * \param in_grad Gradient of inputs in backward pass. + * \param out_grad Gradient of outputs in backward pass. * \return list of pair of integers taken from the inputs vector, * indicating possible in place operations. */ diff --git a/include/mxnet/storage.h b/include/mxnet/storage.h index 2953cbe0d171..575dc4cde1a2 100644 --- a/include/mxnet/storage.h +++ b/include/mxnet/storage.h @@ -1,45 +1,70 @@ /*! - * Copyright (c) 2015 by Contributors + * Copyright (c) 2015 by Contributors * \file storage.h - * \brief the memory allocator that manages the memory across multiple devices + * \brief Storage manager across multiple devices. */ #ifndef MXNET_STORAGE_H_ #define MXNET_STORAGE_H_ -#include "./base.h" -#include "./context.h" + +#include +#include "base.h" +#include "context.h" namespace mxnet { -/*! \brief memory allocator of storage */ -class StorageManager { + +/*! + * \brief Storage manager across multiple devices. + */ +class Storage { public: /*! - * \brief storage handle the represents storage information + * \brief Storage handle. */ struct Handle { - /*! \brief pointer to the data */ - void *dptr; - /*! \brief context information about device and deviceID */ - Context ctx; /*! - * \brief internal handle reserved for manager, - * user should not change or use this + * \brief Pointer to the data. */ - void *handle_; + void* dptr; + /*! + * \brief Size of the storage. + */ + size_t size; + /*! + * \brief Context information about device and ID. + */ + Context ctx; }; /*! - * \brief allocate a new contiguous memory for a given size - * \param size the total size of memory in bytes - * \param ctx context information about the device and deviceID - * \return Handle struct + * \brief Allocate a new contiguous memory for a given size. + * \param size Total size of memory in bytes. + * \param ctx Context information about the device and ID. + * \return Handle struct. */ - virtual Handle Alloc(size_t size, Context ctx) = 0; + Handle Alloc(size_t size, Context ctx); /*! - * \brief free the space represened the handle - * \param handle the handle to memory to be freed + * \brief Free storage. + * \param handle Handle struect. */ - virtual void Free(Handle handle) = 0; - /*! \return storage manager singleton */ - static StorageManager *Get(); -}; // class StorageManager + void Free(Handle handle); + /*! + * \brief Destructor. + */ + ~Storage(); + /*! + * \return Storage singleton. + */ + static Storage* Get(); + + private: + /*! + * \brief Hidden constructors. + */ + Storage(); + struct Impl; + std::unique_ptr impl_; + DISALLOW_COPY_AND_ASSIGN(Storage); +}; // class Storage + } // namespace mxnet + #endif // MXNET_STORAGE_H_ diff --git a/src/common/concurrent_blocking_queue.h b/src/common/concurrent_blocking_queue.h index 14bab00d8280..82e2598816a5 100644 --- a/src/common/concurrent_blocking_queue.h +++ b/src/common/concurrent_blocking_queue.h @@ -13,6 +13,9 @@ #include #include +/*! + * \brief Common components. + */ namespace common { /*! diff --git a/src/common/cuda_utils.h b/src/common/cuda_utils.h new file mode 100644 index 000000000000..6002da20c1fe --- /dev/null +++ b/src/common/cuda_utils.h @@ -0,0 +1,157 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file cuda_utils.h + * \brief CUDA debugging utilities. + */ +#ifndef MXNET_COMMON_CUDA_UTILS_H_ +#define MXNET_COMMON_CUDA_UTILS_H_ + +#include + +#if MXNET_USE_CUDA + +#include +#include +#include + +namespace common { + +/*! + * \brief CUDA utilities. + */ +namespace cuda { + +/*! + * \brief Get string representation of cuBLAS errors. + * \param error The error. + * \return String representation. + */ +inline const char* CublasGetErrorString(cublasStatus_t error) { + switch (error) { + case CUBLAS_STATUS_SUCCESS: + return "CUBLAS_STATUS_SUCCESS"; + case CUBLAS_STATUS_NOT_INITIALIZED: + return "CUBLAS_STATUS_NOT_INITIALIZED"; + case CUBLAS_STATUS_ALLOC_FAILED: + return "CUBLAS_STATUS_ALLOC_FAILED"; + case CUBLAS_STATUS_INVALID_VALUE: + return "CUBLAS_STATUS_INVALID_VALUE"; + case CUBLAS_STATUS_ARCH_MISMATCH: + return "CUBLAS_STATUS_ARCH_MISMATCH"; + case CUBLAS_STATUS_MAPPING_ERROR: + return "CUBLAS_STATUS_MAPPING_ERROR"; + case CUBLAS_STATUS_EXECUTION_FAILED: + return "CUBLAS_STATUS_EXECUTION_FAILED"; + case CUBLAS_STATUS_INTERNAL_ERROR: + return "CUBLAS_STATUS_INTERNAL_ERROR"; + case CUBLAS_STATUS_NOT_SUPPORTED: + return "CUBLAS_STATUS_NOT_SUPPORTED"; + default: + break; + } + return "Unknown cuBLAS status"; +} + +/*! + * \brief Get string representation of cuRAND errors. + * \param status The status. + * \return String representation. + */ +inline const char* CurandGetErrorString(curandStatus_t status) { + switch (status) { + case CURAND_STATUS_SUCCESS: + return "CURAND_STATUS_SUCCESS"; + case CURAND_STATUS_VERSION_MISMATCH: + return "CURAND_STATUS_VERSION_MISMATCH"; + case CURAND_STATUS_NOT_INITIALIZED: + return "CURAND_STATUS_NOT_INITIALIZED"; + case CURAND_STATUS_ALLOCATION_FAILED: + return "CURAND_STATUS_ALLOCATION_FAILED"; + case CURAND_STATUS_TYPE_ERROR: + return "CURAND_STATUS_TYPE_ERROR"; + case CURAND_STATUS_OUT_OF_RANGE: + return "CURAND_STATUS_OUT_OF_RANGE"; + case CURAND_STATUS_LENGTH_NOT_MULTIPLE: + return "CURAND_STATUS_LENGTH_NOT_MULTIPLE"; + case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED: + return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED"; + case CURAND_STATUS_LAUNCH_FAILURE: + return "CURAND_STATUS_LAUNCH_FAILURE"; + case CURAND_STATUS_PREEXISTING_FAILURE: + return "CURAND_STATUS_PREEXISTING_FAILURE"; + case CURAND_STATUS_INITIALIZATION_FAILED: + return "CURAND_STATUS_INITIALIZATION_FAILED"; + case CURAND_STATUS_ARCH_MISMATCH: + return "CURAND_STATUS_ARCH_MISMATCH"; + case CURAND_STATUS_INTERNAL_ERROR: + return "CURAND_STATUS_INTERNAL_ERROR"; + } + return "Unknown cuRAND status"; +} + +} // namespace cuda +} // namespace common + +/*! + * \brief Check CUDA error. + * \param msg Message to print if an error occured. + */ +#define CHECK_CUDA_ERROR(msg) \ + { \ + cudaError_t e = cudaGetLastError(); \ + CHECK_EQ(e, cudaSuccess) << (msg) << " CUDA: " << cudaGetErrorString(e); \ + } + +/*! + * \brief Protected CUDA call. + * \param func Expression to call. + * + * It checks for CUDA errors after invocation of the expression. + */ +#define CUDA_CALL(func) \ + { \ + cudaError_t e = (func); \ + CHECK_EQ(e, cudaSuccess) << "CUDA: " << cudaGetErrorString(e); \ + } + +/*! + * \brief Protected cuBLAS call. + * \param func Expression to call. + * + * It checks for cuBLAS errors after invocation of the expression. + */ +#define CUBLAS_CALL(func) \ + { \ + cublasStatus_t e = (func); \ + CHECK_EQ(e, CUBLAS_STATUS_SUCCESS) \ + << "cuBLAS: " << common::cuda::CublasGetErrorString(e); \ + } + +/*! + * \brief Protected cuRAND call. + * \param func Expression to call. + * + * It checks for cuRAND errors after invocation of the expression. + */ +#define CURAND_CALL(func) \ + { \ + curandStatus_t e = (func); \ + CHECK_EQ(e, CURAND_STATUS_SUCCESS) \ + << "cuRAND: " << common::cuda::CurandGetErrorString(e); \ + } + +#endif // MXNET_USE_CUDA + +#if MXNET_USE_CUDNN + +#include + +#define CUDNN_CALL(func) \ + { \ + cudnnStatus_t e = (func); \ + CHECK_EQ(e, CUDNN_STATUS_SUCCESS) << "cuDNN: " << cudnnGetErrorString(e); \ + } + +#endif // MXNET_USE_CUDNN + +#endif // MXNET_COMMON_CUDA_UTILS_H_ diff --git a/src/common/utils.h b/src/common/utils.h new file mode 100644 index 000000000000..f55ebc26535f --- /dev/null +++ b/src/common/utils.h @@ -0,0 +1,105 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file utils.h + * \brief Basic utilility functions. + */ +#ifndef MXNET_COMMON_UTILS_H_ +#define MXNET_COMMON_UTILS_H_ + +#if DMLC_USE_CXX11 +#include +#include +#include +#endif // DMLC_USE_CXX11 + +namespace common { + +#if DMLC_USE_CXX11 + +/*! + * \brief Helper functions. + */ +namespace helper { + +/*! + * \brief Helper for non-array type `T`. + */ +template +struct UniqueIf { + /*! + * \brief Type of `T`. + */ + using SingleObject = std::unique_ptr; +}; + +/*! + * \brief Helper for an array of unknown bound `T`. + */ +template +struct UniqueIf { + /*! + * \brief Type of `T`. + */ + using UnknownBound = std::unique_ptr; +}; + +/*! + * \brief Helper for an array of known bound `T`. + */ +template +struct UniqueIf { + /*! + * \brief Type of `T`. + */ + using KnownBound = void; +}; + +} // namespace helper + +/*! + * \brief Constructs an object of type `T` and wraps it in a + * `std``::``unique_ptr`. + * \param args List of arguments with which an instance of `T` will be + * constructed. + * \return `std``::``unique_ptr` of an instance of type `T`. + * + * Constructs a non-array type `T`. The arguments `args` are passed to the + * constructor of `T`. The function does not participate in the overload + * resolution if `T` is an array type. + */ +template +typename helper::UniqueIf::SingleObject MakeUnique(Args&&... args) { + return std::unique_ptr(new T(std::forward(args)...)); +} + +/*! + * \brief Constructs an object of type `T` and wraps it in a + * `std``::``unique_ptr`. + * \param n The size of the array to construct. + * \return `std``::``unique_ptr` of an instance of type `T`. + * + * Constructs an array of unknown bound `T`. The function does not participate + * in the overload resolution unless `T` is an array of unknown bound. + */ +template +typename helper::UniqueIf::UnknownBound MakeUnique(size_t n) { + using U = typename std::remove_extent::type; + return std::unique_ptr(new U[n]{}); +} + +/*! + * \brief Constructs an object of type `T` and wraps it in a + * `std``::``unique_ptr`. + * \param args List of arguments with which an instance of `T` will be + * constructed. + * + * Constructs an arrays of known bound is disallowed. + */ +template +typename helper::UniqueIf::KnownBound MakeUnique(Args&&... args) = delete; + +#endif // DMLC_USE_CXX11 + +} // namespace common + +#endif // MXNET_COMMON_UTILS_H_ diff --git a/src/storage/cpu_device_storage.h b/src/storage/cpu_device_storage.h new file mode 100644 index 000000000000..0b69b6b4fd8d --- /dev/null +++ b/src/storage/cpu_device_storage.h @@ -0,0 +1,49 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file cpu_device_storage.h + * \brief CPU storage implementation. + */ +#ifndef MXNET_STORAGE_CPU_DEVICE_STORAGE_H_ +#define MXNET_STORAGE_CPU_DEVICE_STORAGE_H_ + +#include +#include +#include "mxnet/base.h" + +namespace mxnet { +namespace storage { + +/*! + * \brief CPU storage implementation. + */ +class CPUDeviceStorage { + public: + /*! + * \brief Aligned allocation on CPU. + * \param size Size to allocate. + * \return Pointer to the storage. + */ + inline static void* Alloc(size_t size); + /*! + * \brief Deallocation. + * \param ptr Pointer to deallocate. + */ + inline static void Free(void* ptr); + + private: + /*! + * \brief Alignment of allocation. + */ + static constexpr size_t alignment_ = 16; +}; // class CPUDeviceStorage + +inline void* CPUDeviceStorage::Alloc(size_t size) { + return CHECK_NOTNULL(memalign(alignment_, size)); +} + +inline void CPUDeviceStorage::Free(void* ptr) { free(ptr); } + +} // namespace storage +} // namespace mxnet + +#endif // MXNET_STORAGE_CPU_DEVICE_STORAGE_H_ diff --git a/src/storage/gpu_device_storage.h b/src/storage/gpu_device_storage.h new file mode 100644 index 000000000000..22956d192893 --- /dev/null +++ b/src/storage/gpu_device_storage.h @@ -0,0 +1,57 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file gpu_device_storage.h + * \brief GPU storage implementation. + */ +#ifndef MXNET_STORAGE_GPU_DEVICE_STORAGE_H_ +#define MXNET_STORAGE_GPU_DEVICE_STORAGE_H_ + +#include "mxnet/base.h" +#include "../common/cuda_utils.h" +#if MXNET_USE_CUDA +#include +#endif // MXNET_USE_CUDA + +namespace mxnet { +namespace storage { + +/*! + * \brief GPU storage implementation. + */ +class GPUDeviceStorage { + public: + /*! + * \brief Allocation. + * \param size Size to allocate. + * \return Pointer to the storage. + */ + inline static void* Alloc(size_t size); + /*! + * \brief Deallocation. + * \param ptr Pointer to deallocate. + */ + inline static void Free(void* ptr); +}; // class GPUDeviceStorage + +inline void* GPUDeviceStorage::Alloc(size_t size) { + void* ret; +#if MXNET_USE_CUDA + CUDA_CALL(cudaMalloc(&ret, size)); +#else // MXNET_USE_CUDA + LOG(FATAL) << "Please compile with CUDA enabled"; +#endif // MXNET_USE_CUDA + return ret; +} + +inline void GPUDeviceStorage::Free(void* ptr) { +#if MXNET_USE_CUDA + CUDA_CALL(cudaFree(ptr)); +#else // MXNET_USE_CUDA + LOG(FATAL) << "Please compile with CUDA enabled"; +#endif // MXNET_USE_CUDA +} + +} // namespace storage +} // namespace mxnet + +#endif // MXNET_STORAGE_GPU_DEVICE_STORAGE_H_ diff --git a/src/storage/naive_storage_manager.h b/src/storage/naive_storage_manager.h new file mode 100644 index 000000000000..a476f5ea2acc --- /dev/null +++ b/src/storage/naive_storage_manager.h @@ -0,0 +1,49 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file naive_storage_manager.h + * \brief Naive storage manager. + */ +#ifndef MXNET_STORAGE_NAIVE_STORAGE_MANAGER_H_ +#define MXNET_STORAGE_NAIVE_STORAGE_MANAGER_H_ + +#include "storage_manager.h" +#include "mxnet/base.h" + +namespace mxnet { +namespace storage { + +/*! + * \brief Naive storage manager. + */ +template +class NaiveStorageManager final : public StorageManager { + public: + /*! + * \brief Default constructor. + */ + NaiveStorageManager() = default; + /*! + * \brief Default destructor. + */ + ~NaiveStorageManager() = default; + void* Alloc(size_t size) override; + void Free(void* ptr, size_t) override; + + private: + DISALLOW_COPY_AND_ASSIGN(NaiveStorageManager); +}; // class NaiveStorageManager + +template +void* NaiveStorageManager::Alloc(size_t size) { + return DeviceStorage::Alloc(size); +} + +template +void NaiveStorageManager::Free(void* ptr, size_t) { + DeviceStorage::Free(ptr); +} + +} // namespace storage +} // namespace mxnet + +#endif // MXNET_STORAGE_NAIVE_STORAGE_MANAGER_H_ diff --git a/src/storage/pooled_storage_manager.h b/src/storage/pooled_storage_manager.h new file mode 100644 index 000000000000..c7e1e0cde3c2 --- /dev/null +++ b/src/storage/pooled_storage_manager.h @@ -0,0 +1,78 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file pooled_storage_manager.h + * \brief Storage manager with a memory pool. + */ +#ifndef MXNET_STORAGE_POOLED_STORAGE_MANAGER_H_ +#define MXNET_STORAGE_POOLED_STORAGE_MANAGER_H_ + +#include +#include +#include "storage_manager.h" +#include "mxnet/base.h" + +namespace mxnet { +namespace storage { + +/*! + * \brief Storage manager with a memory pool. + */ +template +class PooledStorageManager final : public StorageManager { + public: + /*! + * \brief Default constructor. + */ + PooledStorageManager() = default; + /*! + * \brief Default destructor. + */ + ~PooledStorageManager() = default; + void* Alloc(size_t size) override; + void Free(void* ptr, size_t size) override; + + private: + void ReleaseAll(); + size_t used_memory_ = 0; + std::unordered_map> memory_pool_; + DISALLOW_COPY_AND_ASSIGN(PooledStorageManager); +}; // class PooledStorageManager + +template +void* PooledStorageManager::Alloc(size_t size) { + auto&& reuse_it = memory_pool_.find(size); + if (reuse_it == memory_pool_.end() || reuse_it->second.size() == 0) { + if (kThreshold <= used_memory_) { + ReleaseAll(); + } + used_memory_ += size; + return DeviceStorage::Alloc(size); + } else { + auto&& reuse_pool = reuse_it->second; + auto ret = reuse_pool.back(); + reuse_pool.pop_back(); + return ret; + } +} + +template +void PooledStorageManager::Free(void* ptr, + size_t size) { + auto&& reuse_pool = memory_pool_[size]; + reuse_pool.push_back(ptr); +} + +template +void PooledStorageManager::ReleaseAll() { + for (auto&& i : memory_pool_) { + for (auto&& j : i.second) { + DeviceStorage::Free(j); + used_memory_ -= i.first; + } + } +} + +} // namespace storage +} // namespace mxnet + +#endif // MXNET_STORAGE_POOLED_STORAGE_MANAGER_H_ diff --git a/src/storage/storage.cc b/src/storage/storage.cc index 3b539551f02d..98c7981b6a88 100644 --- a/src/storage/storage.cc +++ b/src/storage/storage.cc @@ -1,41 +1,93 @@ -// Copyright (c) 2015 by Contributors +/*! + * Copyright (c) 2015 by Contributors + */ +#include "mxnet/storage.h" #include -#include +#include +#include +#include "storage_manager.h" +#include "naive_storage_manager.h" +#include "pooled_storage_manager.h" +#include "cpu_device_storage.h" +#include "gpu_device_storage.h" +#include "../common/cuda_utils.h" +#include "../common/utils.h" + namespace mxnet { -class NaiveStorageManager : public StorageManager { - public: - virtual Handle Alloc(size_t size, Context ctx); - virtual void Free(Handle handle); -}; - -StorageManager::Handle -NaiveStorageManager::Alloc(size_t size, Context ctx) { + +struct Storage::Impl { + static constexpr size_t kPoolThreshold = 4096 * 1024 * 1024ul; + static constexpr size_t kMaxNumberOfDevices = 3; + static constexpr size_t kMaxNumberOfDeviceIDs = 16; + + template + using CurrentStorageManager = + storage::PooledStorageManager; + + static void ActivateDevice(Context ctx) { + switch (ctx.dev_mask) { + case cpu::kDevMask: + break; + case gpu::kDevMask: +#if MXNET_USE_CUDA + CUDA_CALL(cudaSetDevice(ctx.dev_id)); +#else // MXNET_USE_CUDA + LOG(FATAL) << "Please compile with CUDA enabled"; +#endif // MXNET_USE_CUDA + break; + default: + LOG(FATAL) << "Unimplemented device"; + } + } + + // std::unordered_map< + // int, std::unordered_map>> + // storage_managers; + std::array, + kMaxNumberOfDeviceIDs>, + kMaxNumberOfDevices> storage_managers; +}; // struct Storage::Impl + +Storage::Handle Storage::Alloc(size_t size, Context ctx) { Handle hd; hd.ctx = ctx; - hd.handle_ = NULL; - if (ctx.dev_mask == cpu::kDevMask) { - hd.dptr = calloc(size, sizeof(real_t)); - // cudaMallocHost(&hd.dptr, size); - } else { -#if MXNET_USE_CUDA - cudaMalloc(&hd.dptr, size); -#endif + auto&& device = impl_->storage_managers.at(ctx.dev_mask); + auto&& device_id_it = device.at(ctx.dev_id); + // Allocate device if necessary. + if (!device_id_it) { + switch (ctx.dev_mask) { + case cpu::kDevMask: + device_id_it = common::MakeUnique< + Storage::Impl::CurrentStorageManager>(); + break; + case gpu::kDevMask: + device_id_it = common::MakeUnique< + Storage::Impl::CurrentStorageManager>(); + break; + default: + LOG(FATAL) << "Unimplemented device"; + } } + Impl::ActivateDevice(ctx); + hd.dptr = device_id_it->Alloc(size); + hd.size = size; return hd; } -void NaiveStorageManager::Free(StorageManager::Handle handle) { - if (handle.ctx.dev_mask == cpu::kDevMask) { - free(handle.dptr); - handle.dptr = NULL; - // cudaFreeHost(handle.dptr); - } else { -#if MXNET_USE_CUDA - cudaFree(handle.dptr); -#endif - } + +void Storage::Free(Storage::Handle handle) { + Impl::ActivateDevice(handle.ctx); + impl_->storage_managers.at(handle.ctx.dev_mask) + .at(handle.ctx.dev_id) + ->Free(handle.dptr, handle.size); } -StorageManager *StorageManager::Get() { - static NaiveStorageManager inst; + +Storage::~Storage() = default; + +Storage* Storage::Get() { + static Storage inst; return &inst; } + +Storage::Storage() : impl_{new Impl{}} {} + } // namespace mxnet diff --git a/src/storage/storage_manager.h b/src/storage/storage_manager.h new file mode 100644 index 000000000000..3d264ab278ca --- /dev/null +++ b/src/storage/storage_manager.h @@ -0,0 +1,40 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file storage_manager.h + * \brief Storage manager. + */ +#ifndef MXNET_STORAGE_STORAGE_MANAGER_H_ +#define MXNET_STORAGE_STORAGE_MANAGER_H_ + +#include + +namespace mxnet { +namespace storage { + +/*! + * \brief Storage manager interface. + */ +class StorageManager { + public: + /*! + * \brief Allocation. + * \param size Size to allocate. + * \return Pointer to the storage. + */ + virtual void* Alloc(size_t size) = 0; + /*! + * \brief Deallocation. + * \param ptr Pointer to deallocate. + * \param size Size of the storage. + */ + virtual void Free(void* ptr, size_t size) = 0; + /*! + * \brief Destructor. + */ + virtual ~StorageManager() = default; +}; // namespace StorageManager + +} // namespace storage +} // namespace mxnet + +#endif // MXNET_STORAGE_STORAGE_MANAGER_H_ diff --git a/test/test_storage.cc b/test/test_storage.cc new file mode 100644 index 000000000000..33995a055dc5 --- /dev/null +++ b/test/test_storage.cc @@ -0,0 +1,39 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file test_storage.cc + * \brief Test for storage. + */ +#include +#include +#include "mxnet/storage.h" + +int main() { + constexpr size_t kSize = 1024; + auto&& storage = mxnet::Storage::Get(); + mxnet::Context context_cpu{}; + auto&& handle = storage->Alloc(kSize, context_cpu); + assert(handle.ctx == context_cpu); + assert(handle.size == kSize); + auto ptr = handle.dptr; + storage->Free(handle); + handle = storage->Alloc(kSize, context_cpu); + assert(handle.ctx == context_cpu); + assert(handle.size == kSize); + assert(handle.dptr == ptr); + printf("Success on CPU!\n"); + +#if MXNET_USE_CUDA + mxnet::Context context_gpu{mxnet::gpu::kDevMask, 0}; + handle = storage->Alloc(kSize, context_gpu); + assert(handle.ctx == context_gpu); + assert(handle.size == kSize); + ptr = handle.dptr; + storage->Free(handle); + handle = storage->Alloc(kSize, context_gpu); + assert(handle.ctx == context_gpu); + assert(handle.size == kSize); + assert(handle.dptr == ptr); + printf("Success on GPU!\n"); +#endif // MXNET_USE_CUDA + return 0; +}