Skip to content

Commit

Permalink
Improve environment variable handling in unittests (apache#18424)
Browse files Browse the repository at this point in the history
  • Loading branch information
DickJC123 committed Sep 17, 2020
1 parent ce0a518 commit cd1688f
Show file tree
Hide file tree
Showing 19 changed files with 475 additions and 335 deletions.
16 changes: 16 additions & 0 deletions include/mxnet/c_api_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,22 @@ MXNET_DLL int MXRemoveSubgraphPropertyOpNames(const char* prop_name);
MXNET_DLL int MXRemoveSubgraphPropertyOpNamesV2(const char* prop_name);


/*!
* \brief Get the value of an environment variable as seen by the backend.
* \param name The name of the environment variable
* \param value The returned value of the environment variable
*/
MXNET_DLL int MXGetEnv(const char* name,
const char** value);

/*!
* \brief Set the value of an environment variable from the backend.
* \param name The name of the environment variable
* \param value The desired value to set the environment variable `name`
*/
MXNET_DLL int MXSetEnv(const char* name,
const char* value);

#ifdef __cplusplus
}
#endif // __cplusplus
Expand Down
113 changes: 76 additions & 37 deletions python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import numbers
import sys
import os
import platform
import errno
import logging
import bz2
Expand All @@ -49,7 +50,7 @@
from .ndarray import array
from .symbol import Symbol
from .symbol.numpy import _Symbol as np_symbol
from .util import use_np # pylint: disable=unused-import
from .util import use_np, getenv, setenv # pylint: disable=unused-import
from .runtime import Features
from .numpy_extension import get_cuda_compute_capability

Expand Down Expand Up @@ -2035,27 +2036,6 @@ def get_bz2_data(data_dir, data_name, url, data_origin_name):
bz_file.close()
os.remove(data_origin_name)

def set_env_var(key, val, default_val=""):
"""Set environment variable
Parameters
----------
key : str
Env var to set
val : str
New value assigned to the env var
default_val : str, optional
Default value returned if the env var doesn't exist
Returns
-------
str
The value of env var before it is set to the new value
"""
prev_val = os.environ.get(key, default_val)
os.environ[key] = val
return prev_val

