Skip to content

Commit

Permalink
Tidy up storage allocation and deallocation (apache#14480)
Browse files Browse the repository at this point in the history
* free memory when dptr is not nullptr

* skip memory allocation when handle size is 0

* update comments

* update Alloc in naive storage manager

* address comments

* add unit test for size 0 allocation
  • Loading branch information
yuxihu committed Apr 22, 2019
1 parent 9b6e84e commit 464d74d
Show file tree
Hide file tree
Showing 13 changed files with 109 additions and 70 deletions.
8 changes: 4 additions & 4 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -979,8 +979,8 @@ class NDArray {
#endif
delay_alloc = false;
} else if (shandle.size < dbytes) {
// free storage if necessary and alloc again
if (shandle.size > 0) Storage::Get()->Free(shandle);
// free storage
Storage::Get()->Free(shandle);
// init storage
shandle = Storage::Get()->Alloc(dbytes, shandle.ctx);
#if MXNET_USE_MKLDNN == 1
Expand Down Expand Up @@ -1048,8 +1048,8 @@ class NDArray {
}
size_t aux_bytes = shape.Size() * mshadow::mshadow_sizeof(aux_types[i]);
if (aux_handles[i].size < aux_bytes) {
// free storage if necessary and alloc again
if (aux_handles[i].size > 0) Storage::Get()->Free(aux_handles[i]);
// free storage
Storage::Get()->Free(aux_handles[i]);
// init aux storage
aux_handles[i] = Storage::Get()->Alloc(aux_bytes, ctx);
}
Expand Down
8 changes: 4 additions & 4 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ NDArray::Chunk::~Chunk() {
CHECK_EQ(mem.mem->GetDataHandle(), mem.h.dptr);
}
#endif
if (mem.h.size > 0) Storage::Get()->Free(mem.h);
Storage::Get()->Free(mem.h);
for (const auto& aux : mem.aux_h) {
if (aux.size > 0) Storage::Get()->Free(aux);
Storage::Get()->Free(aux);
}
}
}, shandle.ctx, var);
Expand All @@ -134,8 +134,8 @@ void NDArray::Chunk::CheckAndAllocData(const TShape &shape, int dtype) {
<< "data is expected to be allocated after aux_data";
auto dbytes = shape.Size() * mshadow::mshadow_sizeof(dtype);
if (shandle.size < dbytes) {
// free storage if necessary and alloc again
if (shandle.size > 0) Storage::Get()->Free(shandle);
// free storage
Storage::Get()->Free(shandle);
// init storage
shandle = Storage::Get()->Alloc(dbytes, ctx);
#if MXNET_USE_MKLDNN == 1
Expand Down
27 changes: 13 additions & 14 deletions src/resource.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,30 +53,29 @@ struct SpaceAllocator {
host_handle.dptr = nullptr;
host_handle.size = 0;
}

inline void ReleaseAll() {
if (handle.size != 0) {
Storage::Get()->DirectFree(handle);
handle.size = 0;
}
if (host_handle.size != 0) {
Storage::Get()->DirectFree(host_handle);
host_handle.size = 0;
}
Storage::Get()->DirectFree(handle);
handle.dptr = nullptr;
handle.size = 0;

Storage::Get()->DirectFree(host_handle);
host_handle.dptr = nullptr;
host_handle.size = 0;
}

inline void* GetSpace(size_t size) {
if (handle.size >= size) return handle.dptr;
if (handle.size != 0) {
Storage::Get()->DirectFree(handle);
}

Storage::Get()->DirectFree(handle);
handle = Storage::Get()->Alloc(size, ctx);
return handle.dptr;
}

inline void* GetHostSpace(size_t size) {
if (host_handle.size >= size) return host_handle.dptr;
if (host_handle.size != 0) {
Storage::Get()->DirectFree(host_handle);
}

Storage::Get()->DirectFree(host_handle);
host_handle = Storage::Get()->Alloc(size, Context());
return host_handle.dptr;
}
Expand Down
25 changes: 12 additions & 13 deletions src/storage/cpu_device_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,12 @@ class CPUDeviceStorage {
public:
/*!
* \brief Aligned allocation on CPU.
* \param size Size to allocate.
* \return Pointer to the storage.
* \param handle Handle struct.
*/
inline static void* Alloc(Storage::Handle* handle);
inline static void Alloc(Storage::Handle* handle);
/*!
* \brief Deallocation.
* \param ptr Pointer to deallocate.
* \param handle Handle struct.
*/
inline static void Free(Storage::Handle handle);

Expand All @@ -63,25 +62,25 @@ class CPUDeviceStorage {
#endif
}; // class CPUDeviceStorage

inline void* CPUDeviceStorage::Alloc(Storage::Handle* handle) {
inline void CPUDeviceStorage::Alloc(Storage::Handle* handle) {
handle->dptr = nullptr;
const size_t size = handle->size;
void* ptr;
if (size == 0) return;

#if _MSC_VER
ptr = _aligned_malloc(size, alignment_);
if (ptr == NULL) LOG(FATAL) << "Failed to allocate CPU Memory";
handle->dptr = _aligned_malloc(size, alignment_);
if (handle->dptr == nullptr) LOG(FATAL) << "Failed to allocate CPU Memory";
#else
int ret = posix_memalign(&ptr, alignment_, size);
int ret = posix_memalign(&handle->dptr, alignment_, size);
if (ret != 0) LOG(FATAL) << "Failed to allocate CPU Memory";
#endif
return ptr;
}

inline void CPUDeviceStorage::Free(Storage::Handle handle) {
void * ptr = handle.dptr;
#if _MSC_VER
_aligned_free(ptr);
_aligned_free(handle.dptr);
#else
free(ptr);
free(handle.dptr);
#endif
}

Expand Down
7 changes: 6 additions & 1 deletion src/storage/cpu_shared_storage_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,18 @@ class CPUSharedStorageManager final : public StorageManager {
}; // class CPUSharedStorageManager

void CPUSharedStorageManager::Alloc(Storage::Handle* handle) {
if (handle->size == 0) {
handle->dptr = nullptr;
return;
}

std::lock_guard<std::recursive_mutex> lock(mutex_);
std::uniform_int_distribution<> dis(0, std::numeric_limits<int>::max());
int fid = -1;
std::string filename;
bool is_new = false;
size_t size = handle->size + alignment_;
void *ptr = nullptr;
void* ptr = nullptr;
#ifdef _WIN32
CheckAndRealFree();
HANDLE map_handle = nullptr;
Expand Down
19 changes: 9 additions & 10 deletions src/storage/gpu_device_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,43 +43,42 @@ class GPUDeviceStorage {
public:
/*!
* \brief Allocation.
* \param size Size to allocate.
* \return Pointer to the storage.
* \param handle Handle struct.
*/
inline static void* Alloc(Storage::Handle* handle);
inline static void Alloc(Storage::Handle* handle);
/*!
* \brief Deallocation.
* \param ptr Pointer to deallocate.
* \param handle Handle struct.
*/
inline static void Free(Storage::Handle handle);
}; // class GPUDeviceStorage

inline void* GPUDeviceStorage::Alloc(Storage::Handle* handle) {
inline void GPUDeviceStorage::Alloc(Storage::Handle* handle) {
handle->dptr = nullptr;
const size_t size = handle->size;
void* ret = nullptr;
if (size == 0) return;

#if MXNET_USE_CUDA
mxnet::common::cuda::DeviceStore device_store(handle->ctx.real_dev_id(), true);
#if MXNET_USE_NCCL
std::lock_guard<std::mutex> l(Storage::Get()->GetMutex(Context::kGPU));
#endif // MXNET_USE_NCCL
cudaError_t e = cudaMalloc(&ret, size);
cudaError_t e = cudaMalloc(&handle->dptr, size);
if (e != cudaSuccess && e != cudaErrorCudartUnloading)
LOG(FATAL) << "CUDA: " << cudaGetErrorString(e);
#else // MXNET_USE_CUDA
LOG(FATAL) << "Please compile with CUDA enabled";
#endif // MXNET_USE_CUDA
return ret;
}

inline void GPUDeviceStorage::Free(Storage::Handle handle) {
#if MXNET_USE_CUDA
void * ptr = handle.dptr;
mxnet::common::cuda::DeviceStore device_store(handle.ctx.real_dev_id(), true);
#if MXNET_USE_NCCL
std::lock_guard<std::mutex> l(Storage::Get()->GetMutex(Context::kGPU));
#endif // MXNET_USE_NCCL
// throw special exception for caller to catch.
cudaError_t err = cudaFree(ptr);
cudaError_t err = cudaFree(handle.dptr);
// ignore unloading error, as memory has already been recycled
if (err != cudaSuccess && err != cudaErrorCudartUnloading) {
LOG(FATAL) << "CUDA: " << cudaGetErrorString(err);
Expand Down
2 changes: 1 addition & 1 deletion src/storage/naive_storage_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class NaiveStorageManager final : public StorageManager {

template <class DeviceStorage>
void NaiveStorageManager<DeviceStorage>::Alloc(Storage::Handle* handle) {
handle->dptr = DeviceStorage::Alloc(handle);
DeviceStorage::Alloc(handle);
}

template <class DeviceStorage>
Expand Down
21 changes: 10 additions & 11 deletions src/storage/pinned_memory_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

/*!
* Copyright (c) 2015 by Contributors
* \file cpu_device_storage.h
* \file pinned_memory_storage.h
* \brief CPU storage with pinned memory
*/
#ifndef MXNET_STORAGE_PINNED_MEMORY_STORAGE_H_
Expand All @@ -38,37 +38,36 @@ class PinnedMemoryStorage {
public:
/*!
* \brief Allocation.
* \param size Size to allocate.
* \return Pointer to the storage.
* \param handle Handle struct.
*/
inline static void* Alloc(Storage::Handle* handle);
inline static void Alloc(Storage::Handle* handle);

/*!
* \brief Deallocation.
* \param ptr Pointer to deallocate.
* \param handle Handle struct.
*/
inline static void Free(Storage::Handle handle);
};

inline void* PinnedMemoryStorage::Alloc(Storage::Handle* handle) {
void* ret = nullptr;
inline void PinnedMemoryStorage::Alloc(Storage::Handle* handle) {
handle->dptr = nullptr;
const size_t size = handle->size;
if (size == 0) return;

#if MXNET_USE_NCCL
std::lock_guard<std::mutex> lock(Storage::Get()->GetMutex(Context::kGPU));
#endif
mxnet::common::cuda::DeviceStore device_store(handle->ctx.real_dev_id(), true);
// make the memory available across all devices
CUDA_CALL(cudaHostAlloc(&ret, size, cudaHostAllocPortable));
return ret;
CUDA_CALL(cudaHostAlloc(&handle->dptr, size, cudaHostAllocPortable));
}

inline void PinnedMemoryStorage::Free(Storage::Handle handle) {
void * ptr = handle.dptr;
#if MXNET_USE_NCCL
std::lock_guard<std::mutex> lock(Storage::Get()->GetMutex(Context::kGPU));
#endif
mxnet::common::cuda::DeviceStore device_store(handle.ctx.real_dev_id(), true);
cudaError_t err = cudaFreeHost(ptr);
cudaError_t err = cudaFreeHost(handle.dptr);
// ignore unloading error, as memory has already been recycled
if (err != cudaSuccess && err != cudaErrorCudartUnloading) {
LOG(FATAL) << "CUDA: " << cudaGetErrorString(err);
Expand Down
20 changes: 20 additions & 0 deletions src/storage/pooled_storage_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,12 @@ class GPUPooledStorageManager final : public StorageManager {
}; // class GPUPooledStorageManager

void GPUPooledStorageManager::Alloc(Storage::Handle* handle) {
// Set dptr to nullptr when handle size is 0.
if (handle->size == 0) {
handle->dptr = nullptr;
return;
}

std::lock_guard<std::mutex> lock(Storage::Get()->GetMutex(Context::kGPU));
size_t size = RoundAllocSize(handle->size);
auto&& reuse_it = memory_pool_.find(size);
Expand All @@ -155,6 +161,10 @@ void GPUPooledStorageManager::Alloc(Storage::Handle* handle) {
}

void GPUPooledStorageManager::Free(Storage::Handle handle) {
// Do nothing if dptr is nullptr. Otherwise, nullptr may be reused
// which can cause illegal memory access error.
if (handle.dptr == nullptr) return;

std::lock_guard<std::mutex> lock(Storage::Get()->GetMutex(Context::kGPU));
size_t size = RoundAllocSize(handle.size);
auto&& reuse_pool = memory_pool_[size];
Expand Down Expand Up @@ -286,6 +296,12 @@ class GPUPooledRoundedStorageManager final : public StorageManager {
}; // class GPUPooledRoundedStorageManager

void GPUPooledRoundedStorageManager::Alloc(Storage::Handle* handle) {
// Set dptr to nullptr when handle size is 0.
if (handle->size == 0) {
handle->dptr = nullptr;
return;
}

std::lock_guard<std::mutex> lock(Storage::Get()->GetMutex(Context::kGPU));
int bucket = get_bucket(handle->size);
size_t size = get_size(bucket);
Expand All @@ -312,6 +328,10 @@ void GPUPooledRoundedStorageManager::Alloc(Storage::Handle* handle) {
}

void GPUPooledRoundedStorageManager::Free(Storage::Handle handle) {
// Do nothing if dptr is nullptr. Otherwise, nullptr may be reused
// which can cause illegal memory access error.
if (handle.dptr == nullptr) return;

std::lock_guard<std::mutex> lock(Storage::Get()->GetMutex(Context::kGPU));
int bucket = get_bucket(handle.size);
auto&& reuse_pool = memory_pool_[bucket];
Expand Down
8 changes: 8 additions & 0 deletions src/storage/storage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ void StorageImpl::Alloc(Storage::Handle* handle) {
}

void StorageImpl::Free(Storage::Handle handle) {
// Do nothing if dtpr is nullptr because the handle may have already
// been freed or have not been allocated memory yet.
if (handle.dptr == nullptr) return;

const Context &ctx = handle.ctx;
auto&& device = storage_managers_.at(ctx.dev_type);
std::shared_ptr<storage::StorageManager> manager = device.Get(
Expand All @@ -140,6 +144,10 @@ void StorageImpl::Free(Storage::Handle handle) {
}

void StorageImpl::DirectFree(Storage::Handle handle) {
// Do nothing if dtpr is nullptr because the handle may have already
// been freed or have not been allocated memory yet.
if (handle.dptr == nullptr) return;

const Context &ctx = handle.ctx;
auto&& device = storage_managers_.at(ctx.dev_type);
std::shared_ptr<storage::StorageManager> manager = device.Get(
Expand Down
11 changes: 4 additions & 7 deletions src/storage/storage_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,17 @@ class StorageManager {
public:
/*!
* \brief Allocation.
* \param size Size to allocate.
* \return Pointer to the storage.
* \param handle Handle struct.
*/
virtual void Alloc(Storage::Handle* handle) = 0;
/*!
* \brief Deallocation.
* \param ptr Pointer to deallocate.
* \param size Size of the storage.
* \param handle Handle struct.
*/
virtual void Free(Storage::Handle handle) = 0;
/*!
* \brief Direct de-allocation.
* \param ptr Pointer to deallocate.
* \param size Size of the storage.
* \brief Direct deallocation.
* \param handle Handle struct.
*/
virtual void DirectFree(Storage::Handle handle) = 0;
/*!
Expand Down
8 changes: 3 additions & 5 deletions tests/cpp/include/test_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,9 @@ class BlobMemory {
return handle_.dptr;
}
void Free() {
if (handle_.dptr) {
Storage *storage = mxnet::Storage::Get();
storage->DirectFree(handle_);
handle_.dptr = nullptr;
}
mxnet::Storage::Get()->DirectFree(handle_);
handle_.dptr = nullptr;
handle_.size = 0;
}
size_t Size() const {
return handle_.size;
Expand Down
Loading

0 comments on commit 464d74d

Please sign in to comment.