Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add mx.context.gpu_memory_info() to python api for flexible tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
DickJC123 committed Oct 10, 2018
1 parent 822e59f commit 7ad40a2
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 11 deletions.
14 changes: 7 additions & 7 deletions include/mxnet/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,11 +225,11 @@ struct Context {
/*!
* \brief get the free and total available memory on a GPU
* \param dev the GPU number to query
* \param free_mem pointer to the integer holding free GPU memory
* \param total_mem pointer to the integer holding total GPU memory
* \param free_mem pointer to the size_t holding free GPU memory
* \param total_mem pointer to the size_t holding total GPU memory
* \return No return value
*/
inline static void GetGPUMemoryInformation(int dev, int *free, int *total);
inline static void GetGPUMemoryInformation(int dev, size_t *free, size_t *total);
/*!
* Create a pinned CPU context.
* \param dev_id the device id for corresponding GPU.
Expand Down Expand Up @@ -334,8 +334,8 @@ inline int32_t Context::GetGPUCount() {
#endif
}

inline void Context::GetGPUMemoryInformation(int dev, int *free_mem,
int *total_mem) {
inline void Context::GetGPUMemoryInformation(int dev, size_t *free_mem,
size_t *total_mem) {
#if MXNET_USE_CUDA

size_t memF, memT;
Expand All @@ -354,8 +354,8 @@ inline void Context::GetGPUMemoryInformation(int dev, int *free_mem,
e = cudaSetDevice(curDevice);
CHECK_EQ(e, cudaSuccess) << " CUDA: " << cudaGetErrorString(e);

*free_mem = static_cast<int>(memF);
*total_mem = static_cast<int>(memT);
*free_mem = memF;
*total_mem = memT;

#else
LOG(FATAL)
Expand Down
6 changes: 3 additions & 3 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -442,11 +442,11 @@ MXNET_DLL int MXGetGPUCount(int* out);
/*!
* \brief get the free and total available memory on a GPU
* \param dev the GPU number to query
* \param free_mem pointer to the integer holding free GPU memory
* \param total_mem pointer to the integer holding total GPU memory
* \param free_mem pointer to the size_t holding free GPU memory
* \param total_mem pointer to the size_t holding total GPU memory
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXGetGPUMemoryInformation(int dev, int *free_mem, int *total_mem);
MXNET_DLL int MXGetGPUMemoryInformation(int dev, size_t *free_mem, size_t *total_mem);

/*!
* \brief get the MXNet library version as an integer
Expand Down
24 changes: 24 additions & 0 deletions python/mxnet/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,30 @@ def num_gpus():
check_call(_LIB.MXGetGPUCount(ctypes.byref(count)))
return count.value

def gpu_memory_info(device_id=0):
"""Query CUDA for the free and total bytes of GPU global memory.
Parameters
----------
device_id : int, optional
The device id of the GPU device.
Raises
------
Will raise an exception on any CUDA error.
Returns
-------
(free, total) : (int, int)
The number of GPUs.
"""
free = ctypes.c_uint64()
total = ctypes.c_uint64()
dev_id = ctypes.c_int(device_id)
check_call(_LIB.MXGetGPUMemoryInformation(dev_id, ctypes.byref(free), ctypes.byref(total)))
return (free.value, total.value)

def current_context():
"""Returns the current context.
Expand Down
2 changes: 1 addition & 1 deletion src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ int MXGetGPUCount(int* out) {
API_END();
}

int MXGetGPUMemoryInformation(int dev, int *free_mem, int *total_mem) {
int MXGetGPUMemoryInformation(int dev, size_t *free_mem, size_t *total_mem) {
API_BEGIN();
Context::GetGPUMemoryInformation(dev, free_mem, total_mem);
API_END();
Expand Down

0 comments on commit 7ad40a2

Please sign in to comment.