Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

environment variable handling in unittests #18424

Merged
merged 19 commits into from
Jul 23, 2020
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions include/mxnet/c_api_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,22 @@ MXNET_DLL int MXRemoveSubgraphPropertyOpNames(const char* prop_name);
MXNET_DLL int MXRemoveSubgraphPropertyOpNamesV2(const char* prop_name);


/*!
* \brief Get the value of an environment variable as seen by the backend.
* \param name The name of the environment variable
* \param value The returned value of the environment variable
*/
MXNET_DLL int MXGetEnv(const char* name,
const char** value);

/*!
* \brief Set the value of an environment variable from the backend.
* \param name The name of the environment variable
* \param value The desired value to set the environment variable `name`
*/
MXNET_DLL int MXSetEnv(const char* name,
const char* value);

#ifdef __cplusplus
}
#endif // __cplusplus
Expand Down
111 changes: 74 additions & 37 deletions python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import numbers
import sys
import os
import platform
import errno
import logging
import bz2
Expand All @@ -49,7 +50,7 @@
from .ndarray import array
from .symbol import Symbol
from .symbol.numpy import _Symbol as np_symbol
from .util import use_np, use_np_default_dtype # pylint: disable=unused-import
from .util import use_np, use_np_default_dtype, getenv, setenv # pylint: disable=unused-import
from .runtime import Features
from .numpy_extension import get_cuda_compute_capability

Expand Down Expand Up @@ -1900,27 +1901,6 @@ def get_bz2_data(data_dir, data_name, url, data_origin_name):
bz_file.close()
os.remove(data_origin_name)

def set_env_var(key, val, default_val=""):
"""Set environment variable

Parameters
----------

key : str
Env var to set
val : str
New value assigned to the env var
default_val : str, optional
Default value returned if the env var doesn't exist

Returns
-------
str
The value of env var before it is set to the new value
"""
prev_val = os.environ.get(key, default_val)
os.environ[key] = val
return prev_val

def same_array(array1, array2):
"""Check whether two NDArrays sharing the same memory block
Expand All @@ -1945,9 +1925,11 @@ def same_array(array1, array2):
array1[:] -= 1
return same(array1.asnumpy(), array2.asnumpy())


@contextmanager
def discard_stderr():
"""Discards error output of a routine if invoked as:
"""
Discards error output of a routine if invoked as:

with discard_stderr():
...
Expand Down Expand Up @@ -2380,22 +2362,77 @@ def same_symbol_structure(sym1, sym2):
return True


class EnvManager(object):
"""Environment variable setter and unsetter via with idiom"""
def __init__(self, key, val):
self._key = key
self._next_val = val
self._prev_val = None
@contextmanager
def environment(*args):
"""
Environment variable setter and unsetter via `with` idiom.

def __enter__(self):
self._prev_val = os.environ.get(self._key)
os.environ[self._key] = self._next_val
Takes a specification of env var names and desired values and adds those
settings to the environment in advance of running the body of the `with`
statement. The original environment state is restored afterwards, even
if exceptions are raised in the `with` body.

def __exit__(self, ptype, value, trace):
if self._prev_val:
os.environ[self._key] = self._prev_val
else:
del os.environ[self._key]
Parameters
----------
args:
if 2 args are passed:
name, desired_value strings of the single env var to update, or
if 1 arg is passed:
a dict of name:desired_value for env var's to update

