Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Allow clearing gpu cache #14252

Merged
merged 9 commits into from
May 25, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2736,6 +2736,12 @@ MXNET_DLL int MXNDArrayGetSharedMemHandle(NDArrayHandle handle, int* shared_pid,
MXNET_DLL int MXNDArrayCreateFromSharedMem(int shared_pid, int shared_id, const mx_uint *shape,
mx_uint ndim, int dtype, NDArrayHandle *out);

/*!
* \brief Release all unreferenced memory from the devices storage managers memory pool
* \param dev_type device type, specify device we want to take
* \param dev_id the device id of the specific device
*/
MXNET_DLL int MXStorageEmptyCache(int dev_type, int dev_id);

/*!
* \brief Reconstruct NDArray from shared memory handle
Expand Down
8 changes: 8 additions & 0 deletions include/mxnet/storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,14 @@ class Storage {
* \param handle Handle struct.
*/
virtual void DirectFree(Handle handle) = 0;
/*!
* \brief Release all memory from device if using a pooled storage manager
*
* This release all memory from pool storage managers such as
* GPUPooledStorageManager and GPUPooledRoundedStorageManager.
* For non-pool memory managers this has no effect.
*/
virtual void ReleaseAll(Context ctx) = 0;
/*!
* \brief Destructor.
*/
Expand Down
18 changes: 18 additions & 0 deletions python/mxnet/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,24 @@ def default_ctx(cls, val):
cls._default_ctx.value = val
#pylint: enable=no-self-argument

def empty_cache(self):
"""Empties the memory cache for the current contexts device.

MXNet utilizes a memory pool to avoid excessive allocations.
Calling empty_cache will empty the memory pool of the contexts
device. This will only free the memory of the unreferenced data.

Examples
-------
>>> ctx = mx.gpu(0)
>>> arr = mx.nd.ones((200,200), ctx=ctx)
>>> del arr
>>> ctx.empty_cache() # forces release of memory allocated for arr
"""
dev_type = ctypes.c_int(self.device_typeid)
dev_id = ctypes.c_int(self.device_id)
check_call(_LIB.MXStorageEmptyCache(dev_type, dev_id))

# initialize the default context in Context
Context._default_ctx.value = Context('cpu', 0)

Expand Down
7 changes: 7 additions & 0 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1526,3 +1526,10 @@ int MXEnginePushSync(EngineSyncFunc sync_func, void* func_param,

API_END();
}

int MXStorageEmptyCache(int dev_type, int dev_id) {
API_BEGIN();
Context ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id);
Storage::Get()->ReleaseAll(ctx);
API_END();
}
6 changes: 4 additions & 2 deletions src/storage/pooled_storage_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ class GPUPooledStorageManager final : public StorageManager {
DirectFreeNoLock(handle);
}

void ReleaseAll() override;

private:
void DirectFreeNoLock(Storage::Handle handle) {
mxnet::common::cuda::DeviceStore device_store(handle.ctx.real_dev_id(), true);
Expand Down Expand Up @@ -115,7 +117,6 @@ class GPUPooledStorageManager final : public StorageManager {
}

private:
void ReleaseAll();
// used memory
size_t used_memory_ = 0;
// page size
Expand Down Expand Up @@ -250,6 +251,8 @@ class GPUPooledRoundedStorageManager final : public StorageManager {
DirectFreeNoLock(handle);
}

void ReleaseAll() override;

private:
inline int div_pow2_round_up(size_t s, int divisor_log2) {
// (1025, 10) -> 2
Expand Down Expand Up @@ -284,7 +287,6 @@ class GPUPooledRoundedStorageManager final : public StorageManager {
}

private:
void ReleaseAll();
// number of devices
const int NDEV = 32;
// log2 of maximum page size. 16GB
Expand Down
12 changes: 12 additions & 0 deletions src/storage/storage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "./pooled_storage_manager.h"
#include "./cpu_shared_storage_manager.h"
#include "./cpu_device_storage.h"
#include "./gpu_device_storage.h"
#include "./pinned_memory_storage.h"
#include "../common/lazy_alloc_array.h"
#include "../profiler/storage_profiler.h"
Expand All @@ -38,6 +39,7 @@ class StorageImpl : public Storage {
void Alloc(Handle* handle) override;
void Free(Handle handle) override;
void DirectFree(Handle handle) override;
void ReleaseAll(Context ctx) override;
void SharedIncrementRefCount(Handle handle) override;
StorageImpl() {}
virtual ~StorageImpl() = default;
Expand Down Expand Up @@ -160,6 +162,16 @@ void StorageImpl::DirectFree(Storage::Handle handle) {
profiler_.OnFree(handle);
}

void StorageImpl::ReleaseAll(Context ctx) {
auto&& device = storage_managers_.at(ctx.dev_type);
std::shared_ptr<storage::StorageManager> manager = device.Get(
ctx.real_dev_id(), []() {
LOG(FATAL) << "Cannot Free space to a device you have not allocated";
return nullptr;
});
manager->ReleaseAll();
}

void StorageImpl::SharedIncrementRefCount(Storage::Handle handle) {
CHECK_EQ(handle.ctx.dev_type, Context::kCPUShared);
auto&& device = storage_managers_.at(Context::kCPUShared);
Expand Down
8 changes: 8 additions & 0 deletions src/storage/storage_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ class StorageManager {
* \param handle Handle struct.
*/
virtual void DirectFree(Storage::Handle handle) = 0;
/*!
* \brief Release all memory if using a pool storage manager
*
* This release all memory from pool storage managers such as
* GPUPooledStorageManager and GPUPooledRoundedStorageManager.
* For non-pool memory managers this has no effect.
*/
virtual void ReleaseAll() {}
/*!
* \brief Destructor.
*/
Expand Down