From cd1688f2039396ad7d9eef6b4df441408a47a584 Mon Sep 17 00:00:00 2001 From: Dick Carter Date: Thu, 23 Jul 2020 11:17:10 -0700 Subject: [PATCH] Improve environment variable handling in unittests (#18424) --- include/mxnet/c_api_test.h | 16 ++ python/mxnet/test_utils.py | 113 ++++++++---- python/mxnet/util.py | 34 +++- src/c_api/c_api_test.cc | 22 +++ tests/python/gpu/test_device.py | 16 +- tests/python/gpu/test_fusion.py | 13 +- tests/python/gpu/test_gluon_gpu.py | 18 +- tests/python/gpu/test_kvstore_gpu.py | 10 +- tests/python/gpu/test_operator_gpu.py | 42 ++--- tests/python/unittest/common.py | 33 ++-- tests/python/unittest/test_autograd.py | 6 +- tests/python/unittest/test_base.py | 100 ++++++++--- tests/python/unittest/test_engine.py | 4 +- tests/python/unittest/test_engine_import.py | 14 +- tests/python/unittest/test_executor.py | 66 +++---- tests/python/unittest/test_gluon.py | 23 +-- tests/python/unittest/test_operator.py | 184 +++++++++----------- tests/python/unittest/test_subgraph_op.py | 28 ++- tests/python/unittest/test_symbol.py | 68 +++----- 19 files changed, 475 insertions(+), 335 deletions(-) diff --git a/include/mxnet/c_api_test.h b/include/mxnet/c_api_test.h index b7ba0cef04a3..df7079842657 100644 --- a/include/mxnet/c_api_test.h +++ b/include/mxnet/c_api_test.h @@ -75,6 +75,22 @@ MXNET_DLL int MXRemoveSubgraphPropertyOpNames(const char* prop_name); MXNET_DLL int MXRemoveSubgraphPropertyOpNamesV2(const char* prop_name); +/*! + * \brief Get the value of an environment variable as seen by the backend. + * \param name The name of the environment variable + * \param value The returned value of the environment variable + */ +MXNET_DLL int MXGetEnv(const char* name, + const char** value); + +/*! + * \brief Set the value of an environment variable from the backend. + * \param name The name of the environment variable + * \param value The desired value to set the environment variable `name` + */ +MXNET_DLL int MXSetEnv(const char* name, + const char* value); + #ifdef __cplusplus } #endif // __cplusplus diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index 3e068604954f..927d85788b39 100755 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -24,6 +24,7 @@ import numbers import sys import os +import platform import errno import logging import bz2 @@ -49,7 +50,7 @@ from .ndarray import array from .symbol import Symbol from .symbol.numpy import _Symbol as np_symbol -from .util import use_np # pylint: disable=unused-import +from .util import use_np, getenv, setenv # pylint: disable=unused-import from .runtime import Features from .numpy_extension import get_cuda_compute_capability @@ -2035,27 +2036,6 @@ def get_bz2_data(data_dir, data_name, url, data_origin_name): bz_file.close() os.remove(data_origin_name) -def set_env_var(key, val, default_val=""): - """Set environment variable - - Parameters - ---------- - - key : str - Env var to set - val : str - New value assigned to the env var - default_val : str, optional - Default value returned if the env var doesn't exist - - Returns - ------- - str - The value of env var before it is set to the new value - """ - prev_val = os.environ.get(key, default_val) - os.environ[key] = val - return prev_val def same_array(array1, array2): """Check whether two NDArrays sharing the same memory block @@ -2080,9 +2060,11 @@ def same_array(array1, array2): array1[:] -= 1 return same(array1.asnumpy(), array2.asnumpy()) + @contextmanager def discard_stderr(): - """Discards error output of a routine if invoked as: + """ + Discards error output of a routine if invoked as: with discard_stderr(): ... @@ -2471,22 +2453,79 @@ def same_symbol_structure(sym1, sym2): return True -class EnvManager(object): - """Environment variable setter and unsetter via with idiom""" - def __init__(self, key, val): - self._key = key - self._next_val = val - self._prev_val = None +@contextmanager +def environment(*args): + """ + Environment variable setter and unsetter via `with` idiom. - def __enter__(self): - self._prev_val = os.environ.get(self._key) - os.environ[self._key] = self._next_val + Takes a specification of env var names and desired values and adds those + settings to the environment in advance of running the body of the `with` + statement. The original environment state is restored afterwards, even + if exceptions are raised in the `with` body. - def __exit__(self, ptype, value, trace): - if self._prev_val: - os.environ[self._key] = self._prev_val - else: - del os.environ[self._key] + Parameters + ---------- + args: + if 2 args are passed: + name, desired_value strings of the single env var to update, or + if 1 arg is passed: + a dict of name:desired_value for env var's to update + + """ + + # On Linux, env var changes made through python's os.environ are seen + # by the backend. On Windows though, the C runtime gets a snapshot + # of the environment that cannot be altered by os.environ. Here we + # check, using a wrapped version of the backend's getenv(), that + # the desired env var value is seen by the backend, and otherwise use + # a wrapped setenv() to establish that value in the backend. + + # Also on Windows, a set env var can never have the value '', since + # the command 'set FOO= ' is used to unset the variable. Perhaps + # as a result, the wrapped dmlc::GetEnv() routine returns the same + # value for unset variables and those set to ''. As a result, we + # ignore discrepancy. + def validate_backend_setting(name, value, can_use_setenv=True): + backend_value = getenv(name) + if value == backend_value or \ + value == '' and backend_value is None and platform.system() == 'Windows': + return + if not can_use_setenv: + raise RuntimeError('Could not set env var {}={} within C Runtime'.format(name, value)) + setenv(name, value) + validate_backend_setting(name, value, can_use_setenv=False) + + # Core routine to alter environment from a dict of env_var_name, env_var_value pairs + def set_environ(env_var_dict): + for env_var_name, env_var_value in env_var_dict.items(): + if env_var_value is None: + os.environ.pop(env_var_name, None) + else: + os.environ[env_var_name] = env_var_value + validate_backend_setting(env_var_name, env_var_value) + + # Create env_var name:value dict from the two calling methods of this routine + if len(args) == 1 and isinstance(args[0], dict): + env_vars = args[0] + else: + assert len(args) == 2, 'Expecting one dict arg or two args: env var name and value' + env_vars = {args[0]: args[1]} + + # Take a snapshot of the existing environment variable state + # for those variables to be changed. get() return None for unset keys. + snapshot = {x: os.environ.get(x) for x in env_vars.keys()} + + # Alter the environment per the env_vars dict + set_environ(env_vars) + + # Now run the wrapped code + try: + yield + finally: + # the backend engines may still be referencing the changed env var state + mx.nd.waitall() + # reinstate original env_var state per the snapshot taken earlier + set_environ(snapshot) def collapse_sum_like(a, shape): diff --git a/python/mxnet/util.py b/python/mxnet/util.py index 54beeb5a875a..aabd5fe9cdfe 100644 --- a/python/mxnet/util.py +++ b/python/mxnet/util.py @@ -21,7 +21,7 @@ import inspect import threading -from .base import _LIB, check_call +from .base import _LIB, check_call, c_str, py_str _np_ufunc_default_kwargs = { @@ -816,3 +816,35 @@ def get_cuda_compute_capability(ctx): raise RuntimeError('cuDeviceComputeCapability failed with error code {}: {}' .format(ret, error_str.value.decode())) return cc_major.value * 10 + cc_minor.value + + +def getenv(name): + """Get the setting of an environment variable from the C Runtime. + + Parameters + ---------- + name : string type + The environment variable name + + Returns + ------- + value : string + The value of the environment variable, or None if not set + """ + ret = ctypes.c_char_p() + check_call(_LIB.MXGetEnv(c_str(name), ctypes.byref(ret))) + return None if ret.value is None else py_str(ret.value) + + +def setenv(name, value): + """Set an environment variable in the C Runtime. + + Parameters + ---------- + name : string type + The environment variable name + value : string type + The desired value to set the environment value to + """ + passed_value = None if value is None else c_str(value) + check_call(_LIB.MXSetEnv(c_str(name), passed_value)) diff --git a/src/c_api/c_api_test.cc b/src/c_api/c_api_test.cc index de4fb7dca18e..e84b0c0b1395 100644 --- a/src/c_api/c_api_test.cc +++ b/src/c_api/c_api_test.cc @@ -106,3 +106,25 @@ int MXRemoveSubgraphPropertyOpNamesV2(const char* prop_name) { } API_END(); } + +int MXGetEnv(const char* name, + const char** value) { + API_BEGIN(); + *value = getenv(name); + API_END(); +} + +int MXSetEnv(const char* name, + const char* value) { + API_BEGIN(); +#ifdef _WIN32 + auto value_arg = (value == nullptr) ? "" : value; + _putenv_s(name, value_arg); +#else + if (value == nullptr) + unsetenv(name); + else + setenv(name, value, 1); +#endif + API_END(); +} diff --git a/tests/python/gpu/test_device.py b/tests/python/gpu/test_device.py index cd8145c3deac..8a6fb3ac6101 100644 --- a/tests/python/gpu/test_device.py +++ b/tests/python/gpu/test_device.py @@ -20,8 +20,7 @@ import unittest import os import logging - -from mxnet.test_utils import EnvManager +from mxnet.test_utils import environment shapes = [(10), (100), (1000), (10000), (100000), (2,2), (2,3,4,5,6,7,8)] keys = [1,2,3,4,5,6,7] @@ -51,16 +50,15 @@ def check_dense_pushpull(kv_type): for x in range(n_gpus): assert(np.sum(np.abs((res[x]-n_gpus).asnumpy()))==0) - kvstore_tree_array_bound = 'MXNET_KVSTORE_TREE_ARRAY_BOUND' - kvstore_usetree_values = ['','1'] - kvstore_usetree = 'MXNET_KVSTORE_USETREE' - for _ in range(2): + kvstore_tree_array_bound_values = [None, '1'] + kvstore_usetree_values = [None, '1'] + for y in kvstore_tree_array_bound_values: for x in kvstore_usetree_values: - with EnvManager(kvstore_usetree, x): + with environment({'MXNET_KVSTORE_USETREE': x, + 'MXNET_KVSTORE_TREE_ARRAY_BOUND': y}): check_dense_pushpull('local') check_dense_pushpull('device') - os.environ[kvstore_tree_array_bound] = '1' - del os.environ[kvstore_tree_array_bound] + if __name__ == '__main__': test_device_pushpull() diff --git a/tests/python/gpu/test_fusion.py b/tests/python/gpu/test_fusion.py index 61fba10913cc..2803869da01a 100644 --- a/tests/python/gpu/test_fusion.py +++ b/tests/python/gpu/test_fusion.py @@ -15,15 +15,17 @@ # specific language governing permissions and limitations # under the License. +import sys import os import random +import itertools import mxnet as mx import numpy as np from mxnet.test_utils import * curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) sys.path.insert(0, os.path.join(curr_path, '../unittest')) -from common import with_seed +from common import setup_module, teardown_module, with_seed def check_fused_symbol(sym, **kwargs): inputs = sym.list_inputs() @@ -43,10 +45,10 @@ def check_fused_symbol(sym, **kwargs): data = {inp : kwargs[inp].astype(dtype) for inp in inputs} for grad_req in ['write', 'add']: type_dict = {inp : dtype for inp in inputs} - os.environ["MXNET_USE_FUSION"] = "0" - orig_exec = test_sym.simple_bind(ctx=ctx, grad_req=grad_req, type_dict=type_dict, **shapes) - os.environ["MXNET_USE_FUSION"] = "1" - fused_exec = test_sym.simple_bind(ctx=ctx, grad_req=grad_req, type_dict=type_dict, **shapes) + with environment('MXNET_USE_FUSION', '0'): + orig_exec = test_sym.simple_bind(ctx=ctx, grad_req=grad_req, type_dict=type_dict, **shapes) + with environment('MXNET_USE_FUSION', '1'): + fused_exec = test_sym.simple_bind(ctx=ctx, grad_req=grad_req, type_dict=type_dict, **shapes) fwd_orig = orig_exec.forward(is_train=True, **data) out_grads = [mx.nd.ones_like(arr) for arr in fwd_orig] orig_exec.backward(out_grads=out_grads) @@ -227,6 +229,7 @@ def check_other_ops(): arr2 = mx.random.uniform(shape=(2,2,2,3)) check_fused_symbol(mx.sym.broadcast_like(a, b, lhs_axes=[0], rhs_axes=[0]), a=arr1, b=arr2) + def check_leakyrelu_ops(): a = mx.sym.Variable('a') b = mx.sym.Variable('b') diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index 52280bf898a5..60a90c9f5c0c 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -22,7 +22,7 @@ import time import mxnet as mx import multiprocessing as mp -from mxnet.test_utils import check_consistency, set_default_context, assert_almost_equal, rand_ndarray +from mxnet.test_utils import check_consistency, set_default_context, assert_almost_equal, rand_ndarray, environment import mxnet.ndarray as nd import numpy as np import math @@ -555,9 +555,9 @@ def _test_bulking(test_bulking_func): time_per_iteration = mp.Manager().Value('d', 0.0) if not run_in_spawned_process(test_bulking_func, - {'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD': seg_sizes[0], - 'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD': seg_sizes[1], - 'MXNET_EXEC_BULK_EXEC_TRAIN': seg_sizes[2]}, + {'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD': str(seg_sizes[0]), + 'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD': str(seg_sizes[1]), + 'MXNET_EXEC_BULK_EXEC_TRAIN': str(seg_sizes[2])}, time_per_iteration): # skip test since the python version can't run it properly. Warning msg was logged. return @@ -631,15 +631,17 @@ def test_gemms_true_fp16(): net.cast('float16') net.initialize(ctx=ctx) net.weight.set_data(weights) - ref_results = net(input) - os.environ["MXNET_FC_TRUE_FP16"] = "1" - results_trueFP16 = net(input) + with environment('MXNET_FC_TRUE_FP16', '0'): + ref_results = net(input) + + with environment('MXNET_FC_TRUE_FP16', '1'): + results_trueFP16 = net(input) + atol = 1e-2 rtol = 1e-2 assert_almost_equal(ref_results.asnumpy(), results_trueFP16.asnumpy(), atol=atol, rtol=rtol) - os.environ["MXNET_FC_TRUE_FP16"] = "0" if __name__ == '__main__': diff --git a/tests/python/gpu/test_kvstore_gpu.py b/tests/python/gpu/test_kvstore_gpu.py index 1dddc5889643..8473dd3a8010 100644 --- a/tests/python/gpu/test_kvstore_gpu.py +++ b/tests/python/gpu/test_kvstore_gpu.py @@ -21,7 +21,7 @@ import mxnet as mx import numpy as np import unittest -from mxnet.test_utils import assert_almost_equal, default_context, EnvManager +from mxnet.test_utils import assert_almost_equal, default_context, environment curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) sys.path.insert(0, os.path.join(curr_path, '../unittest')) from common import setup_module, with_seed, teardown @@ -97,11 +97,11 @@ def check_rsp_pull(kv, ctxs, sparse_pull, is_same_rowid=False, use_slice=False): check_rsp_pull(kv, [mx.gpu(i//2) for i in range(4)], sparse_pull, use_slice=True) check_rsp_pull(kv, [mx.cpu(i) for i in range(4)], sparse_pull, use_slice=True) - envs = ["","1"] - key = "MXNET_KVSTORE_USETREE" + envs = [None, '1'] + key = 'MXNET_KVSTORE_USETREE' for val in envs: - with EnvManager(key, val): - if val is "1": + with environment(key, val): + if val is '1': sparse_pull = False else: sparse_pull = True diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index bcf906a92e44..5fee473554e4 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -28,6 +28,8 @@ import mxnet.ndarray.sparse as mxsps import itertools from mxnet.test_utils import check_consistency, set_default_context, assert_almost_equal, assert_allclose +from mxnet.test_utils import check_symbolic_forward, check_symbolic_backward, discard_stderr +from mxnet.test_utils import default_context, rand_shape_2d, rand_ndarray, same, environment from mxnet.base import MXNetError from mxnet import autograd @@ -755,12 +757,12 @@ def _conv_with_num_streams(seed): @unittest.skip("skipping for now due to severe flakiness") @with_seed() def test_convolution_multiple_streams(): - for num_streams in [1, 2]: + for num_streams in ['1', '2']: for engine in ['NaiveEngine', 'ThreadedEngine', 'ThreadedEnginePerDevice']: - print("Starting engine %s with %d streams." % (engine, num_streams), file=sys.stderr) + print('Starting engine {} with {} streams.'.format(engine, num_streams), file=sys.stderr) run_in_spawned_process(_conv_with_num_streams, {'MXNET_GPU_WORKER_NSTREAMS' : num_streams, 'MXNET_ENGINE_TYPE' : engine}) - print("Finished engine %s with %d streams." % (engine, num_streams), file=sys.stderr) + print('Finished engine {} with {} streams.'.format(engine, num_streams), file=sys.stderr) # This test is designed to expose an issue with cudnn v7.1.4 algo find() when invoked with large c. @@ -2229,22 +2231,22 @@ def check_proposal_consistency(op, batch_size, with_nms=False): # The following 2 functions launch 0-thread kernels, an error that should be caught and signaled. def kernel_error_check_imperative(): - os.environ['MXNET_ENGINE_TYPE'] = 'NaiveEngine' - with mx.np_shape(active=True): - a = mx.nd.array([1,2,3],ctx=mx.gpu(0)) - b = mx.nd.array([],ctx=mx.gpu(0)) - c = (a / b).asnumpy() + with environment('MXNET_ENGINE_TYPE', 'NaiveEngine'): + with mx.np_shape(active=True): + a = mx.nd.array([1,2,3],ctx=mx.gpu(0)) + b = mx.nd.array([],ctx=mx.gpu(0)) + c = (a / b).asnumpy() def kernel_error_check_symbolic(): - os.environ['MXNET_ENGINE_TYPE'] = 'NaiveEngine' - with mx.np_shape(active=True): - a = mx.sym.Variable('a') - b = mx.sym.Variable('b') - c = a / b - f = c.bind(mx.gpu(0), { 'a':mx.nd.array([1,2,3],ctx=mx.gpu(0)), - 'b':mx.nd.array([],ctx=mx.gpu(0))}) - f.forward() - g = f.outputs[0].asnumpy() + with environment('MXNET_ENGINE_TYPE', 'NaiveEngine'): + with mx.np_shape(active=True): + a = mx.sym.Variable('a') + b = mx.sym.Variable('b') + c = a / b + f = c.bind(mx.gpu(0), {'a':mx.nd.array([1,2,3],ctx=mx.gpu(0)), + 'b':mx.nd.array([],ctx=mx.gpu(0))}) + f.forward() + g = f.outputs[0].asnumpy() def test_kernel_error_checking(): # Running tests that may throw exceptions out of worker threads will stop CI testing @@ -2440,9 +2442,9 @@ def test_bulking(): # Create shared variable to return measured time from test process time_per_iteration = mp.Manager().Value('d', 0.0) if not run_in_spawned_process(_test_bulking_in_process, - {'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD' : seg_sizes[0], - 'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD' : seg_sizes[1], - 'MXNET_EXEC_BULK_EXEC_TRAIN' : seg_sizes[2]}, + {'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD' : str(seg_sizes[0]), + 'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD' : str(seg_sizes[1]), + 'MXNET_EXEC_BULK_EXEC_TRAIN' : str(seg_sizes[2])}, time_per_iteration): # skip test since the python version can't run it properly. Warning msg was logged. return diff --git a/tests/python/unittest/common.py b/tests/python/unittest/common.py index 8e4e2e35f0cc..74bd9c8f752b 100644 --- a/tests/python/unittest/common.py +++ b/tests/python/unittest/common.py @@ -23,6 +23,7 @@ import random import shutil from mxnet.base import MXNetError +from mxnet.test_utils import environment curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) sys.path.append(os.path.join(curr_path, '../common/')) sys.path.insert(0, os.path.join(curr_path, '../../../python')) @@ -208,15 +209,17 @@ def test_new(*args, **kwargs): logger = default_logger() # 'nosetests --logging-level=DEBUG' shows this msg even with an ensuing core dump. test_count_msg = '{} of {}: '.format(i+1,test_count) if test_count > 1 else '' - test_msg = ('{}Setting test np/mx/python random seeds, use MXNET_TEST_SEED={}' - ' to reproduce.').format(test_count_msg, this_test_seed) - logger.log(log_level, test_msg) + pre_test_msg = ('{}Setting test np/mx/python random seeds, use MXNET_TEST_SEED={}' + ' to reproduce.').format(test_count_msg, this_test_seed) + on_err_test_msg = ('{}Error seen with seeded test, use MXNET_TEST_SEED={}' + ' to reproduce.').format(test_count_msg, this_test_seed) + logger.log(log_level, pre_test_msg) try: orig_test(*args, **kwargs) except: # With exceptions, repeat test_msg at WARNING level to be sure it's seen. if log_level < logging.WARNING: - logger.warning(test_msg) + logger.warning(on_err_test_msg) raise finally: # Provide test-isolation for any test having this decorator @@ -329,6 +332,20 @@ def test_new(*args, **kwargs): finally: mx.nd.waitall() mx.cpu().empty_cache() + + +def with_environment(*args_): + """ + Helper function that takes a dictionary of environment variables and their + desired settings and changes the environment in advance of running the + decorated code. The original environment state is reinstated afterwards, + even if exceptions are raised. + """ + def test_helper(orig_test): + @functools.wraps(orig_test) + def test_new(*args, **kwargs): + with environment(*args_): + orig_test(*args, **kwargs) return test_new return test_helper @@ -363,16 +380,10 @@ def run_in_spawned_process(func, env, *args): return False else: seed = np.random.randint(0,1024*1024*1024) - orig_environ = os.environ.copy() - try: - for (key, value) in env.items(): - os.environ[key] = str(value) + with environment(env): # Prepend seed as first arg p = mpctx.Process(target=func, args=(seed,)+args) p.start() p.join() assert p.exitcode == 0, "Non-zero exit code %d from %s()." % (p.exitcode, func.__name__) - finally: - os.environ.clear() - os.environ.update(orig_environ) return True diff --git a/tests/python/unittest/test_autograd.py b/tests/python/unittest/test_autograd.py index 61955f034a71..caff307a8bb0 100644 --- a/tests/python/unittest/test_autograd.py +++ b/tests/python/unittest/test_autograd.py @@ -21,7 +21,7 @@ from mxnet.autograd import * from mxnet.test_utils import * from common import setup_module, with_seed, teardown -from mxnet.test_utils import EnvManager +from mxnet.test_utils import environment def grad_and_loss(func, argnum=None): @@ -121,7 +121,7 @@ def check_unary_func(x): autograd_assert(x, func=f_square, grad_func=f_square_grad) uniform = nd.uniform(shape=(4, 5)) stypes = ['default', 'row_sparse', 'csr'] - with EnvManager('MXNET_STORAGE_FALLBACK_LOG_VERBOSE', '0'): + with environment('MXNET_STORAGE_FALLBACK_LOG_VERBOSE', '0'): for stype in stypes: check_unary_func(uniform.tostype(stype)) @@ -140,7 +140,7 @@ def check_binary_func(x, y): uniform_x = nd.uniform(shape=(4, 5)) uniform_y = nd.uniform(shape=(4, 5)) stypes = ['default', 'row_sparse', 'csr'] - with EnvManager('MXNET_STORAGE_FALLBACK_LOG_VERBOSE', '0'): + with environment('MXNET_STORAGE_FALLBACK_LOG_VERBOSE', '0'): for stype_x in stypes: for stype_y in stypes: x = uniform_x.tostype(stype_x) diff --git a/tests/python/unittest/test_base.py b/tests/python/unittest/test_base.py index 3189729e1d10..5744dccf8cd6 100644 --- a/tests/python/unittest/test_base.py +++ b/tests/python/unittest/test_base.py @@ -16,35 +16,93 @@ # under the License. import mxnet as mx +from numpy.testing import assert_equal from mxnet.base import data_dir from nose.tools import * +from mxnet.test_utils import environment +from mxnet.util import getenv +from common import setup_module, teardown, with_environment import os -import unittest import logging import os.path as op import platform -class MXNetDataDirTest(unittest.TestCase): - def setUp(self): - self.mxnet_data_dir = os.environ.get('MXNET_HOME') - if 'MXNET_HOME' in os.environ: - del os.environ['MXNET_HOME'] +def test_environment(): + name1 = 'MXNET_TEST_ENV_VAR_1' + name2 = 'MXNET_TEST_ENV_VAR_2' - def tearDown(self): - if self.mxnet_data_dir: - os.environ['MXNET_HOME'] = self.mxnet_data_dir - else: - if 'MXNET_HOME' in os.environ: - del os.environ['MXNET_HOME'] + # Test that a variable can be set in the python and backend environment + with environment(name1, '42'): + assert_equal(os.environ.get(name1), '42') + assert_equal(getenv(name1), '42') + + # Test dict form of invocation + env_var_dict = {name1: '1', name2: '2'} + with environment(env_var_dict): + for key, value in env_var_dict.items(): + assert_equal(os.environ.get(key), value) + assert_equal(getenv(key), value) + + # Further testing in 'test_with_environment()' + +@with_environment({'MXNET_TEST_ENV_VAR_1': '10', 'MXNET_TEST_ENV_VAR_2': None}) +def test_with_environment(): + name1 = 'MXNET_TEST_ENV_VAR_1' + name2 = 'MXNET_TEST_ENV_VAR_2' + def check_background_values(): + assert_equal(os.environ.get(name1), '10') + assert_equal(getenv(name1), '10') + assert_equal(os.environ.get(name2), None) + assert_equal(getenv(name2), None) + + check_background_values() - def test_data_dir(self,): - prev_data_dir = data_dir() - system = platform.system() - if system != 'Windows': - self.assertEqual(data_dir(), op.join(op.expanduser('~'), '.mxnet')) - os.environ['MXNET_HOME'] = '/tmp/mxnet_data' - self.assertEqual(data_dir(), '/tmp/mxnet_data') - del os.environ['MXNET_HOME'] - self.assertEqual(data_dir(), prev_data_dir) + # This completes the testing of with_environment(), but since we have + # an environment with a couple of known settings, lets use it to test if + # 'with environment()' properly restores to these settings in all cases. + class OnPurposeError(Exception): + """A class for exceptions thrown by this test""" + pass + # Enter an environment with one variable set and check it appears + # to both python and the backend. Then, outside the 'with' block, + # make sure the background environment is seen, regardless of whether + # the 'with' block raised an exception. + def test_one_var(name, value, raise_exception=False): + try: + with environment(name, value): + assert_equal(os.environ.get(name), value) + assert_equal(getenv(name), value) + if raise_exception: + raise OnPurposeError + except OnPurposeError: + pass + finally: + check_background_values() + + # Test various combinations of set and unset env vars. + # Test that the background setting is restored in the presense of exceptions. + for raise_exception in [False, True]: + # name1 is initially set in the environment + test_one_var(name1, '42', raise_exception) + test_one_var(name1, None, raise_exception) + # name2 is initially not set in the environment + test_one_var(name2, '42', raise_exception) + test_one_var(name2, None, raise_exception) + + +def test_data_dir(): + prev_data_dir = data_dir() + system = platform.system() + # Test that data_dir() returns the proper default value when MXNET_HOME is not set + with environment('MXNET_HOME', None): + if system == 'Windows': + assert_equal(data_dir(), op.join(os.environ.get('APPDATA'), 'mxnet')) + else: + assert_equal(data_dir(), op.join(op.expanduser('~'), '.mxnet')) + # Test that data_dir() responds to an explicit setting of MXNET_HOME + with environment('MXNET_HOME', '/tmp/mxnet_data'): + assert_equal(data_dir(), '/tmp/mxnet_data') + # Test that this test has not disturbed the MXNET_HOME value existing before the test + assert_equal(data_dir(), prev_data_dir) diff --git a/tests/python/unittest/test_engine.py b/tests/python/unittest/test_engine.py index 61d94ddbf4ec..ea02551aef00 100644 --- a/tests/python/unittest/test_engine.py +++ b/tests/python/unittest/test_engine.py @@ -19,7 +19,7 @@ import mxnet as mx import os import unittest -from mxnet.test_utils import EnvManager +from mxnet.test_utils import environment def test_bulk(): with mx.engine.bulk(10): @@ -42,7 +42,7 @@ def test_engine_openmp_after_fork(): With GOMP the child always has the same number when calling omp_get_max_threads, with LLVM OMP the child respects the number of max threads set in the parent. """ - with EnvManager('OMP_NUM_THREADS', '42'): + with environment('OMP_NUM_THREADS', '42'): r, w = os.pipe() pid = os.fork() if pid: diff --git a/tests/python/unittest/test_engine_import.py b/tests/python/unittest/test_engine_import.py index 303f3ceb1dee..ed56531b2f45 100644 --- a/tests/python/unittest/test_engine_import.py +++ b/tests/python/unittest/test_engine_import.py @@ -16,6 +16,8 @@ # under the License. import os +from mxnet.test_utils import environment +import unittest try: reload # Python 2 @@ -23,17 +25,15 @@ from importlib import reload +@unittest.skip('test needs improving, current use of reload(mxnet) is ineffective') def test_engine_import(): import mxnet - - engine_types = ['', 'NaiveEngine', 'ThreadedEngine', 'ThreadedEnginePerDevice'] + # Temporarily add an illegal entry (that is not caught) to show how the test needs improving + engine_types = [None, 'NaiveEngine', 'ThreadedEngine', 'ThreadedEnginePerDevice', 'BogusEngine'] for type in engine_types: - if type: - os.environ['MXNET_ENGINE_TYPE'] = type - else: - os.environ.pop('MXNET_ENGINE_TYPE', None) - reload(mxnet) + with environment('MXNET_ENGINE_TYPE', type): + reload(mxnet) if __name__ == '__main__': diff --git a/tests/python/unittest/test_executor.py b/tests/python/unittest/test_executor.py index 2bc696fd4e43..29fda3b08f55 100644 --- a/tests/python/unittest/test_executor.py +++ b/tests/python/unittest/test_executor.py @@ -18,7 +18,7 @@ import numpy as np import mxnet as mx from common import setup_module, with_seed, teardown -from mxnet.test_utils import assert_almost_equal +from mxnet.test_utils import assert_almost_equal, environment def check_bind_with_uniform(uf, gf, dim, sf=None, lshape=None, rshape=None): @@ -74,42 +74,34 @@ def check_bind_with_uniform(uf, gf, dim, sf=None, lshape=None, rshape=None): @with_seed() def test_bind(): - def check_bind(disable_bulk_exec): - if disable_bulk_exec: - prev_bulk_inf_val = mx.test_utils.set_env_var("MXNET_EXEC_BULK_EXEC_INFERENCE", "0", "1") - prev_bulk_train_val = mx.test_utils.set_env_var("MXNET_EXEC_BULK_EXEC_TRAIN", "0", "1") - - nrepeat = 10 - maxdim = 4 - for repeat in range(nrepeat): - for dim in range(1, maxdim): - check_bind_with_uniform(lambda x, y: x + y, - lambda g, x, y: (g, g), - dim) - check_bind_with_uniform(lambda x, y: x - y, - lambda g, x, y: (g, -g), - dim) - check_bind_with_uniform(lambda x, y: x * y, - lambda g, x, y: (y * g, x * g), - dim) - check_bind_with_uniform(lambda x, y: x / y, - lambda g, x, y: (g / y, -x * g/ (y**2)), - dim) - - check_bind_with_uniform(lambda x, y: np.maximum(x, y), - lambda g, x, y: (g * (x>=y), g * (y>x)), - dim, - sf=mx.symbol.maximum) - check_bind_with_uniform(lambda x, y: np.minimum(x, y), - lambda g, x, y: (g * (x<=y), g * (y=y), g * (y>x)), + dim, + sf=mx.symbol.maximum) + check_bind_with_uniform(lambda x, y: np.minimum(x, y), + lambda g, x, y: (g * (x<=y), g * (y