Skip to content

Commit

Permalink
[MEMORY] retry GPU memory allocation if fragmented (apache#16194)
Browse files Browse the repository at this point in the history
  • Loading branch information
szha authored and larroy committed Sep 28, 2019
1 parent 8147279 commit c068f3f
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions src/storage/pooled_storage_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,16 @@ void GPUPooledStorageManager::Alloc(Storage::Handle* handle) {

void* ret = nullptr;
cudaError_t e = cudaMalloc(&ret, size);
if (e != cudaSuccess && e != cudaErrorCudartUnloading) {
LOG(FATAL) << "cudaMalloc failed: " << cudaGetErrorString(e);
if (e != cudaSuccess) {
if (e == cudaErrorMemoryAllocation) {
ReleaseAll();
e = cudaMalloc(&ret, size);
if (e != cudaSuccess && e != cudaErrorCudartUnloading) {
LOG(FATAL) << "cudaMalloc retry failed: " << cudaGetErrorString(e);
}
} else if (e != cudaErrorCudartUnloading) {
LOG(FATAL) << "cudaMalloc failed: " << cudaGetErrorString(e);
}
}
used_memory_ += size;
handle->dptr = ret;
Expand Down Expand Up @@ -328,8 +336,16 @@ void GPUPooledRoundedStorageManager::Alloc(Storage::Handle* handle) {

void* ret = nullptr;
cudaError_t e = cudaMalloc(&ret, size);
if (e != cudaSuccess && e != cudaErrorCudartUnloading) {
LOG(FATAL) << "cudaMalloc failed: " << cudaGetErrorString(e);
if (e != cudaSuccess) {
if (e == cudaErrorMemoryAllocation) {
ReleaseAll();
e = cudaMalloc(&ret, size);
if (e != cudaSuccess && e != cudaErrorCudartUnloading) {
LOG(FATAL) << "cudaMalloc retry failed: " << cudaGetErrorString(e);
}
} else if (e != cudaErrorCudartUnloading) {
LOG(FATAL) << "cudaMalloc failed: " << cudaGetErrorString(e);
}
}
used_memory_ += size;
handle->dptr = ret;
Expand Down

0 comments on commit c068f3f

Please sign in to comment.