Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

CudnnFind() usage improvements #12804

Merged
merged 14 commits into from
Oct 26, 2018
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions include/mxnet/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,11 +225,11 @@ struct Context {
/*!
* \brief get the free and total available memory on a GPU
* \param dev the GPU number to query
* \param free_mem pointer to the integer holding free GPU memory
* \param total_mem pointer to the integer holding total GPU memory
* \param free_mem pointer to the size_t holding free GPU memory
* \param total_mem pointer to the size_t holding total GPU memory
* \return No return value
*/
inline static void GetGPUMemoryInformation(int dev, int *free, int *total);
inline static void GetGPUMemoryInformation(int dev, size_t *free, size_t *total);
/*!
* Create a pinned CPU context.
* \param dev_id the device id for corresponding GPU.
Expand Down Expand Up @@ -334,8 +334,8 @@ inline int32_t Context::GetGPUCount() {
#endif
}

inline void Context::GetGPUMemoryInformation(int dev, int *free_mem,
int *total_mem) {
inline void Context::GetGPUMemoryInformation(int dev, size_t *free_mem,
size_t *total_mem) {
#if MXNET_USE_CUDA

size_t memF, memT;
Expand All @@ -354,8 +354,8 @@ inline void Context::GetGPUMemoryInformation(int dev, int *free_mem,
e = cudaSetDevice(curDevice);
CHECK_EQ(e, cudaSuccess) << " CUDA: " << cudaGetErrorString(e);

*free_mem = static_cast<int>(memF);
*total_mem = static_cast<int>(memT);
*free_mem = memF;
*total_mem = memT;

#else
LOG(FATAL)
Expand Down
6 changes: 3 additions & 3 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -442,11 +442,11 @@ MXNET_DLL int MXGetGPUCount(int* out);
/*!
* \brief get the free and total available memory on a GPU
* \param dev the GPU number to query
* \param free_mem pointer to the integer holding free GPU memory
* \param total_mem pointer to the integer holding total GPU memory
* \param free_mem pointer to the size_t holding free GPU memory
* \param total_mem pointer to the size_t holding total GPU memory
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXGetGPUMemoryInformation(int dev, int *free_mem, int *total_mem);
MXNET_DLL int MXGetGPUMemoryInformation(int dev, size_t *free_mem, size_t *total_mem);

/*!
* \brief get the MXNet library version as an integer
Expand Down
24 changes: 24 additions & 0 deletions python/mxnet/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,30 @@ def num_gpus():
check_call(_LIB.MXGetGPUCount(ctypes.byref(count)))
return count.value

def gpu_memory_info(device_id=0):
"""Query CUDA for the free and total bytes of GPU global memory.

Parameters
----------
device_id : int, optional
The device id of the GPU device.

Raises
------
Will raise an exception on any CUDA error.

Returns
-------
(free, total) : (int, int)
Copy link

@blac2kite blac2kite Oct 24, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: 'total' - is it referring to total used, total available or the total size of the physical GPU. Also, aren't they 64 bit integers. So maybe 'long' would be more appropriate. Since we are exposing this API in python, it'd be a good idea to document it well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to leave this as is. Regarding int vs long, I'm not a Python wizard, but ints are 'plain integers' and longs have unlimited precision:

$ python
Python 2.7.12 (default, Dec  4 2017, 14:50:18) 
[GCC 5.4.0 20160609] on linux2
Type "help", "copyright", "credits" or "license" for more information.
>>> import sys
>>> x  = sys.maxsize
>>> x
9223372036854775807
>>> type(x)
<type 'int'>
>>> y=2*x
>>> y
18446744073709551614L
>>> type(y)
<type 'long'>
>>> 

And unfortunately, there's not a real short answer to what 'total' memory means. We're wrapping the cuda call cudaMemGetInfo(), and the NVIDIA documentation says:

Returns in *free and *total respectively, the free and total amount of memory available for allocation by the device in bytes.

Let's say you've got a GPU with published memory T. The GPU driver puts some control structures like the page table in that memory, so call that driver overhead D. Finally, your GPU may be driving a monitor, so a window manager is using the GPU with overhead W. So what does the API return for 'total' in this scenario? The answer is T - D. The long answer then is: 'total' means the total memory available to both your MXNet process and other processes that may be using the GPU. I don't know a way to suggest this succinctly without introducing more confusion.

The number of GPUs.

"""
free = ctypes.c_uint64()
total = ctypes.c_uint64()
dev_id = ctypes.c_int(device_id)
check_call(_LIB.MXGetGPUMemoryInformation(dev_id, ctypes.byref(free), ctypes.byref(total)))
return (free.value, total.value)

def current_context():
"""Returns the current context.

Expand Down
2 changes: 1 addition & 1 deletion src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ int MXGetGPUCount(int* out) {
API_END();
}

int MXGetGPUMemoryInformation(int dev, int *free_mem, int *total_mem) {
int MXGetGPUMemoryInformation(int dev, size_t *free_mem, size_t *total_mem) {
API_BEGIN();
Context::GetGPUMemoryInformation(dev, free_mem, total_mem);
API_END();
Expand Down
48 changes: 46 additions & 2 deletions tests/python/gpu/test_gluon_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@
import mxnet as mx
import numpy as np
import unittest
import math
from nose.tools import assert_raises
from mxnet.test_utils import check_consistency, set_default_context, assert_almost_equal
from mxnet.base import MXNetError
from mxnet import autograd
from numpy.testing import assert_allclose


curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.insert(0, os.path.join(curr_path, '../unittest'))
from common import setup_module, with_seed, teardown, assert_raises_cudnn_disabled
Expand All @@ -57,7 +59,7 @@ def check_rnn_layer(layer):
for g, c in zip(gs, cs):
assert_almost_equal(g.asnumpy(), c.asnumpy(), rtol=1e-2, atol=1e-6)


@with_seed()
def check_rnn_layer_w_rand_inputs(layer):
layer.collect_params().initialize(ctx=[mx.cpu(0), mx.gpu(0)])
x = mx.nd.uniform(shape=(10, 16, 30))
Expand Down Expand Up @@ -186,7 +188,7 @@ def _syncParameters(bn1, bn2, ctx):
input2grad = mx.nd.concat(*[output.grad.as_in_context(input.context) for output in inputs2], dim=0)
assert_almost_equal(input1.grad.asnumpy(), input2grad.asnumpy(), atol=1e-3, rtol=1e-3)


@with_seed()
def test_sync_batchnorm():
def get_num_devices():
for i in range(100):
Expand All @@ -203,6 +205,7 @@ def get_num_devices():
_check_batchnorm_result(mx.nd.random.uniform(shape=(4, 1, 4, 4)),
num_devices=ndev, cuda=True)


@with_seed()
def test_symbol_block_fp16():
# Test case to verify if initializing the SymbolBlock from a model with params
Expand Down Expand Up @@ -233,6 +236,47 @@ def test_symbol_block_fp16():
break
assert np.dtype(net_fp16.params[param_name].dtype) == np.dtype(np.float16)


@with_seed()
def test_large_models():
ctx = default_context()
# Create model
net = gluon.nn.HybridSequential()

largest_num_features = 256
with net.name_scope():
net.add(nn.Conv2D(128, 3))
net.add(nn.LeakyReLU(0.1))
net.add(nn.Conv2D(largest_num_features, 3))
net.add(nn.LeakyReLU(0.1))
net.add(nn.Conv2D(1, 3))

net.hybridize()
net.initialize(mx.init.Normal(sigma=0.01), ctx=ctx)
mx.nd.waitall()
DickJC123 marked this conversation as resolved.
Show resolved Hide resolved

# The idea is to create models with large tensors of (say) 20% of the total memory.
# This in the past has given cudnnFind() trouble when it needed to allocate similar I/O's
# from the area carved out by the MXNET_GPU_MEM_POOL_RESERVE setting (by default 5%).
def tensor_size(memory_fraction):
bytes_per_float = 4
(free_mem_bytes, total_mem_bytes) = mx.context.gpu_memory_info(ctx.device_id)
big_tensor_size = total_mem_bytes * memory_fraction
sz = int(math.sqrt(big_tensor_size / largest_num_features / bytes_per_float))
return (sz // 100) * 100

start_size = tensor_size(0.20)
num_trials = 4
for i in range(num_trials):
sz = start_size - 10 * i
Copy link
Contributor

@KellenSunderland KellenSunderland Oct 23, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edit: ok I see what's going on here. Sorry for the confusion.

Is it intended that we test sz at 1200x1200 1190x1190 1180x1180 1170x1170 1160x1160 1150x1150 1140x1140 1130x1130 1120x1120 1110x1110? Is there a good reason to test these values?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't want to test the same size, since that would be a hit in the algo cache and cudnnFind() would not be re-run. The idea is to test different sizes that are roughly the size as chosen initially. I observed that with the progression of these models, the MXNet storage manager would slowly allocate up to 95% of memory by default (although much of it ends up in its free store). At that point, if MXNet's storage manager is asked for a big allocation that doesn't match the size of something in the free store, it will trigger a beneficial ReleaseAll():

    if (free <= total * reserve_ / 100 || size > free - total * reserve_ / 100)
      ReleaseAll();

However, if instead cudnnFind is called at that point, the large allocations it wants to make for I/Os and workspace will cause an out of memory or 'no algo found' error, despite the fact that unused allocations are in the free store. I'm using the existing Alloc/DirectFree API to possibly trigger a ReleaseAll() prior to calling cudnn within the algo-setting callback.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And to get back to your question more succinctly, there's nothing magic about the initial 1200x1200 starting point. It was calculated based on the 8GB global memory of the CI machines. I verified that the test initially failed, then passed, on 16GB and 32GB GPUs (with different sizes in each case).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah that makes total sense, you're just using dimensions to make sure we're not pulling the alg straight from the registry and bypassing the auto-tuning. Gotcha, thanks Dick.

(height, width) = (sz,sz)
print("Testing model with input = {}x{}".format(height,width))
data_in = nd.random_uniform(low=0, high=255, shape=(1, 3, height, width),
ctx=ctx, dtype="float32")
# Evaluate model
net(data_in).asnumpy()


if __name__ == '__main__':
import nose
nose.runmodule()