Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add unpooled gpu memory type (#14716)
Browse files Browse the repository at this point in the history
* Add unpooled gpu memory type

* Adding missing header

* undo bad rebase change
  • Loading branch information
vladoovtcharov authored and szha committed Apr 30, 2019
1 parent 3a46980 commit 07aef13
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
6 changes: 5 additions & 1 deletion docs/faq/env_var.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,20 @@ $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
* MXNET_GPU_MEM_POOL_RESERVE
- Values: Int ```(default=5)```
- The percentage of GPU memory to reserve for things other than the GPU array, such as kernel launch or cudnn handle space.
- If you see a strange out-of-memory error from the kernel launch, after multiple iterations, try setting this to a larger value.
- If you see a strange out-of-memory error from the kernel launch, after multiple iterations, try setting this to a larger value.

* MXNET_GPU_MEM_POOL_TYPE
- Values: String ```(default=Naive)```
- The type of memory pool.
- Choices:
- Naive: A simple memory pool that allocates memory for the exact requested size and cache memory buffers. If a buffered memory chunk matches the size of a new request, the chunk from the memory pool will be returned and reused.
- Round: A memory pool that always rounds the requested memory size and allocates memory of the rounded size. MXNET_GPU_MEM_POOL_ROUND_LINEAR_CUTOFF defines how to round up a memory size. Caching and allocating buffered memory works in the same way as the naive memory pool.
- Unpooled: No memory pool is used.

* MXNET_GPU_MEM_POOL_ROUND_LINEAR_CUTOFF
- Values: Int ```(default=24)```
- The cutoff threshold that decides the rounding strategy. Let's denote the threshold as T. If the memory size is smaller than `2 ** T` (by default, it's 2 ** 24 = 16MB), it rounds to the smallest `2 ** n` that is larger than the requested memory size; if the memory size is larger than `2 ** T`, it rounds to the next k * 2 ** T.

* MXNET_GPU_MEM_LARGE_ALLOC_ROUND_SIZE
- Values: Int ```(default=2097152)```
- When using the naive pool type, memory allocations larger than this threshhold are rounded up to a multiple of this value.
Expand Down
10 changes: 6 additions & 4 deletions src/storage/storage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "./pooled_storage_manager.h"
#include "./cpu_shared_storage_manager.h"
#include "./cpu_device_storage.h"
#include "./gpu_device_storage.h"
#include "./pinned_memory_storage.h"
#include "../common/lazy_alloc_array.h"
#include "../profiler/storage_profiler.h"
Expand Down Expand Up @@ -106,11 +107,12 @@ void StorageImpl::Alloc(Storage::Handle* handle) {
if (strategy == "Round") {
ptr = new storage::GPUPooledRoundedStorageManager(handle->ctx);
LOG(INFO) << "Using GPUPooledRoundedStorageManager.";
} else {
if (strategy != "Naive") {
LOG(FATAL) << "Unknown memory pool strategy specified: " << strategy << ".";
}
} else if (strategy == "Naive") {
ptr = new storage::GPUPooledStorageManager(handle->ctx);
} else if (strategy == "Unpooled") {
ptr = new storage::NaiveStorageManager<storage::GPUDeviceStorage>();
} else {
LOG(FATAL) << "Unknown memory pool strategy specified: " << strategy << ".";
}
#else
LOG(FATAL) << "Compile with USE_CUDA=1 to enable GPU usage";
Expand Down

0 comments on commit 07aef13

Please sign in to comment.