"""

# On Linux, env var changes made through python's os.environ are seen
# by the backend. On Windows though, the C runtime gets a snapshot
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 That difference can be nasty.

# of the environment that cannot be altered by os.environ. Here we
# check, using a wrapped version of the backend's getenv(), that
# the desired env var value is seen by the backend, and otherwise use
# a wrapped setenv() to establish that value in the backend.

# Also on Windows, a set env var can never have the value '', since
# the command 'set FOO= ' is used to unset the variable. Perhaps
# as a result, the wrapped dmlc::GetEnv() routine returns the same
# value for unset variables and those set to ''. As a result, we
# ignore discrepancy.
def validate_backend_setting(name, value, can_use_setenv=True):
backend_value = getenv(name)
if value == backend_value or \
value == '' and backend_value is None and platform.system() == 'Windows':
return
if not can_use_setenv:
raise RuntimeError('Could not set env var {}={} within C Runtime'.format(name, value))
setenv(name, value)
validate_backend_setting(name, value, can_use_setenv=False)

# Core routine to alter environment from a dict of env_var_name, env_var_value pairs
def set_environ(env_var_dict):
for env_var_name, env_var_value in env_var_dict.items():
if env_var_value is None:
os.environ.pop(env_var_name, None)
else:
os.environ[env_var_name] = env_var_value
validate_backend_setting(env_var_name, env_var_value)

# Create env_var name:value dict from the two calling methods of this routine
if len(args) == 1 and isinstance(args[0], dict):
env_vars = args[0]
else:
assert len(args) == 2, 'Expecting one dict arg or two args: env var name and value'
env_vars = {args[0]: args[1]}

# Take a snapshot of the existing environment variable state
# for those variables to be changed. get() return None for unset keys.
snapshot = {x: os.environ.get(x) for x in env_vars.keys()}

# Alter the environment per the env_vars dict
set_environ(env_vars)

# Now run the wrapped code
try:
yield
finally:
# reinstate original env_var state per the snapshot taken earlier
set_environ(snapshot)


def collapse_sum_like(a, shape):
Expand Down
35 changes: 34 additions & 1 deletion python/mxnet/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import inspect
import threading

from .base import _LIB, check_call
from .base import _LIB, check_call, c_str, py_str


_np_ufunc_default_kwargs = {
Expand Down Expand Up @@ -913,6 +913,7 @@ def get_cuda_compute_capability(ctx):
.format(ret, error_str.value.decode()))
return cc_major.value * 10 + cc_minor.value


def default_array(source_array, ctx=None, dtype=None):
"""Creates an array from any object exposing the default(nd or np) array interface.

Expand Down Expand Up @@ -1144,3 +1145,35 @@ def set_np_default_dtype(is_np_default_dtype=True): # pylint: disable=redefined
prev = ctypes.c_bool()
check_call(_LIB.MXSetIsNumpyDefaultDtype(ctypes.c_bool(is_np_default_dtype), ctypes.byref(prev)))
return prev.value


def getenv(name):
"""Get the setting of an environment variable from the C Runtime.

Parameters
----------
name : string type
The environment variable name

Returns
-------
value : string
The value of the environment variable, or None if not set
"""
ret = ctypes.c_char_p()
check_call(_LIB.MXGetEnv(c_str(name), ctypes.byref(ret)))
return None if ret.value is None else py_str(ret.value)


def setenv(name, value):
"""Set an environment variable in the C Runtime.

Parameters
----------
name : string type
The environment variable name
value : string type
The desired value to set the environment value to
"""
passed_value = None if value is None else c_str(value)
check_call(_LIB.MXSetEnv(c_str(name), passed_value))
22 changes: 22 additions & 0 deletions src/c_api/c_api_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,25 @@ int MXRemoveSubgraphPropertyOpNamesV2(const char* prop_name) {
}
API_END();
}

int MXGetEnv(const char* name,
const char** value) {
API_BEGIN();
*value = getenv(name);
API_END();
}

int MXSetEnv(const char* name,
const char* value) {
API_BEGIN();
#ifdef _WIN32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be a good idea to put a mutex? setenv is not thread safe.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sequence we follow of checkpoint_env_vars, change_env_vars, run_test, reinstate_env_vars, is fundamentally not thread-safe. I'm not making those guarantees, so no mutex is required I feel.

auto value_arg = (value == nullptr) ? "" : value;
_putenv_s(name, value_arg);
#else
if (value == nullptr)
unsetenv(name);
else
setenv(name, value, 1);
#endif
API_END();
}
16 changes: 7 additions & 9 deletions tests/python/gpu/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
import pytest
import os
import logging

from mxnet.test_utils import EnvManager
from mxnet.test_utils import environment

shapes = [(10), (100), (1000), (10000), (100000), (2,2), (2,3,4,5,6,7,8)]
keys = [1,2,3,4,5,6,7]
Expand Down Expand Up @@ -51,16 +50,15 @@ def check_dense_pushpull(kv_type):
for x in range(n_gpus):
assert(np.sum(np.abs((res[x]-n_gpus).asnumpy()))==0)

