From 8beea18e3d9835f90b59d3f9de8f9945ac819423 Mon Sep 17 00:00:00 2001 From: Dick Carter Date: Wed, 6 Mar 2019 21:58:52 -0800 Subject: [PATCH] Bulked op segments to allow Variable nodes (#14200) * Bulked op seg size to ignore Variable nodes, limited by MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_{FWD,BWD}. * Document new env variables. Unify operation with imperative. * Add timing-based tests of symbol and gluon op bulking. * Rename test_in_separate_process -> run_in_spawned_process. * Remove redundant util test_operator_gpu.py:_test_in_separate_process(). * Consolidate references to env vars that set op-bulking policy. * Test for effect of MXNET_EXEC_BULK_EXEC_TRAIN=0. * Fix python2 print() issue. * Trigger CI. * Consolidate similar op bulking routines. * Trigger CI. * Trigger CI. * Add instrumentation to debug failing CI. --- docs/faq/env_var.md | 8 +- include/mxnet/imperative.h | 23 +++++- src/executor/graph_executor.cc | 106 ++++++-------------------- src/executor/graph_executor.h | 6 +- src/imperative/cached_op.cc | 11 ++- src/imperative/cached_op.h | 4 +- tests/python/gpu/test_gluon_gpu.py | 78 +++++++++++++++++++ tests/python/gpu/test_operator_gpu.py | 99 +++++++++++++++++++----- tests/python/unittest/common.py | 50 ++++++++++++ 9 files changed, 274 insertions(+), 111 deletions(-) diff --git a/docs/faq/env_var.md b/docs/faq/env_var.md index f49cb1997691..095c214e66b3 100644 --- a/docs/faq/env_var.md +++ b/docs/faq/env_var.md @@ -115,7 +115,13 @@ $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0 - If set to `1`, during training MXNet executes the computation graph as several subgraphs in bulk mode. * MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN - Values: Int ```(default=15)``` - - The maximum number of nodes in the subgraph executed in bulk during training(not inference). Setting this to a larger number may reduce the degree of parallelism for multi-GPU training. + - The maximum number of nodes in the subgraph executed in bulk during training (not inference). Setting this to a larger number may reduce the degree of parallelism for multi-GPU training. +* MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD + - Values: Int ```(default=)``` + - The maximum number of nodes in the subgraph executed in bulk during training (not inference) in the forward pass. +* MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD + - Values: Int ```(default=)``` + - The maximum number of nodes in the subgraph executed in bulk during training (not inference) in the backward pass. ## Control the Data Communication diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h index 7ea60df33028..52cedb2fadd9 100644 --- a/include/mxnet/imperative.h +++ b/include/mxnet/imperative.h @@ -129,14 +129,31 @@ class Imperative { bool create_graph); /*! \return AutogradRuntime singleton */ static Imperative* Get(); + /*! \brief Should op execution bulking be employed during inference. */ + static bool PreferBulkExecInference() { + return dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_INFERENCE", true); + } + /*! \brief Should op execution bulking be employed during training. */ + static bool PreferBulkExecTrain() { + return dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_TRAIN", true); + } + /*! \brief The max number of op nodes in a bulk during forward pass of training. */ + static int BulkExecMaxNodeTrainFwd() { + return dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD", + dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15)); + } + /*! \brief The max number of op nodes in a bulk during backward pass of training. */ + static int BulkExecMaxNodeTrainBwd() { + return dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD", + dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15)); + } private: friend class NDArray; /*! \brief make constructor protected. */ Imperative() { - if (dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_TRAIN", 1)) { - backward_bulk_size_ = dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15); - } + if (PreferBulkExecTrain()) + backward_bulk_size_ = BulkExecMaxNodeTrainBwd(); } /*! \brief find the input/output ndarrays that are needed for backward */ void GetBackwardDependency( diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 436eae37d785..3d74bfb9a663 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -1191,105 +1191,49 @@ void GraphExecutor::InitOpSegs() { cached_seg_opr_.resize(total_num_nodes, p); if (monitor_callback_) return; + // Symbolic bulking is set by the same environment variables as Imperative bulking. // Generate segments based on the graph structure - bool prefer_bulk_exec_inference = dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_INFERENCE", true); + bool prefer_bulk_exec_inference = Imperative::PreferBulkExecInference(); // Whether to perform bulk exec for training const profiler::Profiler *prof = profiler::Profiler::Get(); - bool prefer_bulk_exec = dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_TRAIN", 1) - && (!prof || !prof->AggregateEnabled()); + bool prefer_bulk_exec_train = Imperative::PreferBulkExecTrain() + && (!prof || !prof->AggregateEnabled()); bool is_training = num_forward_nodes_ != total_num_nodes; - if (prefer_bulk_exec && is_training) { - this->BulkTrainingOpSegs(total_num_nodes); + if (prefer_bulk_exec_train && is_training) { + // Bulk the forward portion of the graph per the bulk segment max size for forward training + this->BulkOpSegs(0, num_forward_nodes_, Imperative::BulkExecMaxNodeTrainFwd()); + // Bulk the backward portion of the graph per the bulk segment max size for backward training + this->BulkOpSegs(num_forward_nodes_, total_num_nodes, Imperative::BulkExecMaxNodeTrainBwd()); } if (prefer_bulk_exec_inference && !is_training) { - this->BulkInferenceOpSegs(); + // Bulk the entire graph as one bulk segment if possible + this->BulkOpSegs(0, total_num_nodes, total_num_nodes); } } -void GraphExecutor::BulkTrainingOpSegs(size_t total_num_nodes) { - // The maximum number of node in a segment executed in bulk - size_t num_nodes_threshold = dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15); - - // create forward segments for training - size_t topo_start = 0; - for (size_t nid = 0; nid < num_forward_nodes_; nid++) { - auto &node = graph_.indexed_graph()[nid].source; - auto &op_node = op_nodes_[nid]; - // check if the segment relies on external input, or exceeds maxinum number of node, - // or requires async ops - if (node->is_variable() || nid - topo_start > num_nodes_threshold || - op_node.exec->exec_type() != ExecType::kSync) { - // create a new segment for the previous nodes if the current one cannot be bulked - cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, nid); - topo_start = nid + 1; - } - } - // the last segment - if (topo_start != num_forward_nodes_) { - cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, num_forward_nodes_); - } - - // create backward segments for training - // get all gradient variables - std::unordered_set grad_vars; - for (auto &kv : grad_store_) { - grad_vars.insert(kv.second.var()); - } - auto &idx = graph_.indexed_graph(); - topo_start = num_forward_nodes_; - for (size_t nid = num_forward_nodes_; nid < total_num_nodes; nid++) { - auto &op_node = op_nodes_[nid]; - if (op_node.skip_exec_node || op_node.exec == nullptr) { - continue; - } - if (idx[nid].source->is_variable() || nid - topo_start > num_nodes_threshold || - op_node.exec->exec_type() != ExecType::kSync) { - cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, nid); - topo_start = nid + 1; - } else { - // If it produces output gradient, don't include it in the segment - bool output_gradient = false; - for (auto &out_arr : op_node.exec->out_array) { - if (grad_vars.find(out_arr.var()) != grad_vars.end()) { - output_gradient = true; - } - } - if (output_gradient) { - cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, nid); - topo_start = nid + 1; - } - } - } - // last segment for backward - if (topo_start < total_num_nodes) { - cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, total_num_nodes); - } -} - -void GraphExecutor::BulkInferenceOpSegs() { - // Attempt to bulk the whole graph for inference. We will only create new segments when - // required for non-kSync operations. - size_t topo_start = 0; - for (size_t nid = 0; nid < num_forward_nodes_; nid++) { +void GraphExecutor::BulkOpSegs(size_t from_node, size_t up_to_node, size_t segment_num_nodes_max) { + size_t topo_start = from_node; + size_t segment_node_count = 0; + for (size_t nid = from_node; nid < up_to_node; nid++) { auto &node = graph_.indexed_graph()[nid].source; auto &op_node = op_nodes_[nid]; - - // Variables do not need to be segmented at inference time. - if (node->is_variable()) continue; - - if (op_node.exec->exec_type() != ExecType::kSync) { - cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, nid); + // Variables, such as learned weights, are ignored in the segment_node_count + bool ignore_node = node->is_variable() || op_node.skip_exec_node || op_node.exec == nullptr; + if (!ignore_node) + segment_node_count++; + bool can_bulk = ignore_node || op_node.exec->exec_type() == ExecType::kSync; + // check if we need to create the segment based on properties of this node + if (!can_bulk || nid == up_to_node - 1 || segment_node_count >= segment_num_nodes_max) { + // Create a new segment for the previous nodes- include also this node if it's bulkable + cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, can_bulk ? nid + 1 : nid); topo_start = nid + 1; + segment_node_count = 0; } } - // The last segment - if (topo_start != num_forward_nodes_) { - cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, num_forward_nodes_); - } } void GraphExecutor::ExecuteMonInputCallback(size_t nid) { diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h index ed49e5bc8bc9..b556a2bd0fe9 100644 --- a/src/executor/graph_executor.h +++ b/src/executor/graph_executor.h @@ -213,10 +213,8 @@ class GraphExecutor : public Executor { void ExecuteMonInputCallback(size_t nid); // run the monitor callback for output of node `nid` void ExecuteMonOutputCallback(size_t nid); - // peform bulking and segmentation on an inference graph - void BulkInferenceOpSegs(); - // perform bulking and segmentation on a training graph - void BulkTrainingOpSegs(size_t total_num_nodes); + // peform bulking and segmentation on the region [from_node, up_to_node) of a graph + void BulkOpSegs(size_t from_node, size_t up_to_node, size_t segment_num_nodes_max); // indicate whether there is a backward graph for gradients. bool need_grad_; // internal graph diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index 61dfb9c9423c..c9215c5c8827 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -619,9 +619,18 @@ void CachedOp::StaticInitExec( SetupOpExec(g, i, state.execs[i], state.arrays, state.array_reqs); } + // Init bulk_size for Inference mode with bulking enabled (= entire forward graph). size_t bulk_size = idx.num_nodes(); if (recording || keep_fwd) { - bulk_size = keep_fwd ? config_.backward_bulk_size : config_.forward_bulk_size; + // Training mode + if (!Imperative::PreferBulkExecTrain()) + bulk_size = 0; + else + bulk_size = keep_fwd ? config_.backward_bulk_size : config_.forward_bulk_size; + } else { + // Inference mode + if (!Imperative::PreferBulkExecInference()) + bulk_size = 0; } CreateEngineOpSeg(idx, default_ctx, start_nid, end_nid, bulk_size, diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h index 5a0351ad6fa7..b3192dc8281b 100644 --- a/src/imperative/cached_op.h +++ b/src/imperative/cached_op.h @@ -53,10 +53,10 @@ struct CachedOpConfig : public dmlc::Parameter { .set_default(2) .describe("Maximum number of operators that can be inlined."); DMLC_DECLARE_FIELD(forward_bulk_size) - .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15)) + .set_default(Imperative::BulkExecMaxNodeTrainFwd()) .describe("Segment size of bulk execution during forward pass."); DMLC_DECLARE_FIELD(backward_bulk_size) - .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15)) + .set_default(Imperative::BulkExecMaxNodeTrainBwd()) .describe("Segment size of bulk execution during backward pass."); DMLC_DECLARE_FIELD(data_indices) .set_default(nnvm::Tuple()) diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index 54bfcee47347..88b436a0deb2 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -38,6 +38,7 @@ curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) sys.path.insert(0, os.path.join(curr_path, '../unittest')) from common import setup_module, with_seed, teardown, assert_raises_cudnn_not_satisfied +from common import run_in_spawned_process from test_gluon import * from test_loss import * from test_gluon_rnn import * @@ -408,6 +409,83 @@ def tensor_size(big_tensor_bytes): # Evaluate model net(data_in).asnumpy() +# isolated execution bulking test function to be invoked with different env var settings +def _test_bulking_in_process(seed, time_per_iteration): + # Use flip since it's a simple function with same-sized I/O unlikely to ever be fused. + class Flip(gluon.HybridBlock): + def __init__(self, **kwargs): + super(Flip, self).__init__(**kwargs) + + def hybrid_forward(self, F, x): + return F.flip(x, axis=0) + + def get_net(num_ops): + net = nn.HybridSequential() + with net.name_scope(): + for _ in range(num_ops): + net.add(Flip()) + return net + + data_shape = (10,) + num_ops = 1000 + num_iterations = 20 + + # build model + x = mx.ndarray.zeros(data_shape) + x.attach_grad() + dy = mx.ndarray.ones(data_shape) + net = get_net(num_ops) + net.hybridize(static_alloc=True, static_shape=True) + + # time a number of forward() and backward() executions after some warm-up iterations + warmups = 1 + for i in range(num_iterations+warmups): + with autograd.record(): + if i == warmups: + start = time.time() + y = net(x) + y.backward(dy) + x.grad.wait_to_read() + + time_per_iteration.value = (time.time() - start) / num_iterations + +@with_seed() +def test_bulking(): + # test case format: (max_fwd_segment_size, max_bwd_segment_size, enable_bulking_in_training) + test_cases = [(0,0,True), (1,1,True), (15,15,False), (15,0,True), (0,15,True), (15,15,True)] + times = {} + times_str = '' + for seg_sizes in test_cases: + # Create shared variable to return measured time from test process + time_per_iteration = mp.Manager().Value('d', 0.0) + if not run_in_spawned_process(_test_bulking_in_process, + {'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD' : seg_sizes[0], + 'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD' : seg_sizes[1], + 'MXNET_EXEC_BULK_EXEC_TRAIN' : seg_sizes[2]}, + time_per_iteration): + # skip test since the python version can't run it properly. Warning msg was logged. + return + times[seg_sizes] = time_per_iteration.value + times_str += \ + '\n runtime of (fwd,bwd,enable) op seg setting ({},{},{}) =\t{:.1f} msec'.format( + seg_sizes[0], seg_sizes[1], seg_sizes[2], 1000.0 * times[seg_sizes]) + + fastest_non_bulked_time = min(times[(0,0,True)], times[(1,1,True)], times[(15,15,False)]) + slowest_half_bulked_time = max(times[(0,15,True)], times[(15,0,True)]) + fastest_half_bulked_time = min(times[(0,15,True)], times[(15,0,True)]) + fully_bulked_time = times[(15,15,True)] + + print(times_str) + # Non-bulked times[0,0,True], times[1,1,True] and times[15,15,False] should be about the same, + # slower than both half-bulked times[0,15,True] and times[15,0,True] + assert slowest_half_bulked_time < fastest_non_bulked_time, \ + 'A half-bulked exec time is slower than the non-bulked time by {} secs! {}' \ + .format(slowest_half_bulked_time - fastest_non_bulked_time, times_str) + # The fully bulked times[15,15,True] should be faster than both half-bulked runs + assert fully_bulked_time < fastest_half_bulked_time, \ + 'The fully-bulked exec time is slower than a half-bulked time by {} secs! {}' \ + .format(fully_bulked_time - fastest_half_bulked_time, times_str) + if __name__ == '__main__': import nose diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index f3299163f323..7d7c2ed71216 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -33,6 +33,7 @@ curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) sys.path.insert(0, os.path.join(curr_path, '../unittest')) from common import setup_module, with_seed, teardown, assert_raises_cudnn_not_satisfied +from common import run_in_spawned_process from test_operator import * from test_optimizer import * from test_random import * @@ -521,24 +522,6 @@ def test_convolution_options(): check_consistency_NxM([sym, sym_no_cudnn], ctx_list) -# Helper function to run tests in a subprocess to avoid save/restore of os.environ. -# Also avoids issues of cached environment variable lookups in the backend. -def _test_in_separate_process(func, env, *args): - try: - mpctx = mp.get_context('spawn') - except: - print('SKIP: python%s.%s lacks the required process fork-exec support ... ' % - sys.version_info[0:2], file=sys.stderr, end='') - else: - seed = np.random.randint(0,1024*1024*1024) - for (key, value) in env.items(): - os.environ[key] = str(value) - # Prepend seed as first arg - p = mpctx.Process(target=func, args=(seed,)+args) - p.start() - p.join() - assert p.exitcode == 0, "Non-zero exit code %d from %s()." % (p.exitcode, func.__name__) - def _conv_with_num_streams(seed): with random_seed(seed): # Try to expose timing-dependent improper workspace sharing by parallel dgrad and wgrad @@ -576,8 +559,10 @@ def test_convolution_multiple_streams(): for num_streams in [1, 2]: for engine in engines: - _test_in_separate_process(_conv_with_num_streams, + print("Starting engine %s with %d streams." % (engine, num_streams), file=sys.stderr) + run_in_spawned_process(_conv_with_num_streams, {'MXNET_GPU_WORKER_NSTREAMS' : num_streams, 'MXNET_ENGINE_TYPE' : engine}) + print("Finished engine %s with %d streams." % (engine, num_streams), file=sys.stderr) # This test is designed to expose an issue with cudnn v7.1.4 algo find() when invoked with large c. @@ -2127,6 +2112,82 @@ def test_bilinear_sampler_versions(): assert_almost_equal(exe.grad_dict['grid'].asnumpy(), exe_list[ref_idx].grad_dict['grid'].asnumpy(), rtol=1e-3, atol=1e-5) +@with_seed() +def test_bulking(): + # Return the execution time of a model with the specified limits to the bulked op segments + def test_bulking_helper(data_shape, num_ops, num_iterations, + max_fwd_segment_size, max_bwd_segment_size, enable_bulking_in_training): + orig_environ = os.environ.copy() + try: + # Explore different ways of setting the env vars. + # The framework does not cache the bulked seg size env var lookups during symbolic. + os.environ['MXNET_EXEC_BULK_EXEC_TRAIN'] = str(enable_bulking_in_training) + if max_fwd_segment_size == max_bwd_segment_size: + os.environ['MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN'] = str(max_fwd_segment_size) + os.environ.pop('MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD', None) + os.environ.pop('MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD', None) + else: + os.environ.pop('MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN', None) + os.environ['MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD'] = str(max_fwd_segment_size) + os.environ['MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD'] = str(max_bwd_segment_size) + + ctx = default_context() + # build symbol + X = mx.sym.Variable('X') + sym = mx.sym.flip(X, axis=0) + for _ in range(num_ops-1): + sym = mx.sym.flip(sym, axis=0) + x = mx.ndarray.zeros(data_shape) + dx = mx.ndarray.zeros(data_shape) + dy = mx.ndarray.ones(data_shape) + exe = sym.bind(ctx=ctx, args=[x], args_grad = {'X':dx}) + + # time a number of forward() and backward() executions after some warm-up iterations + warmups = 1 + for i in range(num_iterations+warmups): + if i == warmups: + start = time.time() + exe.forward(is_train=True) + exe.backward(dy) + dx.wait_to_read() + time_per_iteration = (time.time() - start) / num_iterations + finally: + os.environ.clear() + os.environ.update(orig_environ) + return time_per_iteration + + data_shape = (10,) + num_ops = 1000 + num_iterations = 20 + + # test case format: (max_fwd_segment_size, max_bwd_segment_size, enable_bulking_in_training) + test_cases = [(0,0,True), (1,1,True), (15,15,False), (15,0,True), (0,15,True), (15,15,True)] + times = {} + times_str = '' + for seg_sizes in test_cases: + times[seg_sizes] = test_bulking_helper(data_shape, num_ops, num_iterations, + seg_sizes[0], seg_sizes[1], seg_sizes[2]) + times_str +=\ + '\n runtime of (fwd,bwd,enable) op seg setting ({},{},{}) =\t{:.1f} msec'.format( + seg_sizes[0], seg_sizes[1], seg_sizes[2], 1000.0 * times[seg_sizes]) + + fastest_non_bulked_time = min(times[(0,0,True)], times[(1,1,True)], times[(15,15,False)]) + slowest_half_bulked_time = max(times[(0,15,True)], times[(15,0,True)]) + fastest_half_bulked_time = min(times[(0,15,True)], times[(15,0,True)]) + fully_bulked_time = times[(15,15,True)] + + print(times_str) + # Non-bulked times[0,0,True], times[1,1,True] and times[15,15,False] should be about the same, + # slower than both half-bulked times[0,15,True] and times[15,0,True] + assert slowest_half_bulked_time < fastest_non_bulked_time,\ + 'A half-bulked exec time is slower than the non-bulked time by {} secs! {}'\ + .format(slowest_half_bulked_time - fastest_non_bulked_time, times_str) + # The fully bulked times[15,15,True] should be faster than both half-bulked runs + assert fully_bulked_time < fastest_half_bulked_time,\ + 'The fully-bulked exec time is slower than a half-bulked time by {} secs! {}'\ + .format(fully_bulked_time - fastest_half_bulked_time, times_str) + + def test_context_num_gpus(): # Test that num_gpus reports at least one GPU, as the test is run on a GPU host. assert mx.context.num_gpus() > 0 diff --git a/tests/python/unittest/common.py b/tests/python/unittest/common.py index abfba73ab727..7cd637da3d4f 100644 --- a/tests/python/unittest/common.py +++ b/tests/python/unittest/common.py @@ -15,7 +15,9 @@ # specific language governing permissions and limitations # under the License. +from __future__ import print_function import sys, os, logging +import multiprocessing as mp import mxnet as mx import numpy as np import random @@ -39,6 +41,7 @@ def assertRaises(expected_exception, func, *args, **kwargs): # Did not raise exception assert False, "%s did not raise %s" % (func.__name__, expected_exception.__name__) + def default_logger(): """A logger used to output seed information to nosetests logs.""" logger = logging.getLogger(__name__) @@ -51,6 +54,7 @@ def default_logger(): logger.setLevel(logging.INFO) return logger + @contextmanager def random_seed(seed=None): """ @@ -181,6 +185,7 @@ def test_new(*args, **kwargs): return test_new return test_helper + def setup_module(): """ A function with a 'magic name' executed automatically before each nosetests module @@ -265,3 +270,48 @@ def teardown(): It waits for all operations in one file to finish before carrying on the next. """ mx.nd.waitall() + + +def run_in_spawned_process(func, env, *args): + """ + Helper function to run a test in its own process. + + Avoids issues with Singleton- or otherwise-cached environment variable lookups in the backend. + Adds a seed as first arg to propagate determinism. + + Parameters + ---------- + + func : function to run in a spawned process. + env : dict of additional environment values to set temporarily in the environment before exec. + args : args to pass to the function. + + Returns + ------- + Whether the python version supports running the function as a spawned process. + + This routine calculates a random seed and passes it into the test as a first argument. If the + test uses random values, it should include an outer 'with random_seed(seed):'. If the + test needs to return values to the caller, consider use of shared variable arguments. + """ + try: + mpctx = mp.get_context('spawn') + except: + print('SKIP: python%s.%s lacks the required process fork-exec support ... ' % + sys.version_info[0:2], file=sys.stderr, end='') + return False + else: + seed = np.random.randint(0,1024*1024*1024) + orig_environ = os.environ.copy() + try: + for (key, value) in env.items(): + os.environ[key] = str(value) + # Prepend seed as first arg + p = mpctx.Process(target=func, args=(seed,)+args) + p.start() + p.join() + assert p.exitcode == 0, "Non-zero exit code %d from %s()." % (p.exitcode, func.__name__) + finally: + os.environ.clear() + os.environ.update(orig_environ) + return True \ No newline at end of file