From 84bf795386613b53a333084f8c93db1475543859 Mon Sep 17 00:00:00 2001 From: Yutian Li Date: Tue, 11 Aug 2015 00:57:40 +0800 Subject: [PATCH] [storage] putting things together --- Makefile | 2 +- include/mxnet/narray.h | 10 +-- include/mxnet/storage.h | 56 +++++++++++----- src/storage/cpu_storage.h | 2 +- src/storage/gpu_storage.h | 2 +- src/storage/naive_storage_manager.h | 6 +- src/storage/pooled_storage_manager.h | 13 ++-- src/storage/storage.cc | 99 +++++++++++++++++++++------- src/storage/storage_manager.h | 4 +- 9 files changed, 132 insertions(+), 62 deletions(-) diff --git a/Makefile b/Makefile index e830c3eac7df..5ddb285f8c8e 100644 --- a/Makefile +++ b/Makefile @@ -58,7 +58,7 @@ endif BIN = test/api_registry_test OBJ = storage.o narray_op_cpu.o static_operator.o static_operator_cpu.o # add threaded engine after it is done -OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o operator.o fully_connect_op_cpu.o cpu_storage.o gpu_storage.o +OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o operator.o fully_connect_op_cpu.o cpu_storage.o gpu_storage.o storage.o CUOBJ = SLIB = lib/libmxnet.so ALIB = lib/libmxnet.a diff --git a/include/mxnet/narray.h b/include/mxnet/narray.h index 4e7b4448e667..9a73870c003b 100644 --- a/include/mxnet/narray.h +++ b/include/mxnet/narray.h @@ -126,14 +126,14 @@ class NArray { /*! \brief the real data chunk that backs NArray */ struct Chunk { /*! \brief storage handlefrom storage engine */ - StorageManager::Handle shandle; + Storage::Handle shandle; /*! \brief variable from DAG engine */ DAGEngine::Variable var; /*! \brief holds the data content */ TBlob data; /*! * \brief if this is true, this means the data do not come - * from StorageManager, and do not need to be freed + * from Storage, and do not need to be freed */ bool static_data; /*! \brief whether allocation is delayed */ @@ -161,7 +161,7 @@ class NArray { /*! \brief check if delay alloc is on, do alloc if not yet done */ inline void CheckAndAlloc(void) { if (delay_alloc) { - shandle = StorageManager::Get()->Alloc(data.shape_.Size() * sizeof(real_t), shandle.ctx); + shandle = Storage::Get()->Alloc(data.shape_.Size() * sizeof(real_t), shandle.ctx); data = TBlob(static_cast(shandle.dptr), data.shape_, shandle.ctx.dev_mask); delay_alloc = false; } @@ -172,9 +172,9 @@ class NArray { DAGEngine::Get()->PushDelete([](RunContext s) {}, shandle.ctx, var); } else { CHECK(!delay_alloc) << "deleted before allocation"; - StorageManager::Handle h = this->shandle; + Storage::Handle h = this->shandle; DAGEngine::Get()->PushDelete([h](RunContext s) { - StorageManager::Get()->Free(h); + Storage::Get()->Free(h); }, shandle.ctx, var); } } diff --git a/include/mxnet/storage.h b/include/mxnet/storage.h index a900537baff3..d9f4a9f17de0 100644 --- a/include/mxnet/storage.h +++ b/include/mxnet/storage.h @@ -1,49 +1,69 @@ /*! * Copyright (c) 2015 by Contributors * \file storage.h - * \brief the memory allocator that manages the memory across multiple devices + * \brief Storage manager across multiple devices. */ #ifndef MXNET_STORAGE_H_ #define MXNET_STORAGE_H_ + +#include #include "./base.h" #include "./tensor_blob.h" namespace mxnet { -/*! \brief memory allocator of storage */ -class StorageManager { +/*! + * \brief Storage manager across multiple devices. + */ +class Storage { public: /*! - * \brief storage handle the represents storage information + * \brief Storage handle. */ struct Handle { - /*! \brief pointer to the data */ + /*! + * \brief Pointer to the data. + */ void* dptr; - /*! \brief context information about device and deviceID */ + /*! + * \brief Size of the storage. + */ + size_t size; + /*! + * \brief Context information about device and ID. + */ Context ctx; }; /*! - * \brief allocate a new contiguous memory for a given size - * \param size the total size of memory in bytes - * \param ctx context information about the device and deviceID - * \return Handle struct + * \brief Allocate a new contiguous memory for a given size. + * \param size Total size of memory in bytes. + * \param ctx Context information about the device and ID. + * \return Handle struct. */ Handle Alloc(size_t size, Context ctx); /*! - * \brief free the space represened the handle - * \param handle the handle to memory to be freed + * \brief Free storage. + * \param handle Handle struect. */ void Free(Handle handle); - /*! \return storage manager singleton */ - static StorageManager* Get(); + /*! + * \brief Destructor. + */ + ~Storage(); + /*! + * \return Storage singleton. + */ + static Storage* Get(); private: /*! - * \brief disabled constructors + * \brief Hidden constructors. */ - StorageManager() {} - DISALLOW_COPY_AND_ASSIGN(StorageManager); -}; // class StorageManager + Storage(); + struct Impl; + std::unique_ptr impl_; + DISALLOW_COPY_AND_ASSIGN(Storage); +}; // class Storage } // namespace mxnet diff --git a/src/storage/cpu_storage.h b/src/storage/cpu_storage.h index 03aa0eee46ab..3c031959d9be 100644 --- a/src/storage/cpu_storage.h +++ b/src/storage/cpu_storage.h @@ -33,7 +33,7 @@ class CpuStorage { * \brief Alignment of allocation. */ static constexpr size_t alignment_ = 16; -}; +}; // class CpuStorage } // namespace storage } // namespace mxnet diff --git a/src/storage/gpu_storage.h b/src/storage/gpu_storage.h index 08a9befe41aa..5e0293f2edd6 100644 --- a/src/storage/gpu_storage.h +++ b/src/storage/gpu_storage.h @@ -27,7 +27,7 @@ class GpuStorage { * \param ptr Pointer to deallocate. */ static void Free(void* ptr); -}; +}; // class GpuStorage } // namespace storage } // namespace mxnet diff --git a/src/storage/naive_storage_manager.h b/src/storage/naive_storage_manager.h index d145cada3f43..3611b775ce74 100644 --- a/src/storage/naive_storage_manager.h +++ b/src/storage/naive_storage_manager.h @@ -31,15 +31,15 @@ class NaiveStorageManager final : public StorageManager { private: DISALLOW_COPY_AND_ASSIGN(NaiveStorageManager); -}; +}; // class NaiveStorageManager template -void* NaiveStorageManager::Alloc(size_t size) { +void* NaiveStorageManager::Alloc(size_t size) { return DeviceStorage::Alloc(size); } template -void NaiveStorageManager::Free(void* ptr) { +void NaiveStorageManager::Free(void* ptr, size_t) { DeviceStorage::Free(ptr); } diff --git a/src/storage/pooled_storage_manager.h b/src/storage/pooled_storage_manager.h index fb168391dc5e..b3a45b895985 100644 --- a/src/storage/pooled_storage_manager.h +++ b/src/storage/pooled_storage_manager.h @@ -36,14 +36,14 @@ class PooledStorageManager final : public StorageManager { size_t used_memory_ = 0; std::unordered_map> memory_pool_; DISALLOW_COPY_AND_ASSIGN(PooledStorageManager); -}; +}; // class PooledStorageManager -templace +template void* PooledStorageManager::Alloc(size_t size) { auto&& reuse_it = memory_pool_.find(size); if (reuse_it == memory_pool_.end() || reuse_it->second.size() == 0) { if (kThreshold <= used_memory_) { - ReleaseAll(); + ReleaseAll(); } used_memory_ += size; return DeviceStorage::Alloc(size); @@ -55,8 +55,9 @@ void* PooledStorageManager::Alloc(size_t size) { } } -templace -void PooledStorageManager::Free(void* ptr, size_t size) { +template +void PooledStorageManager::Free(void* ptr, + size_t size) { auto&& reuse_pool = memory_pool_[size]; reuse_pool.push_back(ptr); } @@ -65,7 +66,7 @@ template void PooledStorageManager::ReleaseAll() { for (auto&& i : memory_pool_) { for (auto&& j : i.second) { - DeviceStorage::Free(i.second); + DeviceStorage::Free(j); used_memory_ -= i.first; } } diff --git a/src/storage/storage.cc b/src/storage/storage.cc index 12e25e68ddaf..a5bfca9c9046 100644 --- a/src/storage/storage.cc +++ b/src/storage/storage.cc @@ -1,44 +1,93 @@ /*! * Copyright (c) 2015 by Contributors */ +#include "mxnet/storage.h" #include -#include +#include +#include "./storage_manager.h" +#include "./naive_storage_manager.h" +#include "./pooled_storage_manager.h" +#include "./cpu_storage.h" +#include "./gpu_storage.h" +#include "mxnet/cuda_utils.h" + namespace mxnet { -// class NaiveStorageManager : public StorageManager { -// public: -// virtual Handle Alloc(size_t size, Context ctx); -// virtual void Free(Handle handle); -// }; +struct Storage::Impl { + static constexpr size_t kPoolThreshold = 4096 * 1024 * 1024ul; + + template + using CurrentStorageManager = storage::PooledStorageManager; + + static void ActivateDevice(Context ctx) { + switch (ctx.dev_mask) { + case cpu::kDevMask: + break; + case gpu::kDevMask: +#if MXNET_USE_CUDA + CUDA_CALL(cudaSetDevice(ctx.dev_id)); +#else // MXNET_USE_CUDA + LOG(FATAL) << "Please compile with CUDA enabled"; +#endif // MXNET_USE_CUDA + break; + default: + LOG(FATAL) << "Unimplemented device"; + } + } + + std::unordered_map< + int, std::unordered_map>> + storage_managers; +}; // struct Storage::Impl -StorageManager::Handle StorageManager::Alloc(size_t size, Context ctx) { +Storage::Handle Storage::Alloc(size_t size, Context ctx) { Handle hd; hd.ctx = ctx; - if (ctx.dev_mask == cpu::kDevMask) { - hd.dptr = calloc(size, sizeof(real_t)); - } else { -#if MXNET_USE_CUDA - cudaMalloc(&hd.dptr, size); -#endif + auto&& device = impl_->storage_managers[ctx.dev_mask]; + auto&& device_id_it = device.find(ctx.dev_id); + // Allocate device if necessary. + if (device_id_it == device.end()) { + switch (ctx.dev_mask) { + case cpu::kDevMask: + device_id_it = + device.emplace(std::make_pair( + ctx.dev_id, + std::unique_ptr{ + new Storage::Impl::CurrentStorageManager< + storage::CpuStorage>{}})).first; + break; + case gpu::kDevMask: + device_id_it = + device.emplace(std::make_pair( + ctx.dev_id, + std::unique_ptr{ + new Storage::Impl::CurrentStorageManager< + storage::GpuStorage>{}})).first; + break; + default: + LOG(FATAL) << "Unimplemented device"; + } } + Impl::ActivateDevice(ctx); + hd.dptr = device_id_it->second->Alloc(size); + hd.size = size; return hd; } -void StorageManager::Free(StorageManager::Handle handle) { - if (handle.ctx.dev_mask == cpu::kDevMask) { - free(handle.dptr); - handle.dptr = NULL; - // cudaFreeHost(handle.dptr); - } else { -#if MXNET_USE_CUDA - cudaFree(handle.dptr); -#endif - } +void Storage::Free(Storage::Handle handle) { + Impl::ActivateDevice(handle.ctx); + impl_->storage_managers.at(handle.ctx.dev_mask) + .at(handle.ctx.dev_id) + ->Free(handle.dptr, handle.size); } -StorageManager* StorageManager::Get() { - static StorageManager inst; +Storage::~Storage() = default; + +Storage* Storage::Get() { + static Storage inst; return &inst; } +Storage::Storage() : impl_{new Impl{}} {} + } // namespace mxnet diff --git a/src/storage/storage_manager.h b/src/storage/storage_manager.h index 94967a210e80..3d264ab278ca 100644 --- a/src/storage/storage_manager.h +++ b/src/storage/storage_manager.h @@ -27,12 +27,12 @@ class StorageManager { * \param ptr Pointer to deallocate. * \param size Size of the storage. */ - virtual void* Free(void* ptr, size_t size) = 0; + virtual void Free(void* ptr, size_t size) = 0; /*! * \brief Destructor. */ virtual ~StorageManager() = default; -}; +}; // namespace StorageManager } // namespace storage } // namespace mxnet