From 02ae456ef0e4eef86455b0a39d5ccabfd5b29668 Mon Sep 17 00:00:00 2001 From: Dick Carter Date: Thu, 23 Jul 2020 11:17:10 -0700 Subject: [PATCH] Improve environment variable handling in unittests (#18424) This PR makes it easy to create unittests that require specific settings of environment variables, while avoiding the pitfalls (discussed in comments section). This PR can be considered a recasting and expansion of the great vision of @larroy in creating the EnvManager class in #13140. In its base form, the facility is a drop-in replacement for EnvManager, and is called 'environment': with environment('MXNET_MY_NEW_FEATURE', '1'): with environment('MXNET_MY_NEW_FEATURE', '0'): Like EnvManager, this facility takes care of the save/restore of the previous environment variable state, including when exceptions are raised. In addition though, this PR introduces the features: A similarly-named unittest decorator: @with_environment(key, value) The ability to pass in multiple env vars as a dict (as is needed for some tests) in both forms, so for example: with environment({'MXNET_FEATURE_A': '1', 'MXNET_FEATURE_B': '1'}): Works on Windows! This PR includes a wrapping of the backend's setenv() and getenv() functions, and uses this direct access to the backend environment to keep it in sync with the python environment. This works around the problem that the C Runtime on Windows gets a snapshot of the Python environment at startup that is immutable from Python. with environment() has a simple implementation using the @contextmanager decorator Tests are included that validate the facility works with all combinations of before_val/set_val, namely unset/unset, unset/set, set/unset, set/set. There were 5 unittests previously using EnvManager, and this PR shifts those uses to with environment():, while converting over 20 other ad-hoc uses of os.environ[] within the unittests. This PR also enables those unittests that were bypassed on Windows (due to the inability to set environment variables) to run on all platforms. Further Comments Environment variables are a two-edged sword- they enable useful operating modes for testing, debugging or niche applications, but like all features they must be tested. The correct approach for testing with a particular env var setting is: def set_env_var(key, value): if value is None: os.environ.pop(key, None) else: os.environ[key] = value old_env_var_value = os.environ.get(env_var_name) try: set_env_var(env_var_name, test_env_var_value) finally: set_env_var(env_var_name, old_env_var_value ) The above code makes no assumption about whether the before-test and within-test state of the env var is set or unset, and restores the prior environment even if the test raises an exception. This represents a lot of boiler-plate code that could be potentially mishandled. The with environment() context makes it simple to handle all this properly. If an entire unittest wants a forced env var setting, then using the @with_environment() decorator avoids the code indent of the with environment() approach if used otherwise within the test. --- include/mxnet/c_api_test.h | 16 ++ python/mxnet/test_utils.py | 113 ++++++++---- python/mxnet/util.py | 35 +++- src/c_api/c_api_test.cc | 22 +++ tests/python/gpu/test_device.py | 16 +- tests/python/gpu/test_fusion.py | 39 ++-- tests/python/gpu/test_gluon_gpu.py | 18 +- tests/python/gpu/test_kvstore_gpu.py | 10 +- tests/python/gpu/test_operator_gpu.py | 42 ++--- tests/python/unittest/common.py | 35 ++-- tests/python/unittest/test_autograd.py | 7 +- tests/python/unittest/test_base.py | 103 ++++++++--- tests/python/unittest/test_engine.py | 4 +- tests/python/unittest/test_engine_import.py | 14 +- tests/python/unittest/test_executor.py | 66 +++---- tests/python/unittest/test_gluon.py | 23 +-- .../unittest/test_gluon_probability_v1.py | 2 +- .../unittest/test_gluon_probability_v2.py | 2 +- tests/python/unittest/test_memory_opt.py | 38 +--- tests/python/unittest/test_operator.py | 170 +++++++++--------- tests/python/unittest/test_subgraph_op.py | 28 ++- tests/python/unittest/test_symbol.py | 68 +++---- 22 files changed, 491 insertions(+), 380 deletions(-) diff --git a/include/mxnet/c_api_test.h b/include/mxnet/c_api_test.h index b7ba0cef04a3..df7079842657 100644 --- a/include/mxnet/c_api_test.h +++ b/include/mxnet/c_api_test.h @@ -75,6 +75,22 @@ MXNET_DLL int MXRemoveSubgraphPropertyOpNames(const char* prop_name); MXNET_DLL int MXRemoveSubgraphPropertyOpNamesV2(const char* prop_name); +/*! + * \brief Get the value of an environment variable as seen by the backend. + * \param name The name of the environment variable + * \param value The returned value of the environment variable + */ +MXNET_DLL int MXGetEnv(const char* name, + const char** value); + +/*! + * \brief Set the value of an environment variable from the backend. + * \param name The name of the environment variable + * \param value The desired value to set the environment variable `name` + */ +MXNET_DLL int MXSetEnv(const char* name, + const char* value); + #ifdef __cplusplus } #endif // __cplusplus diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index 9ec0c6cd1219..cb71c7dd46b9 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -24,6 +24,7 @@ import numbers import sys import os +import platform import errno import logging import bz2 @@ -48,7 +49,7 @@ from .ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID from .symbol import Symbol from .symbol.numpy import _Symbol as np_symbol -from .util import use_np, use_np_default_dtype # pylint: disable=unused-import +from .util import use_np, use_np_default_dtype, getenv, setenv # pylint: disable=unused-import from .runtime import Features from .numpy_extension import get_cuda_compute_capability @@ -1920,27 +1921,6 @@ def get_bz2_data(data_dir, data_name, url, data_origin_name): bz_file.close() os.remove(data_origin_name) -def set_env_var(key, val, default_val=""): - """Set environment variable - - Parameters - ---------- - - key : str - Env var to set - val : str - New value assigned to the env var - default_val : str, optional - Default value returned if the env var doesn't exist - - Returns - ------- - str - The value of env var before it is set to the new value - """ - prev_val = os.environ.get(key, default_val) - os.environ[key] = val - return prev_val def same_array(array1, array2): """Check whether two NDArrays sharing the same memory block @@ -1965,9 +1945,11 @@ def same_array(array1, array2): array1[:] -= 1 return same(array1.asnumpy(), array2.asnumpy()) + @contextmanager def discard_stderr(): - """Discards error output of a routine if invoked as: + """ + Discards error output of a routine if invoked as: with discard_stderr(): ... @@ -2400,22 +2382,79 @@ def same_symbol_structure(sym1, sym2): return True -class EnvManager(object): - """Environment variable setter and unsetter via with idiom""" - def __init__(self, key, val): - self._key = key - self._next_val = val - self._prev_val = None +@contextmanager +def environment(*args): + """ + Environment variable setter and unsetter via `with` idiom. - def __enter__(self): - self._prev_val = os.environ.get(self._key) - os.environ[self._key] = self._next_val + Takes a specification of env var names and desired values and adds those + settings to the environment in advance of running the body of the `with` + statement. The original environment state is restored afterwards, even + if exceptions are raised in the `with` body. - def __exit__(self, ptype, value, trace): - if self._prev_val: - os.environ[self._key] = self._prev_val - else: - del os.environ[self._key] + Parameters + ---------- + args: + if 2 args are passed: + name, desired_value strings of the single env var to update, or + if 1 arg is passed: + a dict of name:desired_value for env var's to update + + """ + + # On Linux, env var changes made through python's os.environ are seen + # by the backend. On Windows though, the C runtime gets a snapshot + # of the environment that cannot be altered by os.environ. Here we + # check, using a wrapped version of the backend's getenv(), that + # the desired env var value is seen by the backend, and otherwise use + # a wrapped setenv() to establish that value in the backend. + + # Also on Windows, a set env var can never have the value '', since + # the command 'set FOO= ' is used to unset the variable. Perhaps + # as a result, the wrapped dmlc::GetEnv() routine returns the same + # value for unset variables and those set to ''. As a result, we + # ignore discrepancy. + def validate_backend_setting(name, value, can_use_setenv=True): + backend_value = getenv(name) + if value == backend_value or \ + value == '' and backend_value is None and platform.system() == 'Windows': + return + if not can_use_setenv: + raise RuntimeError('Could not set env var {}={} within C Runtime'.format(name, value)) + setenv(name, value) + validate_backend_setting(name, value, can_use_setenv=False) + + # Core routine to alter environment from a dict of env_var_name, env_var_value pairs + def set_environ(env_var_dict): + for env_var_name, env_var_value in env_var_dict.items(): + if env_var_value is None: + os.environ.pop(env_var_name, None) + else: + os.environ[env_var_name] = env_var_value + validate_backend_setting(env_var_name, env_var_value) + + # Create env_var name:value dict from the two calling methods of this routine + if len(args) == 1 and isinstance(args[0], dict): + env_vars = args[0] + else: + assert len(args) == 2, 'Expecting one dict arg or two args: env var name and value' + env_vars = {args[0]: args[1]} + + # Take a snapshot of the existing environment variable state + # for those variables to be changed. get() return None for unset keys. + snapshot = {x: os.environ.get(x) for x in env_vars.keys()} + + # Alter the environment per the env_vars dict + set_environ(env_vars) + + # Now run the wrapped code + try: + yield + finally: + # the backend engines may still be referencing the changed env var state + mx.nd.waitall() + # reinstate original env_var state per the snapshot taken earlier + set_environ(snapshot) def collapse_sum_like(a, shape): diff --git a/python/mxnet/util.py b/python/mxnet/util.py index b35d3f38aa75..05a2bdad3b77 100644 --- a/python/mxnet/util.py +++ b/python/mxnet/util.py @@ -21,7 +21,7 @@ import inspect import threading -from .base import _LIB, check_call +from .base import _LIB, check_call, c_str, py_str _np_ufunc_default_kwargs = { @@ -913,6 +913,7 @@ def get_cuda_compute_capability(ctx): .format(ret, error_str.value.decode())) return cc_major.value * 10 + cc_minor.value + def default_array(source_array, ctx=None, dtype=None): """Creates an array from any object exposing the default(nd or np) array interface. @@ -1144,3 +1145,35 @@ def set_np_default_dtype(is_np_default_dtype=True): # pylint: disable=redefined prev = ctypes.c_bool() check_call(_LIB.MXSetIsNumpyDefaultDtype(ctypes.c_bool(is_np_default_dtype), ctypes.byref(prev))) return prev.value + + +def getenv(name): + """Get the setting of an environment variable from the C Runtime. + + Parameters + ---------- + name : string type + The environment variable name + + Returns + ------- + value : string + The value of the environment variable, or None if not set + """ + ret = ctypes.c_char_p() + check_call(_LIB.MXGetEnv(c_str(name), ctypes.byref(ret))) + return None if ret.value is None else py_str(ret.value) + + +def setenv(name, value): + """Set an environment variable in the C Runtime. + + Parameters + ---------- + name : string type + The environment variable name + value : string type + The desired value to set the environment value to + """ + passed_value = None if value is None else c_str(value) + check_call(_LIB.MXSetEnv(c_str(name), passed_value)) diff --git a/src/c_api/c_api_test.cc b/src/c_api/c_api_test.cc index de4fb7dca18e..e84b0c0b1395 100644 --- a/src/c_api/c_api_test.cc +++ b/src/c_api/c_api_test.cc @@ -106,3 +106,25 @@ int MXRemoveSubgraphPropertyOpNamesV2(const char* prop_name) { } API_END(); } + +int MXGetEnv(const char* name, + const char** value) { + API_BEGIN(); + *value = getenv(name); + API_END(); +} + +int MXSetEnv(const char* name, + const char* value) { + API_BEGIN(); +#ifdef _WIN32 + auto value_arg = (value == nullptr) ? "" : value; + _putenv_s(name, value_arg); +#else + if (value == nullptr) + unsetenv(name); + else + setenv(name, value, 1); +#endif + API_END(); +} diff --git a/tests/python/gpu/test_device.py b/tests/python/gpu/test_device.py index 52e09c029b49..76a32def33f5 100644 --- a/tests/python/gpu/test_device.py +++ b/tests/python/gpu/test_device.py @@ -20,8 +20,7 @@ import pytest import os import logging - -from mxnet.test_utils import EnvManager +from mxnet.test_utils import environment shapes = [(10), (100), (1000), (10000), (100000), (2,2), (2,3,4,5,6,7,8)] keys = [1,2,3,4,5,6,7] @@ -51,16 +50,15 @@ def check_dense_pushpull(kv_type): for x in range(n_gpus): assert(np.sum(np.abs((res[x]-n_gpus).asnumpy()))==0) - kvstore_tree_array_bound = 'MXNET_KVSTORE_TREE_ARRAY_BOUND' - kvstore_usetree_values = ['','1'] - kvstore_usetree = 'MXNET_KVSTORE_USETREE' - for _ in range(2): + kvstore_tree_array_bound_values = [None, '1'] + kvstore_usetree_values = [None, '1'] + for y in kvstore_tree_array_bound_values: for x in kvstore_usetree_values: - with EnvManager(kvstore_usetree, x): + with environment({'MXNET_KVSTORE_USETREE': x, + 'MXNET_KVSTORE_TREE_ARRAY_BOUND': y}): check_dense_pushpull('local') check_dense_pushpull('device') - os.environ[kvstore_tree_array_bound] = '1' - del os.environ[kvstore_tree_array_bound] + if __name__ == '__main__': test_device_pushpull() diff --git a/tests/python/gpu/test_fusion.py b/tests/python/gpu/test_fusion.py index 8d3ce47c18e8..1f261adcebac 100644 --- a/tests/python/gpu/test_fusion.py +++ b/tests/python/gpu/test_fusion.py @@ -15,8 +15,10 @@ # specific language governing permissions and limitations # under the License. +import sys import os import random +import itertools import mxnet as mx import numpy as np from mxnet import autograd, gluon @@ -24,7 +26,7 @@ curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) sys.path.insert(0, os.path.join(curr_path, '../unittest')) -from common import with_seed +from common import setup_module, teardown_module, with_seed def check_fused_symbol(sym, **kwargs): inputs = sym.list_inputs() @@ -44,10 +46,10 @@ def check_fused_symbol(sym, **kwargs): data = {inp : kwargs[inp].astype(dtype) for inp in inputs} for grad_req in ['write', 'add']: type_dict = {inp : dtype for inp in inputs} - os.environ["MXNET_USE_FUSION"] = "0" - orig_exec = test_sym._simple_bind(ctx=ctx, grad_req=grad_req, type_dict=type_dict, **shapes) - os.environ["MXNET_USE_FUSION"] = "1" - fused_exec = test_sym._simple_bind(ctx=ctx, grad_req=grad_req, type_dict=type_dict, **shapes) + with environment('MXNET_USE_FUSION', '0'): + orig_exec = test_sym._simple_bind(ctx=ctx, grad_req=grad_req, type_dict=type_dict, **shapes) + with environment('MXNET_USE_FUSION', '1'): + fused_exec = test_sym._simple_bind(ctx=ctx, grad_req=grad_req, type_dict=type_dict, **shapes) fwd_orig = orig_exec.forward(is_train=True, **data) out_grads = [mx.nd.ones_like(arr) for arr in fwd_orig] orig_exec.backward(out_grads=out_grads) @@ -231,6 +233,7 @@ def check_other_ops(): arr2 = mx.random.uniform(shape=(2,2,2,3)) check_fused_symbol(mx.sym.broadcast_like(a, b, lhs_axes=[0], rhs_axes=[0]), a=arr1, b=arr2) + def check_leakyrelu_ops(): a = mx.sym.Variable('a') b = mx.sym.Variable('b') @@ -331,18 +334,18 @@ def hybrid_forward(self, F, x, y, z): arrays = {} for use_fusion in ('0', '1'): - os.environ['MXNET_USE_FUSION'] = use_fusion - arrays[use_fusion] = {} - n = Block() - n.hybridize(static_alloc=static_alloc) - args = [arg.copyto(mx.gpu()) for arg in arg_data] - for arg in args: - arg.attach_grad() - with autograd.record(): - r = n(*args) - arrays[use_fusion]['result'] = r - r.backward() - for i, arg in enumerate(args): - arrays[use_fusion][i] = arg.grad + with environment('MXNET_USE_FUSION', use_fusion): + arrays[use_fusion] = {} + n = Block() + n.hybridize(static_alloc=static_alloc) + args = [arg.copyto(mx.gpu()) for arg in arg_data] + for arg in args: + arg.attach_grad() + with autograd.record(): + r = n(*args) + arrays[use_fusion]['result'] = r + r.backward() + for i, arg in enumerate(args): + arrays[use_fusion][i] = arg.grad for key in ['result'] + list(range(len(arg_data))): assert_allclose(arrays['0'][key].asnumpy(), arrays['1'][key].asnumpy()) diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index e259b74b9fad..278ea02fce98 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -21,7 +21,7 @@ import time import mxnet as mx import multiprocessing as mp -from mxnet.test_utils import check_consistency, set_default_context, assert_almost_equal, rand_ndarray +from mxnet.test_utils import check_consistency, set_default_context, assert_almost_equal, rand_ndarray, environment import mxnet.ndarray as nd import numpy as np import math @@ -558,9 +558,9 @@ def _test_bulking(test_bulking_func): time_per_iteration = mp.Manager().Value('d', 0.0) if not run_in_spawned_process(test_bulking_func, - {'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD': seg_sizes[0], - 'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD': seg_sizes[1], - 'MXNET_EXEC_BULK_EXEC_TRAIN': seg_sizes[2]}, + {'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD': str(seg_sizes[0]), + 'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD': str(seg_sizes[1]), + 'MXNET_EXEC_BULK_EXEC_TRAIN': str(seg_sizes[2])}, time_per_iteration): # skip test since the python version can't run it properly. Warning msg was logged. return @@ -634,13 +634,15 @@ def test_gemms_true_fp16(): net.cast('float16') net.initialize(ctx=ctx) net.weight.set_data(weights) - ref_results = net(input) - os.environ["MXNET_FC_TRUE_FP16"] = "1" - results_trueFP16 = net(input) + with environment('MXNET_FC_TRUE_FP16', '0'): + ref_results = net(input) + + with environment('MXNET_FC_TRUE_FP16', '1'): + results_trueFP16 = net(input) + atol = 1e-2 rtol = 1e-2 assert_almost_equal(ref_results.asnumpy(), results_trueFP16.asnumpy(), atol=atol, rtol=rtol) - os.environ["MXNET_FC_TRUE_FP16"] = "0" diff --git a/tests/python/gpu/test_kvstore_gpu.py b/tests/python/gpu/test_kvstore_gpu.py index 5167970bb8f8..d836e970181d 100644 --- a/tests/python/gpu/test_kvstore_gpu.py +++ b/tests/python/gpu/test_kvstore_gpu.py @@ -21,7 +21,7 @@ import mxnet as mx import numpy as np import pytest -from mxnet.test_utils import assert_almost_equal, default_context, EnvManager +from mxnet.test_utils import assert_almost_equal, default_context, environment curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) sys.path.insert(0, os.path.join(curr_path, '../unittest')) from common import setup_module, with_seed, teardown_module @@ -99,11 +99,11 @@ def check_rsp_pull(kv, ctxs, sparse_pull, is_same_rowid=False, use_slice=False): check_rsp_pull(kv, [mx.gpu(i//2) for i in range(4)], sparse_pull, use_slice=True) check_rsp_pull(kv, [mx.cpu(i) for i in range(4)], sparse_pull, use_slice=True) - envs = ["","1"] - key = "MXNET_KVSTORE_USETREE" + envs = [None, '1'] + key = 'MXNET_KVSTORE_USETREE' for val in envs: - with EnvManager(key, val): - if val is "1": + with environment(key, val): + if val is '1': sparse_pull = False else: sparse_pull = True diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index d088b44f85b7..84cfe9cfa35d 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -26,7 +26,7 @@ import itertools from mxnet.test_utils import check_consistency, set_default_context, assert_almost_equal, assert_allclose from mxnet.test_utils import check_symbolic_forward, check_symbolic_backward, discard_stderr -from mxnet.test_utils import default_context, rand_shape_2d, rand_ndarray, same +from mxnet.test_utils import default_context, rand_shape_2d, rand_ndarray, same, environment from mxnet.base import MXNetError from mxnet import autograd @@ -654,12 +654,12 @@ def _conv_with_num_streams(seed): @pytest.mark.skip(reason="skipping for now due to severe flakiness") @with_seed() def test_convolution_multiple_streams(): - for num_streams in [1, 2]: + for num_streams in ['1', '2']: for engine in ['NaiveEngine', 'ThreadedEngine', 'ThreadedEnginePerDevice']: - print("Starting engine %s with %d streams." % (engine, num_streams), file=sys.stderr) + print('Starting engine {} with {} streams.'.format(engine, num_streams), file=sys.stderr) run_in_spawned_process(_conv_with_num_streams, {'MXNET_GPU_WORKER_NSTREAMS' : num_streams, 'MXNET_ENGINE_TYPE' : engine}) - print("Finished engine %s with %d streams." % (engine, num_streams), file=sys.stderr) + print('Finished engine {} with {} streams.'.format(engine, num_streams), file=sys.stderr) # This test is designed to expose an issue with cudnn v7.1.4 algo find() when invoked with large c. @@ -2009,22 +2009,22 @@ def check_proposal_consistency(op, batch_size, with_nms=False): # The following 2 functions launch 0-thread kernels, an error that should be caught and signaled. def kernel_error_check_imperative(): - os.environ['MXNET_ENGINE_TYPE'] = 'NaiveEngine' - with mx.np_shape(active=True): - a = mx.nd.array([1,2,3],ctx=mx.gpu(0)) - b = mx.nd.array([],ctx=mx.gpu(0)) - c = (a / b).asnumpy() + with environment('MXNET_ENGINE_TYPE', 'NaiveEngine'): + with mx.np_shape(active=True): + a = mx.nd.array([1,2,3],ctx=mx.gpu(0)) + b = mx.nd.array([],ctx=mx.gpu(0)) + c = (a / b).asnumpy() def kernel_error_check_symbolic(): - os.environ['MXNET_ENGINE_TYPE'] = 'NaiveEngine' - with mx.np_shape(active=True): - a = mx.sym.Variable('a') - b = mx.sym.Variable('b') - c = a / b - f = c._bind(mx.gpu(0), { 'a':mx.nd.array([1,2,3],ctx=mx.gpu(0)), - 'b':mx.nd.array([],ctx=mx.gpu(0))}) - f.forward() - g = f.outputs[0].asnumpy() + with environment('MXNET_ENGINE_TYPE', 'NaiveEngine'): + with mx.np_shape(active=True): + a = mx.sym.Variable('a') + b = mx.sym.Variable('b') + c = a / b + f = c.bind(mx.gpu(0), {'a':mx.nd.array([1,2,3],ctx=mx.gpu(0)), + 'b':mx.nd.array([],ctx=mx.gpu(0))}) + f.forward() + g = f.outputs[0].asnumpy() @pytest.mark.serial def test_kernel_error_checking(): @@ -2223,9 +2223,9 @@ def test_bulking(): # Create shared variable to return measured time from test process time_per_iteration = mp.Manager().Value('d', 0.0) if not run_in_spawned_process(_test_bulking_in_process, - {'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD' : seg_sizes[0], - 'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD' : seg_sizes[1], - 'MXNET_EXEC_BULK_EXEC_TRAIN' : seg_sizes[2]}, + {'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD' : str(seg_sizes[0]), + 'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD' : str(seg_sizes[1]), + 'MXNET_EXEC_BULK_EXEC_TRAIN' : str(seg_sizes[2])}, time_per_iteration): # skip test since the python version can't run it properly. Warning msg was logged. return diff --git a/tests/python/unittest/common.py b/tests/python/unittest/common.py index 331f32d789b7..2f30096c10b4 100644 --- a/tests/python/unittest/common.py +++ b/tests/python/unittest/common.py @@ -23,6 +23,7 @@ import random import shutil from mxnet.base import MXNetError +from mxnet.test_utils import environment curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) sys.path.append(os.path.join(curr_path, '../common/')) sys.path.insert(0, os.path.join(curr_path, '../../../python')) @@ -216,15 +217,17 @@ def test_new(*args, **kwargs): logger = default_logger() # 'pytest --logging-level=DEBUG' shows this msg even with an ensuing core dump. test_count_msg = '{} of {}: '.format(i+1,test_count) if test_count > 1 else '' - test_msg = ('{}Setting test np/mx/python random seeds, use MXNET_TEST_SEED={}' - ' to reproduce.').format(test_count_msg, this_test_seed) - logger.log(log_level, test_msg) + pre_test_msg = ('{}Setting test np/mx/python random seeds, use MXNET_TEST_SEED={}' + ' to reproduce.').format(test_count_msg, this_test_seed) + on_err_test_msg = ('{}Error seen with seeded test, use MXNET_TEST_SEED={}' + ' to reproduce.').format(test_count_msg, this_test_seed) + logger.log(log_level, pre_test_msg) try: orig_test(*args, **kwargs) except: # With exceptions, repeat test_msg at WARNING level to be sure it's seen. if log_level < logging.WARNING: - logger.warning(test_msg) + logger.warning(on_err_test_msg) raise finally: # Provide test-isolation for any test having this decorator @@ -307,6 +310,22 @@ def teardown_module(): mx.nd.waitall() +def with_environment(*args_): + """ + Helper function that takes a dictionary of environment variables and their + desired settings and changes the environment in advance of running the + decorated code. The original environment state is reinstated afterwards, + even if exceptions are raised. + """ + def test_helper(orig_test): + @functools.wraps(orig_test) + def test_new(*args, **kwargs): + with environment(*args_): + orig_test(*args, **kwargs) + return test_new + return test_helper + + def run_in_spawned_process(func, env, *args): """ Helper function to run a test in its own process. @@ -337,18 +356,12 @@ def run_in_spawned_process(func, env, *args): return False else: seed = np.random.randint(0,1024*1024*1024) - orig_environ = os.environ.copy() - try: - for (key, value) in env.items(): - os.environ[key] = str(value) + with environment(env): # Prepend seed as first arg p = mpctx.Process(target=func, args=(seed,)+args) p.start() p.join() assert p.exitcode == 0, "Non-zero exit code %d from %s()." % (p.exitcode, func.__name__) - finally: - os.environ.clear() - os.environ.update(orig_environ) return True diff --git a/tests/python/unittest/test_autograd.py b/tests/python/unittest/test_autograd.py index ad0601cdbb0f..6a75eed7d0bb 100644 --- a/tests/python/unittest/test_autograd.py +++ b/tests/python/unittest/test_autograd.py @@ -20,8 +20,9 @@ from mxnet.ndarray import zeros_like from mxnet.autograd import * from mxnet.test_utils import * + from common import setup_module, with_seed, teardown_module, xfail_when_nonstandard_decimal_separator -from mxnet.test_utils import EnvManager +from mxnet.test_utils import environment import pytest @@ -124,7 +125,7 @@ def check_unary_func(x): autograd_assert(x, func=f_square, grad_func=f_square_grad) uniform = nd.uniform(shape=(4, 5)) stypes = ['default', 'row_sparse', 'csr'] - with EnvManager('MXNET_STORAGE_FALLBACK_LOG_VERBOSE', '0'): + with environment('MXNET_STORAGE_FALLBACK_LOG_VERBOSE', '0'): for stype in stypes: check_unary_func(uniform.tostype(stype)) @@ -143,7 +144,7 @@ def check_binary_func(x, y): uniform_x = nd.uniform(shape=(4, 5)) uniform_y = nd.uniform(shape=(4, 5)) stypes = ['default', 'row_sparse', 'csr'] - with EnvManager('MXNET_STORAGE_FALLBACK_LOG_VERBOSE', '0'): + with environment('MXNET_STORAGE_FALLBACK_LOG_VERBOSE', '0'): for stype_x in stypes: for stype_y in stypes: x = uniform_x.tostype(stype_x) diff --git a/tests/python/unittest/test_base.py b/tests/python/unittest/test_base.py index 07d429589ba2..74d3f17a645e 100644 --- a/tests/python/unittest/test_base.py +++ b/tests/python/unittest/test_base.py @@ -16,33 +16,92 @@ # under the License. import mxnet as mx +from numpy.testing import assert_equal from mxnet.base import data_dir +from mxnet.test_utils import environment +from mxnet.util import getenv +from common import setup_module, teardown_module, with_environment import os -import unittest import logging import os.path as op import platform -class MXNetDataDirTest(unittest.TestCase): - def setUp(self): - self.mxnet_data_dir = os.environ.get('MXNET_HOME') - if 'MXNET_HOME' in os.environ: - del os.environ['MXNET_HOME'] +def test_environment(): + name1 = 'MXNET_TEST_ENV_VAR_1' + name2 = 'MXNET_TEST_ENV_VAR_2' - def tearDown(self): - if self.mxnet_data_dir: - os.environ['MXNET_HOME'] = self.mxnet_data_dir - else: - if 'MXNET_HOME' in os.environ: - del os.environ['MXNET_HOME'] - - def test_data_dir(self,): - prev_data_dir = data_dir() - system = platform.system() - if system != 'Windows': - self.assertEqual(data_dir(), op.join(op.expanduser('~'), '.mxnet')) - os.environ['MXNET_HOME'] = '/tmp/mxnet_data' - self.assertEqual(data_dir(), '/tmp/mxnet_data') - del os.environ['MXNET_HOME'] - self.assertEqual(data_dir(), prev_data_dir) + # Test that a variable can be set in the python and backend environment + with environment(name1, '42'): + assert_equal(os.environ.get(name1), '42') + assert_equal(getenv(name1), '42') + + # Test dict form of invocation + env_var_dict = {name1: '1', name2: '2'} + with environment(env_var_dict): + for key, value in env_var_dict.items(): + assert_equal(os.environ.get(key), value) + assert_equal(getenv(key), value) + + # Further testing in 'test_with_environment()' + +@with_environment({'MXNET_TEST_ENV_VAR_1': '10', 'MXNET_TEST_ENV_VAR_2': None}) +def test_with_environment(): + name1 = 'MXNET_TEST_ENV_VAR_1' + name2 = 'MXNET_TEST_ENV_VAR_2' + def check_background_values(): + assert_equal(os.environ.get(name1), '10') + assert_equal(getenv(name1), '10') + assert_equal(os.environ.get(name2), None) + assert_equal(getenv(name2), None) + + check_background_values() + + # This completes the testing of with_environment(), but since we have + # an environment with a couple of known settings, lets use it to test if + # 'with environment()' properly restores to these settings in all cases. + class OnPurposeError(Exception): + """A class for exceptions thrown by this test""" + pass + + # Enter an environment with one variable set and check it appears + # to both python and the backend. Then, outside the 'with' block, + # make sure the background environment is seen, regardless of whether + # the 'with' block raised an exception. + def test_one_var(name, value, raise_exception=False): + try: + with environment(name, value): + assert_equal(os.environ.get(name), value) + assert_equal(getenv(name), value) + if raise_exception: + raise OnPurposeError + except OnPurposeError: + pass + finally: + check_background_values() + + # Test various combinations of set and unset env vars. + # Test that the background setting is restored in the presense of exceptions. + for raise_exception in [False, True]: + # name1 is initially set in the environment + test_one_var(name1, '42', raise_exception) + test_one_var(name1, None, raise_exception) + # name2 is initially not set in the environment + test_one_var(name2, '42', raise_exception) + test_one_var(name2, None, raise_exception) + + +def test_data_dir(): + prev_data_dir = data_dir() + system = platform.system() + # Test that data_dir() returns the proper default value when MXNET_HOME is not set + with environment('MXNET_HOME', None): + if system == 'Windows': + assert_equal(data_dir(), op.join(os.environ.get('APPDATA'), 'mxnet')) + else: + assert_equal(data_dir(), op.join(op.expanduser('~'), '.mxnet')) + # Test that data_dir() responds to an explicit setting of MXNET_HOME + with environment('MXNET_HOME', '/tmp/mxnet_data'): + assert_equal(data_dir(), '/tmp/mxnet_data') + # Test that this test has not disturbed the MXNET_HOME value existing before the test + assert_equal(data_dir(), prev_data_dir) diff --git a/tests/python/unittest/test_engine.py b/tests/python/unittest/test_engine.py index 538e4b57ead8..642d9e1f169e 100644 --- a/tests/python/unittest/test_engine.py +++ b/tests/python/unittest/test_engine.py @@ -17,7 +17,7 @@ import mxnet as mx import os -from mxnet.test_utils import EnvManager +from mxnet.test_utils import environment import pytest def test_bulk(): @@ -41,7 +41,7 @@ def test_engine_openmp_after_fork(): With GOMP the child always has the same number when calling omp_get_max_threads, with LLVM OMP the child respects the number of max threads set in the parent. """ - with EnvManager('OMP_NUM_THREADS', '42'): + with environment('OMP_NUM_THREADS', '42'): r, w = os.pipe() pid = os.fork() if pid: diff --git a/tests/python/unittest/test_engine_import.py b/tests/python/unittest/test_engine_import.py index 7675cf836999..322528c11aed 100644 --- a/tests/python/unittest/test_engine_import.py +++ b/tests/python/unittest/test_engine_import.py @@ -16,6 +16,8 @@ # under the License. import os +from mxnet.test_utils import environment +import pytest try: reload # Python 2 @@ -23,15 +25,13 @@ from importlib import reload +@pytest.mark.skip(reason='test needs improving, current use of reload(mxnet) is ineffective') def test_engine_import(): import mxnet - - engine_types = ['', 'NaiveEngine', 'ThreadedEngine', 'ThreadedEnginePerDevice'] + # Temporarily add an illegal entry (that is not caught) to show how the test needs improving + engine_types = [None, 'NaiveEngine', 'ThreadedEngine', 'ThreadedEnginePerDevice', 'BogusEngine'] for type in engine_types: - if type: - os.environ['MXNET_ENGINE_TYPE'] = type - else: - os.environ.pop('MXNET_ENGINE_TYPE', None) - reload(mxnet) + with environment('MXNET_ENGINE_TYPE', type): + reload(mxnet) diff --git a/tests/python/unittest/test_executor.py b/tests/python/unittest/test_executor.py index 0e142bf5b05a..27a9e030c171 100644 --- a/tests/python/unittest/test_executor.py +++ b/tests/python/unittest/test_executor.py @@ -18,7 +18,7 @@ import numpy as np import mxnet as mx from common import setup_module, with_seed, teardown_module -from mxnet.test_utils import assert_almost_equal +from mxnet.test_utils import assert_almost_equal, environment def check_bind_with_uniform(uf, gf, dim, sf=None, lshape=None, rshape=None): @@ -74,42 +74,34 @@ def check_bind_with_uniform(uf, gf, dim, sf=None, lshape=None, rshape=None): @with_seed() def test_bind(): - def check_bind(disable_bulk_exec): - if disable_bulk_exec: - prev_bulk_inf_val = mx.test_utils.set_env_var("MXNET_EXEC_BULK_EXEC_INFERENCE", "0", "1") - prev_bulk_train_val = mx.test_utils.set_env_var("MXNET_EXEC_BULK_EXEC_TRAIN", "0", "1") - - nrepeat = 10 - maxdim = 4 - for repeat in range(nrepeat): - for dim in range(1, maxdim): - check_bind_with_uniform(lambda x, y: x + y, - lambda g, x, y: (g, g), - dim) - check_bind_with_uniform(lambda x, y: x - y, - lambda g, x, y: (g, -g), - dim) - check_bind_with_uniform(lambda x, y: x * y, - lambda g, x, y: (y * g, x * g), - dim) - check_bind_with_uniform(lambda x, y: x / y, - lambda g, x, y: (g / y, -x * g/ (y**2)), - dim) - - check_bind_with_uniform(lambda x, y: np.maximum(x, y), - lambda g, x, y: (g * (x>=y), g * (y>x)), - dim, - sf=mx.symbol.maximum) - check_bind_with_uniform(lambda x, y: np.minimum(x, y), - lambda g, x, y: (g * (x<=y), g * (y=y), g * (y>x)), + dim, + sf=mx.symbol.maximum) + check_bind_with_uniform(lambda x, y: np.minimum(x, y), + lambda g, x, y: (g * (x<=y), g * (y