Skip to content

Commit

Permalink
Fix memory leak for size-zero ndarray (apache#14365)
Browse files Browse the repository at this point in the history
* free memory for size zero storage handle

* skip adding nullptr into GPU memory pool

* set context for aux handle

* set context for aux handle once it is created
  • Loading branch information
yuxihu authored and eric-haibin-lin committed Mar 18, 2019
1 parent d671528 commit 3ab1dec
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 12 deletions.
15 changes: 7 additions & 8 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -909,9 +909,6 @@ class NDArray {
// aux_handles always reflect the correct number of aux data
for (size_t i = 0; i < aux_shapes.size(); i++) {
CheckAndAllocAuxData(i, aux_shapes[i]);
// this line is needed in case when aux_shapes[i].Size() = 0
// aux_handles[i] will not be updated and take only default value.
aux_handles[i].ctx = ctx;
}
if (!delay_alloc) {
CheckAndAllocData(storage_shape, dtype);
Expand Down Expand Up @@ -986,8 +983,8 @@ class NDArray {
#endif
delay_alloc = false;
} else if (shandle.size < dbytes) {
// free storage if necessary and alloc again
if (shandle.size > 0) Storage::Get()->Free(shandle);
// free storage
Storage::Get()->Free(shandle);
// init storage
shandle = Storage::Get()->Alloc(dbytes, shandle.ctx);
#if MXNET_USE_MKLDNN == 1
Expand Down Expand Up @@ -1052,12 +1049,14 @@ class NDArray {
<< "storage type cannot be kDefaultStorage in CheckAndAllocAuxData";
if (aux_handles.size() <= i) {
aux_handles.resize(i + 1);
// set context for the newly created aux handle
aux_handles[i].ctx = ctx;
}
size_t aux_bytes = shape.Size() * mshadow::mshadow_sizeof(aux_types[i]);
if (aux_handles[i].size < aux_bytes) {
// free storage if necessary and alloc again
if (aux_handles[i].size > 0) Storage::Get()->Free(aux_handles[i]);
// init aux storage
// free storage
Storage::Get()->Free(aux_handles[i]);
// init storage
aux_handles[i] = Storage::Get()->Alloc(aux_bytes, ctx);
}
// init shape
Expand Down
8 changes: 4 additions & 4 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ NDArray::Chunk::~Chunk() {
CHECK_EQ(mem.mem->GetDataHandle(), mem.h.dptr);
}
#endif
if (mem.h.size > 0) Storage::Get()->Free(mem.h);
Storage::Get()->Free(mem.h);
for (const auto& aux : mem.aux_h) {
if (aux.size > 0) Storage::Get()->Free(aux);
Storage::Get()->Free(aux);
}
}
}, shandle.ctx, var);
Expand All @@ -134,8 +134,8 @@ void NDArray::Chunk::CheckAndAllocData(const mxnet::TShape &shape, int dtype) {
<< "data is expected to be allocated after aux_data";
auto dbytes = shape.Size() * mshadow::mshadow_sizeof(dtype);
if (shandle.size < dbytes) {
// free storage if necessary and alloc again
if (shandle.size > 0) Storage::Get()->Free(shandle);
// free storage
Storage::Get()->Free(shandle);
// init storage
shandle = Storage::Get()->Alloc(dbytes, ctx);
#if MXNET_USE_MKLDNN == 1
Expand Down
8 changes: 8 additions & 0 deletions src/storage/pooled_storage_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ void GPUPooledStorageManager::Alloc(Storage::Handle* handle) {
}

void GPUPooledStorageManager::Free(Storage::Handle handle) {
// Do nothing if dptr is nullptr. Otherwise, nullptr may be reused
// which can cause illegal memory access error.
if (handle.dptr == nullptr) return;

std::lock_guard<std::mutex> lock(Storage::Get()->GetMutex(Context::kGPU));
size_t size = RoundAllocSize(handle.size);
auto&& reuse_pool = memory_pool_[size];
Expand Down Expand Up @@ -312,6 +316,10 @@ void GPUPooledRoundedStorageManager::Alloc(Storage::Handle* handle) {
}

void GPUPooledRoundedStorageManager::Free(Storage::Handle handle) {
// Do nothing if dptr is nullptr. Otherwise, nullptr may be reused
// which can cause illegal memory access error.
if (handle.dptr == nullptr) return;

std::lock_guard<std::mutex> lock(Storage::Get()->GetMutex(Context::kGPU));
int bucket = get_bucket(handle.size);
auto&& reuse_pool = memory_pool_[bucket];
Expand Down

0 comments on commit 3ab1dec

Please sign in to comment.