Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Tidy up storage allocation and deallocation #14480

Merged
merged 6 commits into from
Mar 28, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -986,8 +986,8 @@ class NDArray {
#endif
delay_alloc = false;
} else if (shandle.size < dbytes) {
// free storage if necessary and alloc again
if (shandle.size > 0) Storage::Get()->Free(shandle);
// free storage
Storage::Get()->Free(shandle);
// init storage
shandle = Storage::Get()->Alloc(dbytes, shandle.ctx);
#if MXNET_USE_MKLDNN == 1
Expand Down Expand Up @@ -1055,8 +1055,8 @@ class NDArray {
}
size_t aux_bytes = shape.Size() * mshadow::mshadow_sizeof(aux_types[i]);
if (aux_handles[i].size < aux_bytes) {
// free storage if necessary and alloc again
if (aux_handles[i].size > 0) Storage::Get()->Free(aux_handles[i]);
// free storage
Storage::Get()->Free(aux_handles[i]);
// init aux storage
aux_handles[i] = Storage::Get()->Alloc(aux_bytes, ctx);
}
Expand Down
8 changes: 4 additions & 4 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ NDArray::Chunk::~Chunk() {
CHECK_EQ(mem.mem->GetDataHandle(), mem.h.dptr);
}
#endif
if (mem.h.size > 0) Storage::Get()->Free(mem.h);
Storage::Get()->Free(mem.h);
for (const auto& aux : mem.aux_h) {
if (aux.size > 0) Storage::Get()->Free(aux);
Storage::Get()->Free(aux);
}
}
}, shandle.ctx, var);
Expand All @@ -134,8 +134,8 @@ void NDArray::Chunk::CheckAndAllocData(const mxnet::TShape &shape, int dtype) {
<< "data is expected to be allocated after aux_data";
auto dbytes = shape.Size() * mshadow::mshadow_sizeof(dtype);
if (shandle.size < dbytes) {
// free storage if necessary and alloc again
if (shandle.size > 0) Storage::Get()->Free(shandle);
// free storage
Storage::Get()->Free(shandle);
// init storage
shandle = Storage::Get()->Alloc(dbytes, ctx);
#if MXNET_USE_MKLDNN == 1
Expand Down
27 changes: 13 additions & 14 deletions src/resource.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,30 +54,29 @@ struct SpaceAllocator {
host_handle.dptr = nullptr;
host_handle.size = 0;
}

inline void ReleaseAll() {
if (handle.size != 0) {
Storage::Get()->DirectFree(handle);
handle.size = 0;
}
if (host_handle.size != 0) {
Storage::Get()->DirectFree(host_handle);
host_handle.size = 0;
}
Storage::Get()->DirectFree(handle);
handle.dptr = nullptr;
handle.size = 0;

Storage::Get()->DirectFree(host_handle);
host_handle.dptr = nullptr;
host_handle.size = 0;
}

inline void* GetSpace(size_t size) {
if (handle.size >= size) return handle.dptr;
if (handle.size != 0) {
Storage::Get()->DirectFree(handle);
}

Storage::Get()->DirectFree(handle);
handle = Storage::Get()->Alloc(size, ctx);
return handle.dptr;
}

inline void* GetHostSpace(size_t size) {
if (host_handle.size >= size) return host_handle.dptr;
if (host_handle.size != 0) {
Storage::Get()->DirectFree(host_handle);
}

Storage::Get()->DirectFree(host_handle);
host_handle = Storage::Get()->Alloc(size, Context());
return host_handle.dptr;
}
Expand Down
25 changes: 12 additions & 13 deletions src/storage/cpu_device_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,12 @@ class CPUDeviceStorage {
public:
/*!
* \brief Aligned allocation on CPU.
* \param size Size to allocate.
* \return Pointer to the storage.
* \param handle Handle struct.
*/
inline static void* Alloc(Storage::Handle* handle);
inline static void Alloc(Storage::Handle* handle);
/*!
* \brief Deallocation.
* \param ptr Pointer to deallocate.
* \param handle Handle struct.
*/
inline static void Free(Storage::Handle handle);

Expand All @@ -63,25 +62,25 @@ class CPUDeviceStorage {
#endif
}; // class CPUDeviceStorage

inline void* CPUDeviceStorage::Alloc(Storage::Handle* handle) {
inline void CPUDeviceStorage::Alloc(Storage::Handle* handle) {
handle->dptr = nullptr;
const size_t size = handle->size;
void* ptr;
if (size == 0) return;

#if _MSC_VER
ptr = _aligned_malloc(size, alignment_);
if (ptr == NULL) LOG(FATAL) << "Failed to allocate CPU Memory";
handle->dptr = _aligned_malloc(size, alignment_);
if (handle->dptr == nullptr) LOG(FATAL) << "Failed to allocate CPU Memory";
#else
int ret = posix_memalign(&ptr, alignment_, size);
int ret = posix_memalign(&handle->dptr, alignment_, size);
if (ret != 0) LOG(FATAL) << "Failed to allocate CPU Memory";
#endif
return ptr;
}

inline void CPUDeviceStorage::Free(Storage::Handle handle) {
void * ptr = handle.dptr;
#if _MSC_VER
_aligned_free(ptr);
_aligned_free(handle.dptr);
#else
free(ptr);
free(handle.dptr);
#endif
}

Expand Down
7 changes: 6 additions & 1 deletion src/storage/cpu_shared_storage_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,18 @@ class CPUSharedStorageManager final : public StorageManager {
}; // class CPUSharedStorageManager

void CPUSharedStorageManager::Alloc(Storage::Handle* handle) {
if (handle->size == 0) {
handle->dptr = nullptr;
yuxihu marked this conversation as resolved.
Show resolved Hide resolved
return;
}

std::lock_guard<std::recursive_mutex> lock(mutex_);
std::uniform_int_distribution<> dis(0, std::numeric_limits<int>::max());
int fid = -1;
std::string filename;
bool is_new = false;
size_t size = handle->size + alignment_;
void *ptr = nullptr;
void* ptr = nullptr;
#ifdef _WIN32
CheckAndRealFree();
HANDLE map_handle = nullptr;
Expand Down
19 changes: 9 additions & 10 deletions src/storage/gpu_device_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,43 +43,42 @@ class GPUDeviceStorage {
public:
/*!
* \brief Allocation.
* \param size Size to allocate.
* \return Pointer to the storage.
* \param handle Handle struct.
*/
inline static void* Alloc(Storage::Handle* handle);
inline static void Alloc(Storage::Handle* handle);
/*!
* \brief Deallocation.
* \param ptr Pointer to deallocate.
* \param handle Handle struct.
*/
inline static void Free(Storage::Handle handle);
}; // class GPUDeviceStorage

inline void* GPUDeviceStorage::Alloc(Storage::Handle* handle) {
inline void GPUDeviceStorage::Alloc(Storage::Handle* handle) {
handle->dptr = nullptr;
const size_t size = handle->size;
void* ret = nullptr;
if (size == 0) return;

#if MXNET_USE_CUDA
mxnet::common::cuda::DeviceStore device_store(handle->ctx.real_dev_id(), true);
#if MXNET_USE_NCCL
std::lock_guard<std::mutex> l(Storage::Get()->GetMutex(Context::kGPU));
#endif // MXNET_USE_NCCL
cudaError_t e = cudaMalloc(&ret, size);
cudaError_t e = cudaMalloc(&handle->dptr, size);
if (e != cudaSuccess && e != cudaErrorCudartUnloading)
LOG(FATAL) << "CUDA: " << cudaGetErrorString(e);
#else // MXNET_USE_CUDA
LOG(FATAL) << "Please compile with CUDA enabled";
#endif // MXNET_USE_CUDA
return ret;
}

inline void GPUDeviceStorage::Free(Storage::Handle handle) {
#if MXNET_USE_CUDA
void * ptr = handle.dptr;
mxnet::common::cuda::DeviceStore device_store(handle.ctx.real_dev_id(), true);
#if MXNET_USE_NCCL
std::lock_guard<std::mutex> l(Storage::Get()->GetMutex(Context::kGPU));
#endif // MXNET_USE_NCCL
// throw special exception for caller to catch.
cudaError_t err = cudaFree(ptr);
cudaError_t err = cudaFree(handle.dptr);
// ignore unloading error, as memory has already been recycled
if (err != cudaSuccess && err != cudaErrorCudartUnloading) {
LOG(FATAL) << "CUDA: " << cudaGetErrorString(err);
Expand Down
2 changes: 1 addition & 1 deletion src/storage/naive_storage_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class NaiveStorageManager final : public StorageManager {

template <class DeviceStorage>
void NaiveStorageManager<DeviceStorage>::Alloc(Storage::Handle* handle) {
handle->dptr = DeviceStorage::Alloc(handle);
DeviceStorage::Alloc(handle);
}

template <class DeviceStorage>
Expand Down
21 changes: 10 additions & 11 deletions src/storage/pinned_memory_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

/*!
* Copyright (c) 2015 by Contributors
* \file cpu_device_storage.h
* \file pinned_memory_storage.h
* \brief CPU storage with pinned memory
*/
#ifndef MXNET_STORAGE_PINNED_MEMORY_STORAGE_H_
Expand All @@ -38,37 +38,36 @@ class PinnedMemoryStorage {
public:
/*!
* \brief Allocation.
* \param size Size to allocate.
* \return Pointer to the storage.
* \param handle Handle struct.
*/
inline static void* Alloc(Storage::Handle* handle);
inline static void Alloc(Storage::Handle* handle);

/*!
* \brief Deallocation.
* \param ptr Pointer to deallocate.
* \param handle Handle struct.
*/
inline static void Free(Storage::Handle handle);
};

inline void* PinnedMemoryStorage::Alloc(Storage::Handle* handle) {
void* ret = nullptr;
inline void PinnedMemoryStorage::Alloc(Storage::Handle* handle) {
handle->dptr = nullptr;
const size_t size = handle->size;
if (size == 0) return;

#if MXNET_USE_NCCL
std::lock_guard<std::mutex> lock(Storage::Get()->GetMutex(Context::kGPU));
#endif
mxnet::common::cuda::DeviceStore device_store(handle->ctx.real_dev_id(), true);
// make the memory available across all devices
CUDA_CALL(cudaHostAlloc(&ret, size, cudaHostAllocPortable));
return ret;
CUDA_CALL(cudaHostAlloc(&handle->dptr, size, cudaHostAllocPortable));
}

inline void PinnedMemoryStorage::Free(Storage::Handle handle) {
void * ptr = handle.dptr;
#if MXNET_USE_NCCL
std::lock_guard<std::mutex> lock(Storage::Get()->GetMutex(Context::kGPU));
#endif
mxnet::common::cuda::DeviceStore device_store(handle.ctx.real_dev_id(), true);
cudaError_t err = cudaFreeHost(ptr);
cudaError_t err = cudaFreeHost(handle.dptr);
// ignore unloading error, as memory has already been recycled
if (err != cudaSuccess && err != cudaErrorCudartUnloading) {
LOG(FATAL) << "CUDA: " << cudaGetErrorString(err);
Expand Down
20 changes: 20 additions & 0 deletions src/storage/pooled_storage_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,12 @@ class GPUPooledStorageManager final : public StorageManager {
}; // class GPUPooledStorageManager

void GPUPooledStorageManager::Alloc(Storage::Handle* handle) {
// Set dptr to nullptr when handle size is 0.
if (handle->size == 0) {
handle->dptr = nullptr;
return;
}

std::lock_guard<std::mutex> lock(Storage::Get()->GetMutex(Context::kGPU));
size_t size = RoundAllocSize(handle->size);
auto&& reuse_it = memory_pool_.find(size);
Expand All @@ -155,6 +161,10 @@ void GPUPooledStorageManager::Alloc(Storage::Handle* handle) {
}

void GPUPooledStorageManager::Free(Storage::Handle handle) {
// Do nothing if dptr is nullptr. Otherwise, nullptr may be reused
// which can cause illegal memory access error.
if (handle.dptr == nullptr) return;

std::lock_guard<std::mutex> lock(Storage::Get()->GetMutex(Context::kGPU));
size_t size = RoundAllocSize(handle.size);
auto&& reuse_pool = memory_pool_[size];
Expand Down Expand Up @@ -286,6 +296,12 @@ class GPUPooledRoundedStorageManager final : public StorageManager {
}; // class GPUPooledRoundedStorageManager

void GPUPooledRoundedStorageManager::Alloc(Storage::Handle* handle) {
// Set dptr to nullptr when handle size is 0.
if (handle->size == 0) {
handle->dptr = nullptr;
return;
}

std::lock_guard<std::mutex> lock(Storage::Get()->GetMutex(Context::kGPU));
int bucket = get_bucket(handle->size);
size_t size = get_size(bucket);
Expand All @@ -312,6 +328,10 @@ void GPUPooledRoundedStorageManager::Alloc(Storage::Handle* handle) {
}

void GPUPooledRoundedStorageManager::Free(Storage::Handle handle) {
// Do nothing if dptr is nullptr. Otherwise, nullptr may be reused
// which can cause illegal memory access error.
if (handle.dptr == nullptr) return;

std::lock_guard<std::mutex> lock(Storage::Get()->GetMutex(Context::kGPU));
int bucket = get_bucket(handle.size);
auto&& reuse_pool = memory_pool_[bucket];
Expand Down
8 changes: 8 additions & 0 deletions src/storage/storage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ void StorageImpl::Alloc(Storage::Handle* handle) {
}

void StorageImpl::Free(Storage::Handle handle) {
// Do nothing if dtpr is nullptr because the handle may have already
// been freed or have not been allocated memory yet.
if (handle.dptr == nullptr) return;

const Context &ctx = handle.ctx;
auto&& device = storage_managers_.at(ctx.dev_type);
std::shared_ptr<storage::StorageManager> manager = device.Get(
Expand All @@ -140,6 +144,10 @@ void StorageImpl::Free(Storage::Handle handle) {
}

void StorageImpl::DirectFree(Storage::Handle handle) {
// Do nothing if dtpr is nullptr because the handle may have already
// been freed or have not been allocated memory yet.
if (handle.dptr == nullptr) return;

const Context &ctx = handle.ctx;
auto&& device = storage_managers_.at(ctx.dev_type);
std::shared_ptr<storage::StorageManager> manager = device.Get(
Expand Down
11 changes: 4 additions & 7 deletions src/storage/storage_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,17 @@ class StorageManager {
public:
/*!
* \brief Allocation.
* \param size Size to allocate.
* \return Pointer to the storage.
* \param handle Handle struct.
*/
virtual void Alloc(Storage::Handle* handle) = 0;
/*!
* \brief Deallocation.
* \param ptr Pointer to deallocate.
* \param size Size of the storage.
* \param handle Handle struct.
*/
virtual void Free(Storage::Handle handle) = 0;
/*!
* \brief Direct de-allocation.
* \param ptr Pointer to deallocate.
* \param size Size of the storage.
* \brief Direct deallocation.
* \param handle Handle struct.
*/
virtual void DirectFree(Storage::Handle handle) = 0;
/*!
Expand Down
8 changes: 3 additions & 5 deletions tests/cpp/include/test_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,9 @@ class BlobMemory {
return handle_.dptr;
}
void Free() {
if (handle_.dptr) {
Storage *storage = mxnet::Storage::Get();
storage->DirectFree(handle_);
handle_.dptr = nullptr;
}
mxnet::Storage::Get()->DirectFree(handle_);
handle_.dptr = nullptr;
yuxihu marked this conversation as resolved.
Show resolved Hide resolved
handle_.size = 0;
}
size_t Size() const {
return handle_.size;
Expand Down
Loading