kvstore_tree_array_bound = 'MXNET_KVSTORE_TREE_ARRAY_BOUND'
kvstore_usetree_values = ['','1']
kvstore_usetree = 'MXNET_KVSTORE_USETREE'
for _ in range(2):
kvstore_tree_array_bound_values = [None, '1']
kvstore_usetree_values = [None, '1']
for y in kvstore_tree_array_bound_values:
for x in kvstore_usetree_values:
with EnvManager(kvstore_usetree, x):
with environment({'MXNET_KVSTORE_USETREE': x,
'MXNET_KVSTORE_TREE_ARRAY_BOUND': y}):
check_dense_pushpull('local')
check_dense_pushpull('device')
os.environ[kvstore_tree_array_bound] = '1'
del os.environ[kvstore_tree_array_bound]


if __name__ == '__main__':
test_device_pushpull()
17 changes: 12 additions & 5 deletions tests/python/gpu/test_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@
# specific language governing permissions and limitations
# under the License.

import sys
import os
import random
import itertools
import mxnet as mx
import numpy as np
from mxnet.test_utils import *

curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.insert(0, os.path.join(curr_path, '../unittest'))
from common import with_seed
from common import setup_module, teardown, with_seed, with_environment

def check_fused_symbol(sym, **kwargs):
inputs = sym.list_inputs()
Expand All @@ -43,10 +45,10 @@ def check_fused_symbol(sym, **kwargs):
data = {inp : kwargs[inp].astype(dtype) for inp in inputs}
for grad_req in ['write', 'add']:
type_dict = {inp : dtype for inp in inputs}
os.environ["MXNET_USE_FUSION"] = "0"
orig_exec = test_sym.simple_bind(ctx=ctx, grad_req=grad_req, type_dict=type_dict, **shapes)
os.environ["MXNET_USE_FUSION"] = "1"
fused_exec = test_sym.simple_bind(ctx=ctx, grad_req=grad_req, type_dict=type_dict, **shapes)
with environment('MXNET_USE_FUSION', '0'):
orig_exec = test_sym.simple_bind(ctx=ctx, grad_req=grad_req, type_dict=type_dict, **shapes)
with environment('MXNET_USE_FUSION', '1'):
fused_exec = test_sym.simple_bind(ctx=ctx, grad_req=grad_req, type_dict=type_dict, **shapes)
fwd_orig = orig_exec.forward(is_train=True, **data)
out_grads = [mx.nd.ones_like(arr) for arr in fwd_orig]
orig_exec.backward(out_grads=out_grads)
Expand Down Expand Up @@ -230,6 +232,7 @@ def check_other_ops():
arr2 = mx.random.uniform(shape=(2,2,2,3))
check_fused_symbol(mx.sym.broadcast_like(a, b, lhs_axes=[0], rhs_axes=[0]), a=arr1, b=arr2)


def check_leakyrelu_ops():
a = mx.sym.Variable('a')
b = mx.sym.Variable('b')
Expand Down Expand Up @@ -337,3 +340,7 @@ def test_fusion_reshape_executor():
out = f.forward(is_train=False, data1=data, data2=data)
assert out[0].sum().asscalar() == 150


if __name__ == '__main__':
import nose
nose.runmodule()
12 changes: 6 additions & 6 deletions tests/python/gpu/test_kvstore_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
import os
import mxnet as mx
import numpy as np
import pytest
from mxnet.test_utils import assert_almost_equal, default_context, EnvManager
import unittest
DickJC123 marked this conversation as resolved.
Show resolved Hide resolved
from mxnet.test_utils import assert_almost_equal, default_context, environment
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.insert(0, os.path.join(curr_path, '../unittest'))
from common import setup_module, with_seed, teardown_module
Expand Down Expand Up @@ -99,11 +99,11 @@ def check_rsp_pull(kv, ctxs, sparse_pull, is_same_rowid=False, use_slice=False):
check_rsp_pull(kv, [mx.gpu(i//2) for i in range(4)], sparse_pull, use_slice=True)
check_rsp_pull(kv, [mx.cpu(i) for i in range(4)], sparse_pull, use_slice=True)

envs = ["","1"]
key = "MXNET_KVSTORE_USETREE"
envs = [None, '1']
key = 'MXNET_KVSTORE_USETREE'
for val in envs:
with EnvManager(key, val):
if val is "1":
with environment(key, val):
if val is '1':
sparse_pull = False
else:
sparse_pull = True
Expand Down
Loading