Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Refactor kvstore test #13140

Merged
merged 4 commits into from
Nov 12, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1998,3 +1998,20 @@ def compare_optimizer(opt1, opt2, shape, dtype, w_stype='default', g_stype='defa
if compare_states:
compare_ndarray_tuple(state1, state2, rtol=rtol, atol=atol)
assert_almost_equal(w1.asnumpy(), w2.asnumpy(), rtol=rtol, atol=atol)

class EnvManager(object):
"""Environment variable setter and unsetter via with idiom"""
def __init__(self, key, val):
self._key = key
self._next_val = val
self._prev_val = None

def __enter__(self):
self._prev_val = os.environ.get(self._key)
os.environ[self._key] = self._next_val

def __exit__(self, ptype, value, trace):
if self._prev_val:
os.environ[self._key] = self._prev_val
else:
del os.environ[self._key]
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ object Image {
def imDecode(buf: Array[Byte], flag: Int,
to_rgb: Boolean,
out: Option[NDArray]): NDArray = {
val nd = NDArray.array(buf.map(_.toFloat), Shape(buf.length))
val nd = NDArray.array(buf.map( x => (x & 0xFF).toFloat), Shape(buf.length))
val byteND = NDArray.api.cast(nd, "uint8")
val args : ListBuffer[Any] = ListBuffer()
val map : mutable.Map[String, Any] = mutable.Map()
Expand Down
46 changes: 15 additions & 31 deletions tests/python/gpu/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,35 +19,23 @@
import numpy as np
import unittest
import os
import logging

from mxnet.test_utils import EnvManager

shapes = [(10), (100), (1000), (10000), (100000), (2,2), (2,3,4,5,6,7,8)]
keys = [1,2,3,4,5,6,7]
num_gpus = len(mx.test_utils.list_gpus())
num_gpus = mx.context.num_gpus()


if num_gpus > 8 :
print("The machine has {} gpus. We will run the test on 8 gpus.".format(num_gpus))
print("There is a limit for all PCI-E hardware on creating number of P2P peers. The limit is 8.")
logging.warn("The machine has {} gpus. We will run the test on 8 gpus.".format(num_gpus))
logging.warn("There is a limit for all PCI-E hardware on creating number of P2P peers. The limit is 8.")
num_gpus = 8;

gpus = range(1, 1+num_gpus)

class EnvManager:
def __init__(self, key, val):
self._key = key
self._next_val = val
self._prev_val = None

def __enter__(self):
try:
self._prev_val = os.environ[self._key]
except KeyError:
self._prev_val = ''
os.environ[self._key] = self._next_val

def __exit__(self, ptype, value, trace):
os.environ[self._key] = self._prev_val

@unittest.skipIf(mx.context.num_gpus() < 1, "test_device_pushpull needs at least 1 GPU")
def test_device_pushpull():
def check_dense_pushpull(kv_type):
for shape, key in zip(shapes, keys):
Expand All @@ -63,20 +51,16 @@ def check_dense_pushpull(kv_type):
for x in range(n_gpus):
assert(np.sum(np.abs((res[x]-n_gpus).asnumpy()))==0)

envs1 = '1'
key1 = 'MXNET_KVSTORE_TREE_ARRAY_BOUND'
envs2 = ['','1']
key2 = 'MXNET_KVSTORE_USETREE'
for i in range(2):
for val2 in envs2:
with EnvManager(key2, val2):
kvstore_tree_array_bound = 'MXNET_KVSTORE_TREE_ARRAY_BOUND'
kvstore_usetree_values = ['','1']
kvstore_usetree = 'MXNET_KVSTORE_USETREE'
for _ in range(2):
for x in kvstore_usetree_values:
with EnvManager(kvstore_usetree, x):
check_dense_pushpull('local')
check_dense_pushpull('device')

os.environ[key1] = envs1
os.environ[key1] = ''

print ("Passed")
os.environ[kvstore_tree_array_bound] = '1'
del os.environ[kvstore_tree_array_bound]

if __name__ == '__main__':
test_device_pushpull()
18 changes: 1 addition & 17 deletions tests/python/gpu/test_kvstore_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import mxnet as mx
import numpy as np
import unittest
from mxnet.test_utils import assert_almost_equal, default_context
from mxnet.test_utils import assert_almost_equal, default_context, EnvManager
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.insert(0, os.path.join(curr_path, '../unittest'))
from common import setup_module, with_seed, teardown
Expand All @@ -30,22 +30,6 @@
keys = [5, 7, 11]
str_keys = ['b', 'c', 'd']

class EnvManager:
def __init__(self, key, val):
self._key = key
self._next_val = val
self._prev_val = None

def __enter__(self):
try:
self._prev_val = os.environ[self._key]
except KeyError:
self._prev_val = ''
os.environ[self._key] = self._next_val

def __exit__(self, ptype, value, trace):
os.environ[self._key] = self._prev_val

def init_kv_with_str(stype='default', kv_type='local'):
"""init kv """
kv = mx.kv.create(kv_type)
Expand Down