def same_array(array1, array2):
"""Check whether two NDArrays sharing the same memory block
Expand All @@ -2080,9 +2060,11 @@ def same_array(array1, array2):
array1[:] -= 1
return same(array1.asnumpy(), array2.asnumpy())


@contextmanager
def discard_stderr():
"""Discards error output of a routine if invoked as:
"""
Discards error output of a routine if invoked as:
with discard_stderr():
...
Expand Down Expand Up @@ -2471,22 +2453,79 @@ def same_symbol_structure(sym1, sym2):
return True


class EnvManager(object):
"""Environment variable setter and unsetter via with idiom"""
def __init__(self, key, val):
self._key = key
self._next_val = val
self._prev_val = None
@contextmanager
def environment(*args):
"""
Environment variable setter and unsetter via `with` idiom.
def __enter__(self):
self._prev_val = os.environ.get(self._key)
os.environ[self._key] = self._next_val
Takes a specification of env var names and desired values and adds those
settings to the environment in advance of running the body of the `with`
statement. The original environment state is restored afterwards, even
if exceptions are raised in the `with` body.
def __exit__(self, ptype, value, trace):
if self._prev_val:
os.environ[self._key] = self._prev_val
else:
del os.environ[self._key]
Parameters
----------
args:
if 2 args are passed:
name, desired_value strings of the single env var to update, or
if 1 arg is passed:
a dict of name:desired_value for env var's to update
"""

# On Linux, env var changes made through python's os.environ are seen
# by the backend. On Windows though, the C runtime gets a snapshot
# of the environment that cannot be altered by os.environ. Here we
# check, using a wrapped version of the backend's getenv(), that
# the desired env var value is seen by the backend, and otherwise use
# a wrapped setenv() to establish that value in the backend.

# Also on Windows, a set env var can never have the value '', since
# the command 'set FOO= ' is used to unset the variable. Perhaps
# as a result, the wrapped dmlc::GetEnv() routine returns the same
# value for unset variables and those set to ''. As a result, we
# ignore discrepancy.
def validate_backend_setting(name, value, can_use_setenv=True):
backend_value = getenv(name)
if value == backend_value or \
value == '' and backend_value is None and platform.system() == 'Windows':
return
if not can_use_setenv:
raise RuntimeError('Could not set env var {}={} within C Runtime'.format(name, value))
setenv(name, value)
validate_backend_setting(name, value, can_use_setenv=False)

# Core routine to alter environment from a dict of env_var_name, env_var_value pairs
def set_environ(env_var_dict):
for env_var_name, env_var_value in env_var_dict.items():
if env_var_value is None:
os.environ.pop(env_var_name, None)
else:
os.environ[env_var_name] = env_var_value
validate_backend_setting(env_var_name, env_var_value)

# Create env_var name:value dict from the two calling methods of this routine
if len(args) == 1 and isinstance(args[0], dict):
env_vars = args[0]
else:
assert len(args) == 2, 'Expecting one dict arg or two args: env var name and value'
env_vars = {args[0]: args[1]}

# Take a snapshot of the existing environment variable state
# for those variables to be changed. get() return None for unset keys.
snapshot = {x: os.environ.get(x) for x in env_vars.keys()}

# Alter the environment per the env_vars dict
set_environ(env_vars)

# Now run the wrapped code
try:
yield
finally:
# the backend engines may still be referencing the changed env var state
mx.nd.waitall()
# reinstate original env_var state per the snapshot taken earlier
set_environ(snapshot)


def collapse_sum_like(a, shape):
Expand Down
34 changes: 33 additions & 1 deletion python/mxnet/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import inspect
import threading

from .base import _LIB, check_call
from .base import _LIB, check_call, c_str, py_str


_np_ufunc_default_kwargs = {
Expand Down Expand Up @@ -816,3 +816,35 @@ def get_cuda_compute_capability(ctx):
raise RuntimeError('cuDeviceComputeCapability failed with error code {}: {}'
.format(ret, error_str.value.decode()))
return cc_major.value * 10 + cc_minor.value


def getenv(name):
"""Get the setting of an environment variable from the C Runtime.
Parameters
----------
name : string type
The environment variable name
Returns
-------
value : string
The value of the environment variable, or None if not set
"""
ret = ctypes.c_char_p()
check_call(_LIB.MXGetEnv(c_str(name), ctypes.byref(ret)))
return None if ret.value is None else py_str(ret.value)


def setenv(name, value):
"""Set an environment variable in the C Runtime.
Parameters
----------
name : string type
The environment variable name
value : string type
The desired value to set the environment value to
"""
passed_value = None if value is None else c_str(value)
check_call(_LIB.MXSetEnv(c_str(name), passed_value))
22 changes: 22 additions & 0 deletions src/c_api/c_api_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,25 @@ int MXRemoveSubgraphPropertyOpNamesV2(const char* prop_name) {
}
API_END();
}

int MXGetEnv(const char* name,
const char** value) {
API_BEGIN();
*value = getenv(name);
API_END();
}

int MXSetEnv(const char* name,
const char* value) {
API_BEGIN();
#ifdef _WIN32
auto value_arg = (value == nullptr) ? "" : value;
_putenv_s(name, value_arg);
#else
if (value == nullptr)
unsetenv(name);
else
setenv(name, value, 1);
#endif
API_END();
}
16 changes: 7 additions & 9 deletions tests/python/gpu/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
import unittest
import os
import logging

from mxnet.test_utils import EnvManager
from mxnet.test_utils import environment

shapes = [(10), (100), (1000), (10000), (100000), (2,2), (2,3,4,5,6,7,8)]
keys = [1,2,3,4,5,6,7]
Expand Down Expand Up @@ -51,16 +50,15 @@ def check_dense_pushpull(kv_type):
for x in range(n_gpus):
assert(np.sum(np.abs((res[x]-n_gpus).asnumpy()))==0)

kvstore_tree_array_bound = 'MXNET_KVSTORE_TREE_ARRAY_BOUND'
kvstore_usetree_values = ['','1']
kvstore_usetree = 'MXNET_KVSTORE_USETREE'
for _ in range(2):
kvstore_tree_array_bound_values = [None, '1']
kvstore_usetree_values = [None, '1']
for y in kvstore_tree_array_bound_values:
for x in kvstore_usetree_values:
with EnvManager(kvstore_usetree, x):
with environment({'MXNET_KVSTORE_USETREE': x,
'MXNET_KVSTORE_TREE_ARRAY_BOUND': y}):
check_dense_pushpull('local')
check_dense_pushpull('device')
os.environ[kvstore_tree_array_bound] = '1'
del os.environ[kvstore_tree_array_bound]


if __name__ == '__main__':
test_device_pushpull()
13 changes: 8 additions & 5 deletions tests/python/gpu/test_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@
# specific language governing permissions and limitations
# under the License.

import sys
import os
import random
import itertools
import mxnet as mx
import numpy as np
from mxnet.test_utils import *

curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.insert(0, os.path.join(curr_path, '../unittest'))
from common import with_seed
from common import setup_module, teardown_module, with_seed

def check_fused_symbol(sym, **kwargs):
inputs = sym.list_inputs()
Expand All @@ -43,10 +45,10 @@ def check_fused_symbol(sym, **kwargs):
data = {inp : kwargs[inp].astype(dtype) for inp in inputs}
for grad_req in ['write', 'add']:
type_dict = {inp : dtype for inp in inputs}
os.environ["MXNET_USE_FUSION"] = "0"
orig_exec = test_sym.simple_bind(ctx=ctx, grad_req=grad_req, type_dict=type_dict, **shapes)
os.environ["MXNET_USE_FUSION"] = "1"
fused_exec = test_sym.simple_bind(ctx=ctx, grad_req=grad_req, type_dict=type_dict, **shapes)
with environment('MXNET_USE_FUSION', '0'):
orig_exec = test_sym.simple_bind(ctx=ctx, grad_req=grad_req, type_dict=type_dict, **shapes)
with environment('MXNET_USE_FUSION', '1'):
fused_exec = test_sym.simple_bind(ctx=ctx, grad_req=grad_req, type_dict=type_dict, **shapes)
fwd_orig = orig_exec.forward(is_train=True, **data)
out_grads = [mx.nd.ones_like(arr) for arr in fwd_orig]
orig_exec.backward(out_grads=out_grads)
Expand Down Expand Up @@ -227,6 +229,7 @@ def check_other_ops():
arr2 = mx.random.uniform(shape=(2,2,2,3))
check_fused_symbol(mx.sym.broadcast_like(a, b, lhs_axes=[0], rhs_axes=[0]), a=arr1, b=arr2)


def check_leakyrelu_ops():
a = mx.sym.Variable('a')
b = mx.sym.Variable('b')
Expand Down
18 changes: 10 additions & 8 deletions tests/python/gpu/test_gluon_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import time
import mxnet as mx
import multiprocessing as mp
from mxnet.test_utils import check_consistency, set_default_context, assert_almost_equal, rand_ndarray
from mxnet.test_utils import check_consistency, set_default_context, assert_almost_equal, rand_ndarray, environment
import mxnet.ndarray as nd
import numpy as np
import math
Expand Down Expand Up @@ -555,9 +555,9 @@ def _test_bulking(test_bulking_func):
time_per_iteration = mp.Manager().Value('d', 0.0)

if not run_in_spawned_process(test_bulking_func,
{'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD': seg_sizes[0],
'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD': seg_sizes[1],
'MXNET_EXEC_BULK_EXEC_TRAIN': seg_sizes[2]},
{'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD': str(seg_sizes[0]),
'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD': str(seg_sizes[1]),
'MXNET_EXEC_BULK_EXEC_TRAIN': str(seg_sizes[2])},
time_per_iteration):
# skip test since the python version can't run it properly. Warning msg was logged.
return
Expand Down Expand Up @@ -631,15 +631,17 @@ def test_gemms_true_fp16():
net.cast('float16')
net.initialize(ctx=ctx)
net.weight.set_data(weights)
ref_results = net(input)

os.environ["MXNET_FC_TRUE_FP16"] = "1"
results_trueFP16 = net(input)
with environment('MXNET_FC_TRUE_FP16', '0'):
ref_results = net(input)

with environment('MXNET_FC_TRUE_FP16', '1'):
results_trueFP16 = net(input)

atol = 1e-2
rtol = 1e-2
assert_almost_equal(ref_results.asnumpy(), results_trueFP16.asnumpy(),
atol=atol, rtol=rtol)
os.environ["MXNET_FC_TRUE_FP16"] = "0"


if __name__ == '__main__':
Expand Down
10 changes: 5 additions & 5 deletions tests/python/gpu/test_kvstore_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import mxnet as mx
import numpy as np
import unittest
from mxnet.test_utils import assert_almost_equal, default_context, EnvManager
from mxnet.test_utils import assert_almost_equal, default_context, environment
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.insert(0, os.path.join(curr_path, '../unittest'))
from common import setup_module, with_seed, teardown
Expand Down Expand Up @@ -97,11 +97,11 @@ def check_rsp_pull(kv, ctxs, sparse_pull, is_same_rowid=False, use_slice=False):
check_rsp_pull(kv, [mx.gpu(i//2) for i in range(4)], sparse_pull, use_slice=True)
check_rsp_pull(kv, [mx.cpu(i) for i in range(4)], sparse_pull, use_slice=True)

envs = ["","1"]
key = "MXNET_KVSTORE_USETREE"
envs = [None, '1']
key = 'MXNET_KVSTORE_USETREE'
for val in envs:
with EnvManager(key, val):
if val is "1":
with environment(key, val):
if val is '1':
sparse_pull = False
else:
sparse_pull = True
Expand Down
Loading

0 comments on commit cd1688f

Please sign in to comment.