diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index 6089edae5a56..e839c4b722be 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -1020,3 +1020,26 @@ def set_env_var(key, val, default_val=""): prev_val = os.environ.get(key, default_val) os.environ[key] = val return prev_val + +def same_array(array1, array2): + """Check whether two NDArrays sharing the same memory block + + Parameters + ---------- + + array1 : NDArray + First NDArray to be checked + array2 : NDArray + Second NDArray to be checked + + Returns + ------- + bool + Whether two NDArrays share the same memory + """ + array1[:] += 1 + if not same(array1.asnumpy(), array2.asnumpy()): + array1[:] -= 1 + return False + array1[:] -= 1 + return same(array1.asnumpy(), array2.asnumpy()) \ No newline at end of file diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py index 5508a37c9567..4dde5a60b8e3 100644 --- a/tests/python/unittest/test_module.py +++ b/tests/python/unittest/test_module.py @@ -2,6 +2,7 @@ import mxnet.ndarray as nd import numpy as np from functools import reduce +from mxnet.module.executor_group import DataParallelExecutorGroup def test_module_dtype(): dtype = np.float16 @@ -254,6 +255,88 @@ def mean_abs(x): break assert(mon_result_counts == [2, 2, 1, 6, 6, 4]) +def test_executor_group(): + def test_create_exec_group(exec_grp_shared, exec_grp_created, + shared_arg_names, extra_input=[], extra_arg=[]): + # Test shared data arrays + for i in range(len(exec_grp_shared.execs)): + for data_name, array in exec_grp_shared.shared_data_arrays[i].items(): + assert data_name in exec_grp_created.shared_data_arrays[i], \ + "Shared input data '%s' is not in " \ + "shared_data_arrays of created executor group." % (data_name) + assert mx.test_utils.same_array(array, exec_grp_created.shared_data_arrays[i][data_name]), \ + "Shared input data '%s' does not share memory." % (data_name) + for input_name in extra_input: + assert input_name in exec_grp_created.execs[i].arg_dict, \ + "Extra input data '%s' is not in arg_dict of created executor group." % (input_name) + + # Test shared argument arrays and gradient arrays + for i in range(len(exec_grp_shared.execs)): + exec1 = exec_grp_shared.execs[i] + exec2 = exec_grp_created.execs[i] + for arg_name in shared_arg_names: + assert arg_name in exec2.arg_dict, \ + "Shared argument '%s' is not in arg_dict of created executor group." % (arg_name) + assert mx.test_utils.same_array(exec1.arg_dict[arg_name], exec2.arg_dict[arg_name]), \ + "Shared argument '%s' does not share memory." % (arg_name) + for arg_name in extra_arg: + assert arg_name in exec2.arg_dict, \ + "Extra argument '%s' is not in arg_dict of created executor group." % (arg_name) + for arg_name, grad in exec_grp_shared.grad_req.items(): + assert grad == exec_grp_created.grad_req[arg_name], \ + "Gradient requirements for shared argument '%s' are inconsistent. " \ + "Shared executor group requires '%s' while created executor group requires '%s'" \ + %(arg_name, grad, exec_grp_created.grad_req[arg_name]) + for arg_name in shared_arg_names: + assert arg_name in exec2.grad_dict, \ + "Shared argument gradient '%s' is not in " \ + "grad_dict of created executor group." % (arg_name) + assert mx.test_utils.same_array(exec1.grad_dict[arg_name], exec2.grad_dict[arg_name]), \ + "Shared argument gradient '%s' does not sharing memory." % (arg_name) + + contexts = [mx.cpu(0), mx.cpu(1)] + workload = [1] * len(contexts) + batch_size = 16 + num_hidden = 4 + data_shapes1 = [('data1', (batch_size, 10))] + data_shapes2 = [('data1', (batch_size, 10)), ('data2', (batch_size, 10))] + label_shapes = [('softmax_label', (batch_size,))] + + data1 = mx.sym.Variable('data1') + data2 = mx.sym.Variable('data2') + fc1 = mx.sym.FullyConnected(data=data1, name='fc1', num_hidden=num_hidden) + mlp1 = mx.sym.SoftmaxOutput(data=fc1, name='softmax') + fc1 = mx.sym.FullyConnected(data=data1 + data2, name='fc1', num_hidden=num_hidden) + fc2 = mx.sym.FullyConnected(data=fc1, name='fc2', num_hidden=num_hidden) + mlp2 = mx.sym.SoftmaxOutput(data=fc2, name='softmax') + + arg_names = mlp1.list_arguments() + input_names = [name[0] for name in data_shapes1] + [name[0] for name in label_shapes] + shared_arg_names = [name for name in arg_names if name not in input_names] + + exec_group1 = DataParallelExecutorGroup(symbol=mlp1, contexts=contexts, + workload=workload, data_shapes=data_shapes1, + label_shapes=label_shapes, param_names=shared_arg_names, + for_training=True, inputs_need_grad=False) + + # Test two executor groups with the same symbol sharing memory + exec_group2 = DataParallelExecutorGroup(symbol=mlp1, contexts=contexts, + workload=workload, data_shapes=data_shapes1, + label_shapes=label_shapes, param_names=shared_arg_names, + for_training=True, inputs_need_grad=False, + shared_group=exec_group1) + test_create_exec_group(exec_group1, exec_group2, shared_arg_names) + + # Test two executor groups with different symbol sharing memory + exec_group3 = DataParallelExecutorGroup(symbol=mlp2, contexts=contexts, + workload=workload, data_shapes=data_shapes2, + label_shapes=label_shapes, param_names=shared_arg_names, + for_training=True, inputs_need_grad=False, + shared_group=exec_group1) + extra_input = ['data2'] + extra_arg = ['fc2_weight', 'fc2_bias'] + test_create_exec_group(exec_group1, exec_group3, shared_arg_names, extra_input, extra_arg) + if __name__ == '__main__': test_module_dtype() test_module_input_grads() @@ -263,3 +346,4 @@ def mean_abs(x): test_module_layout() test_module_switch_bucket() test_monitor() + test_executor_group()