diff --git a/src/storage/pooled_storage_manager.h b/src/storage/pooled_storage_manager.h index 4c8ae4eb12dd..7726bc6f9273 100644 --- a/src/storage/pooled_storage_manager.h +++ b/src/storage/pooled_storage_manager.h @@ -53,8 +53,11 @@ class GPUPooledStorageManager final : public StorageManager { public: /*! * \brief Default constructor. + * + * \param initial_context context used by this Storage Manager */ - GPUPooledStorageManager() { + explicit GPUPooledStorageManager(Context initial_context) : + initial_context_(initial_context) { reserve_ = dmlc::GetEnv("MXNET_GPU_MEM_POOL_RESERVE", 5); page_size_ = dmlc::GetEnv("MXNET_GPU_MEM_POOL_PAGE_SIZE", 4096); large_alloc_round_size_ = dmlc::GetEnv("MXNET_GPU_MEM_LARGE_ALLOC_ROUND_SIZE", 2 * 1024 * 1024); @@ -123,6 +126,8 @@ class GPUPooledStorageManager final : public StorageManager { int reserve_; // number of devices const size_t NDEV = 32; + // context used by this Storage Manager + const Context initial_context_; // memory pool std::unordered_map> memory_pool_; DISALLOW_COPY_AND_ASSIGN(GPUPooledStorageManager); @@ -177,6 +182,7 @@ void GPUPooledStorageManager::ReleaseAll() { Storage::Handle handle; handle.dptr = j; handle.size = i.first; + handle.ctx = initial_context_; DirectFreeNoLock(handle); } } @@ -201,8 +207,11 @@ class GPUPooledRoundedStorageManager final : public StorageManager { public: /*! * \brief Default constructor. + * + * \param initial_context context used by this Storage Manager */ - GPUPooledRoundedStorageManager() { + explicit GPUPooledRoundedStorageManager(Context initial_context) : + initial_context_(initial_context) { reserve_ = dmlc::GetEnv("MXNET_GPU_MEM_POOL_RESERVE", 5); page_size_ = dmlc::GetEnv("MXNET_GPU_MEM_POOL_PAGE_SIZE", 4096); cut_off_ = dmlc::GetEnv("MXNET_GPU_MEM_POOL_ROUND_LINEAR_CUTOFF", 24); @@ -290,6 +299,8 @@ class GPUPooledRoundedStorageManager final : public StorageManager { size_t cut_off_; // percentage of reserved memory int reserve_; + // context used by this Storage Manager + const Context initial_context_; // memory pool std::vector> memory_pool_; DISALLOW_COPY_AND_ASSIGN(GPUPooledRoundedStorageManager); @@ -345,6 +356,7 @@ void GPUPooledRoundedStorageManager::ReleaseAll() { Storage::Handle handle; handle.size = size; handle.dptr = j; + handle.ctx = initial_context_; DirectFreeNoLock(handle); } memory_pool_[i].clear(); diff --git a/src/storage/storage.cc b/src/storage/storage.cc index 7484e699d388..4f15351a594a 100644 --- a/src/storage/storage.cc +++ b/src/storage/storage.cc @@ -104,13 +104,13 @@ void StorageImpl::Alloc(Storage::Handle* handle) { std::string strategy = type; if (strategy == "Round") { - ptr = new storage::GPUPooledRoundedStorageManager(); + ptr = new storage::GPUPooledRoundedStorageManager(handle->ctx); LOG(INFO) << "Using GPUPooledRoundedStorageManager."; } else { if (strategy != "Naive") { LOG(FATAL) << "Unknown memory pool strategy specified: " << strategy << "."; } - ptr = new storage::GPUPooledStorageManager(); + ptr = new storage::GPUPooledStorageManager(handle->ctx); } #else LOG(FATAL) << "Compile with USE_CUDA=1 to enable GPU usage";