Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add test_gluon_gpu.py:test_large_models to show cudnnFind headroom is…
Browse files Browse the repository at this point in the history
…sue.
  • Loading branch information
DickJC123 committed Oct 11, 2018
1 parent 7ad40a2 commit 94614a5
Showing 1 changed file with 46 additions and 2 deletions.
48 changes: 46 additions & 2 deletions tests/python/gpu/test_gluon_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@
import mxnet as mx
import numpy as np
import unittest
import math
from nose.tools import assert_raises
from mxnet.test_utils import check_consistency, set_default_context, assert_almost_equal
from mxnet.base import MXNetError
from mxnet import autograd
from numpy.testing import assert_allclose


curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.insert(0, os.path.join(curr_path, '../unittest'))
from common import setup_module, with_seed, teardown, assert_raises_cudnn_disabled
Expand All @@ -57,7 +59,7 @@ def check_rnn_layer(layer):
for g, c in zip(gs, cs):
assert_almost_equal(g.asnumpy(), c.asnumpy(), rtol=1e-2, atol=1e-6)


@with_seed()
def check_rnn_layer_w_rand_inputs(layer):
layer.collect_params().initialize(ctx=[mx.cpu(0), mx.gpu(0)])
x = mx.nd.uniform(shape=(10, 16, 30))
Expand Down Expand Up @@ -186,7 +188,7 @@ def _syncParameters(bn1, bn2, ctx):
input2grad = mx.nd.concat(*[output.grad.as_in_context(input.context) for output in inputs2], dim=0)
assert_almost_equal(input1.grad.asnumpy(), input2grad.asnumpy(), atol=1e-3, rtol=1e-3)


@with_seed()
def test_sync_batchnorm():
def get_num_devices():
for i in range(100):
Expand All @@ -203,6 +205,7 @@ def get_num_devices():
_check_batchnorm_result(mx.nd.random.uniform(shape=(4, 1, 4, 4)),
num_devices=ndev, cuda=True)


@with_seed()
def test_symbol_block_fp16():
# Test case to verify if initializing the SymbolBlock from a model with params
Expand Down Expand Up @@ -233,6 +236,47 @@ def test_symbol_block_fp16():
break
assert np.dtype(net_fp16.params[param_name].dtype) == np.dtype(np.float16)


@with_seed()
def test_large_models():
ctx = default_context()
# Create model
net = gluon.nn.HybridSequential()

largest_num_features = 256
with net.name_scope():
net.add(nn.Conv2D(128, 3))
net.add(nn.LeakyReLU(0.1))
net.add(nn.Conv2D(largest_num_features, 3))
net.add(nn.LeakyReLU(0.1))
net.add(nn.Conv2D(1, 3))

net.hybridize()
net.initialize(mx.init.Normal(sigma=0.01), ctx=ctx)
mx.nd.waitall()

# The idea is to create models with large tensors of (say) 20% of the total memory.
# This in the past has given cudnnFind() trouble when it needed to allocate similar I/O's
# from the area carved out by the MXNET_GPU_MEM_POOL_RESERVE setting (by default 5%).
def tensor_size(memory_fraction):
bytes_per_float = 4
(free_mem_bytes, total_mem_bytes) = mx.context.gpu_memory_info(ctx.device_id)

This comment has been minimized.

Copy link
@marcoabreu

marcoabreu Oct 12, 2018

Contributor

Love this dynamic size! Could we maybe print the used values here for reproducibility?

This comment has been minimized.

Copy link
@DickJC123

DickJC123 via email Oct 12, 2018

Author Contributor
big_tensor_size = total_mem_bytes * memory_fraction
sz = int(math.sqrt(big_tensor_size / largest_num_features / bytes_per_float))
return (sz // 100) * 100

start_size = tensor_size(0.20)
num_trials = 4
for i in range(num_trials):
sz = start_size - 10 * i
(height, width) = (sz,sz)
print("Testing model with input = {}x{}".format(height,width))
data_in = nd.random_uniform(low=0, high=255, shape=(1, 3, height, width),
ctx=ctx, dtype="float32")
# Evaluate model
net(data_in).asnumpy()


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 94614a5

Please sign in to comment.