diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index 38a2733001f3..c549f6f59940 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -1998,3 +1998,20 @@ def compare_optimizer(opt1, opt2, shape, dtype, w_stype='default', g_stype='defa if compare_states: compare_ndarray_tuple(state1, state2, rtol=rtol, atol=atol) assert_almost_equal(w1.asnumpy(), w2.asnumpy(), rtol=rtol, atol=atol) + +class EnvManager: + def __init__(self, key, val): + self._key = key + self._next_val = val + self._prev_val = None + + def __enter__(self): + self._prev_val = os.environ.get(self._key) + os.environ[self._key] = self._next_val + + def __exit__(self, ptype, value, trace): + if self._prev_val: + os.environ[self._key] = self._prev_val + else: + del os.environ[self._key] + diff --git a/tests/python/gpu/test_device.py b/tests/python/gpu/test_device.py index 66772dc86c21..f11bb8cd883f 100644 --- a/tests/python/gpu/test_device.py +++ b/tests/python/gpu/test_device.py @@ -19,35 +19,23 @@ import numpy as np import unittest import os +import logging + +from mxnet.test_utils import EnvManager shapes = [(10), (100), (1000), (10000), (100000), (2,2), (2,3,4,5,6,7,8)] keys = [1,2,3,4,5,6,7] -num_gpus = len(mx.test_utils.list_gpus()) +num_gpus = mx.test_utils.list_gpus() if num_gpus > 8 : - print("The machine has {} gpus. We will run the test on 8 gpus.".format(num_gpus)) - print("There is a limit for all PCI-E hardware on creating number of P2P peers. The limit is 8.") + logging.warn("The machine has {} gpus. We will run the test on 8 gpus.".format(num_gpus)) + logging.warn("There is a limit for all PCI-E hardware on creating number of P2P peers. The limit is 8.") num_gpus = 8; gpus = range(1, 1+num_gpus) -class EnvManager: - def __init__(self, key, val): - self._key = key - self._next_val = val - self._prev_val = None - - def __enter__(self): - try: - self._prev_val = os.environ[self._key] - except KeyError: - self._prev_val = '' - os.environ[self._key] = self._next_val - - def __exit__(self, ptype, value, trace): - os.environ[self._key] = self._prev_val - +@unittest.skipIf(mx.context.num_gpus() < 1, "test_device_pushpull needs at least 1 GPU") def test_device_pushpull(): def check_dense_pushpull(kv_type): for shape, key in zip(shapes, keys): @@ -63,20 +51,16 @@ def check_dense_pushpull(kv_type): for x in range(n_gpus): assert(np.sum(np.abs((res[x]-n_gpus).asnumpy()))==0) - envs1 = '1' - key1 = 'MXNET_KVSTORE_TREE_ARRAY_BOUND' - envs2 = ['','1'] - key2 = 'MXNET_KVSTORE_USETREE' - for i in range(2): - for val2 in envs2: - with EnvManager(key2, val2): + kvstore_tree_array_bound = 'MXNET_KVSTORE_TREE_ARRAY_BOUND' + kvstore_usetree_values = ['','1'] + kvstore_usetree = 'MXNET_KVSTORE_USETREE' + for _ in range(2): + for x in kvstore_usetree_values: + with EnvManager(kvstore_usetree, x): check_dense_pushpull('local') check_dense_pushpull('device') - - os.environ[key1] = envs1 - os.environ[key1] = '' - - print ("Passed") + os.environ[kvstore_tree_array_bound] = '1' + del os.environ[kvstore_tree_array_bound] if __name__ == '__main__': test_device_pushpull() diff --git a/tests/python/gpu/test_kvstore_gpu.py b/tests/python/gpu/test_kvstore_gpu.py index 4232a590a5df..8ff8752f534a 100644 --- a/tests/python/gpu/test_kvstore_gpu.py +++ b/tests/python/gpu/test_kvstore_gpu.py @@ -21,7 +21,7 @@ import mxnet as mx import numpy as np import unittest -from mxnet.test_utils import assert_almost_equal, default_context +from mxnet.test_utils import assert_almost_equal, default_context, EnvManager curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) sys.path.insert(0, os.path.join(curr_path, '../unittest')) from common import setup_module, with_seed, teardown @@ -30,22 +30,6 @@ keys = [5, 7, 11] str_keys = ['b', 'c', 'd'] -class EnvManager: - def __init__(self, key, val): - self._key = key - self._next_val = val - self._prev_val = None - - def __enter__(self): - try: - self._prev_val = os.environ[self._key] - except KeyError: - self._prev_val = '' - os.environ[self._key] = self._next_val - - def __exit__(self, ptype, value, trace): - os.environ[self._key] = self._prev_val - def init_kv_with_str(stype='default', kv_type='local'): """init kv """ kv = mx.kv.create(kv_type)