diff --git a/src/storage/pooled_storage_manager.h b/src/storage/pooled_storage_manager.h index 91eb536ec7bd..6e54ddd7e52a 100644 --- a/src/storage/pooled_storage_manager.h +++ b/src/storage/pooled_storage_manager.h @@ -153,8 +153,16 @@ void GPUPooledStorageManager::Alloc(Storage::Handle* handle) { void* ret = nullptr; cudaError_t e = cudaMalloc(&ret, size); - if (e != cudaSuccess && e != cudaErrorCudartUnloading) { - LOG(FATAL) << "cudaMalloc failed: " << cudaGetErrorString(e); + if (e != cudaSuccess) { + if (e == cudaErrorMemoryAllocation) { + ReleaseAll(); + e = cudaMalloc(&ret, size); + if (e != cudaSuccess && e != cudaErrorCudartUnloading) { + LOG(FATAL) << "cudaMalloc retry failed: " << cudaGetErrorString(e); + } + } else if (e != cudaErrorCudartUnloading) { + LOG(FATAL) << "cudaMalloc failed: " << cudaGetErrorString(e); + } } used_memory_ += size; handle->dptr = ret; @@ -328,8 +336,16 @@ void GPUPooledRoundedStorageManager::Alloc(Storage::Handle* handle) { void* ret = nullptr; cudaError_t e = cudaMalloc(&ret, size); - if (e != cudaSuccess && e != cudaErrorCudartUnloading) { - LOG(FATAL) << "cudaMalloc failed: " << cudaGetErrorString(e); + if (e != cudaSuccess) { + if (e == cudaErrorMemoryAllocation) { + ReleaseAll(); + e = cudaMalloc(&ret, size); + if (e != cudaSuccess && e != cudaErrorCudartUnloading) { + LOG(FATAL) << "cudaMalloc retry failed: " << cudaGetErrorString(e); + } + } else if (e != cudaErrorCudartUnloading) { + LOG(FATAL) << "cudaMalloc failed: " << cudaGetErrorString(e); + } } used_memory_ += size; handle->dptr = ret;