From de254591b33196cc50982aa51e8cd3343d6d0f6c Mon Sep 17 00:00:00 2001 From: luke metz Date: Tue, 15 Dec 2015 11:04:39 -0500 Subject: [PATCH] refactor model.py _train_multi_device --- python/mxnet/executor.py | 213 ++++++++++++++++++++++++++++++- python/mxnet/model.py | 266 ++++++++++++--------------------------- 2 files changed, 294 insertions(+), 185 deletions(-) diff --git a/python/mxnet/executor.py b/python/mxnet/executor.py index 05a269d1e2db..3e357fb18974 100644 --- a/python/mxnet/executor.py +++ b/python/mxnet/executor.py @@ -1,5 +1,5 @@ # coding: utf-8 -# pylint: disable=invalid-name, protected-access, too-many-locals +# pylint: disable=invalid-name, protected-access, too-many-locals, too-many-arguments """Symbolic Executor component of MXNet.""" from __future__ import absolute_import @@ -8,6 +8,9 @@ from .base import mx_uint, NDArrayHandle, ExecutorHandle from .base import check_call, c_array, py_str from .ndarray import NDArray +from . import ndarray as nd +from .context import cpu +import logging class Executor(object): """ Executor is the actual executing object of MXNet.""" @@ -216,3 +219,211 @@ def debug_str(self): check_call(_LIB.MXExecutorPrint( self.handle, ctypes.byref(debug_str))) return py_str(debug_str.value) + +def _split_input_slice(batch_size, work_load_list): + """Get input slice from the input shape. + Parameters + ---------- + batch_size : int + The number of samples in a mini-batch. + work_load_list : list of float or int, optional + The list of work load for different devices, + in the same order as ctx + Returns + ------- + slices : list of slice + The split slices to get a specific slice. + Raises + ------ + ValueError + If there are two many splits such that some slice can be empty. + """ + total_work_load = sum(work_load_list) + batch_num_list = [round(work_load * batch_size / total_work_load) + for work_load in work_load_list] + batch_num_sum = sum(batch_num_list) + if batch_num_sum < batch_size: + batch_num_list[-1] += batch_size - batch_num_sum + slices = [] + end = 0 + for batch_num in batch_num_list: + begin = int(min((end, batch_size))) + end = int(min((begin + batch_num, batch_size))) + if begin >= end: + raise ValueError('Too many slices such that some splits are empty') + slices.append(slice(begin, end)) + return slices + +def _check_arguments(symbol): + """Check the argument names of symbol. + This function checks the duplication of arguments in Symbol. + The check is done for feedforward net for now. + Parameters + ---------- + symbol : Symbol + The network configuration + """ + arg_set = set() + arg_names = symbol.list_arguments() + for name in arg_names: + if name in arg_set: + raise ValueError(('Find duplicated argument name \"%s\", ' + + 'please make the weight name non-duplicated(using name arguments), ' + + 'arguments are %s') % (name, str(arg_names))) + arg_set.add(name) + + aux_set = set() + aux_names = symbol.list_auxiliary_states() + for name in aux_names: + if name in aux_set: + raise ValueError( + ('Find duplicated auxiliary param name \"%s\", ' + + 'please make the weight name non-duplicated(using name arguments), ' + + 'arguments are %s, auxiliary params are %s' + ) % (name, str(arg_names), str(aux_names))) + aux_set.add(name) + +def _load_general(data, targets): + """Load a list of arrays into a list of arrays specified by slices""" + for d_src, d_targets in zip(data, targets): + if isinstance(d_targets, nd.NDArray): + d_src.copyto(d_targets) + else: + for slice_idx, d_dst in d_targets: + d_src[slice_idx].copyto(d_dst) + +def _load_data(batch, targets): + """Load data into sliced arrays""" + _load_general(batch.data, targets) + +def _load_label(batch, targets): + """Load label into sliced arrays""" + _load_general(batch.label, targets) + +class DataParallelExecutorManager(object): + """ Helper class to manage multiple executors for data parallelism. + Parameters + ---------- + symbol : Symbol + output symbol + ctx : list of Context + devices to run on + param_names: list of str + Name of all trainable parameters of the network. + arg_names: list of str + Name of all arguments of the network. + aux_names: list of str + Name of all auxiliary states of the network. + train_data : DataIter + Training data iterator. + work_load_list : list of float or int, optional + The list of work load for different devices, + in the same order as ctx + logger : logging logger + When not specified, default logger will be used. + """ + def __init__(self, symbol, ctx, train_data, + param_names, arg_names, aux_names, + work_load_list=None, logger=None): + if logger is None: + logger = logging + # preparation + num_device = len(ctx) + logger.info('Start training with %s', str(ctx)) + + # make sure the architecture is valid + _check_arguments(symbol) + + if work_load_list is None: + work_load_list = [1] * num_device + assert isinstance(work_load_list, list) and len(work_load_list) == num_device, \ + "Invalid settings for work load. " + + slices = _split_input_slice(train_data.batch_size, work_load_list) + self.slices = slices + + self.train_execs = [] + for i in range(len(ctx)): + data_shapes = {k: tuple([slices[i].stop-slices[i].start] + list(v[1:])) + for k, v in train_data.provide_data} + train_exec = symbol.simple_bind(ctx[i], 'write', **data_shapes) + self.train_execs.append(train_exec) + + # data structure + self.data_names = [x[0] for x in train_data.provide_data] + self.label_names = [x[0] for x in train_data.provide_label] + self.aux_names = aux_names + + self.data_arrays = [[(slices[i], e.arg_dict[name]) for i, e in enumerate(self.train_execs)] + for name in self.data_names] + self.label_arrays = [[(slices[i], e.arg_dict[name]) for i, e in enumerate(self.train_execs)] + for name in self.label_names] + + self.param_idx = [i for i in range(len(arg_names)) if arg_names[i] in param_names] + self.param_names = [arg_names[i] for i in self.param_idx] + self.param_arrays = [[e.arg_arrays[i] for e in self.train_execs] + for i in self.param_idx] + self.grad_arrays = [[e.grad_arrays[i] for e in self.train_execs] + for i in self.param_idx] + + self.aux_arrays = [[e.aux_arrays[i] for e in self.train_execs] + for i in range(len(aux_names))] + + batch_size = train_data.batch_size + + output_shapes = [tuple([batch_size]+list(x.shape[1:])) for x in self.train_execs[0].outputs] + self.cpu_output_arrays = [nd.zeros(s) for s in output_shapes] + + def install_monitor(self, monitor): + """ Install monitor on all executors """ + for train_exec in self.train_execs: + monitor.install(train_exec) + + def set_params(self, arg_params, aux_params): + """ set parameter and aux values + Parameters + ---------- + arg_params : list of NDArray + source parameter arrays + aux_params : list of NDArray + source aux arrays + """ + + for texec in self.train_execs: + texec.copy_params_from(arg_params, aux_params) + + def copy_to(self, arg_params, aux_params): + """ Copy data from each executor to `arg_params` and `aux_params` + Parameters + ---------- + arg_params : list of NDArray + target parameter arrays + aux_params : list of NDArray + target aux arrays + Notes + ----- + - This function will inplace update the NDArrays in arg_params and aux_params. + """ + for name, block in zip(self.param_names, self.param_arrays): + weight = sum(w.copyto(cpu()) for w in block) / len(block) + weight.copyto(arg_params[name]) + for name, block in zip(self.aux_names, self.aux_arrays): + weight = sum(w.copyto(cpu()) for w in block) / len(block) + weight.copyto(aux_params[name]) + + def load_data_batch(self, data_batch): + """ load data and labels into arrays """ + _load_data(data_batch, self.data_arrays) + _load_label(data_batch, self.label_arrays) + + def forward(self, is_train=False): + """ Perform a forward pass on each executor """ + for texec, islice in zip(self.train_execs, self.slices): + texec.forward(is_train=is_train) + for cpu_out, dev_out in zip(self.cpu_output_arrays, texec.outputs): + dev_out.copyto(cpu_out[islice]) + + def backward(self): + """ Perform a backward pass on each executor """ + for texec in self.train_execs: + texec.backward() diff --git a/python/mxnet/model.py b/python/mxnet/model.py index 70774321fdcc..5d9288406f7a 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -16,6 +16,7 @@ from .initializer import Uniform from collections import namedtuple from .optimizer import get_updater +from .executor import DataParallelExecutorManager, _check_arguments, _load_data BASE_ESTIMATOR = object @@ -31,86 +32,6 @@ 'nbatch', 'eval_metric']) -def _load_general(data, targets): - """Load a list of arrays into a list of arrays specified by slices""" - for d_src, d_targets in zip(data, targets): - if isinstance(d_targets, nd.NDArray): - d_src.copyto(d_targets) - else: - for slice_idx, d_dst in d_targets: - d_src[slice_idx].copyto(d_dst) -def _load_data(batch, targets): - """Load data into sliced arrays""" - _load_general(batch.data, targets) -def _load_label(batch, targets): - """Load label into sliced arrays""" - _load_general(batch.label, targets) - - -def _check_arguments(symbol): - """Check the argument names of symbol. - This function checks the duplication of arguments in Symbol. - The check is done for feedforward net for now. - Parameters - ---------- - symbol : Symbol - The network configuration - """ - arg_set = set() - arg_names = symbol.list_arguments() - for name in arg_names: - if name in arg_set: - raise ValueError(('Find duplicated argument name \"%s\", ' + - 'please make the weight name non-duplicated(using name arguments), ' + - 'arguments are %s') % (name, str(arg_names))) - arg_set.add(name) - - aux_set = set() - aux_names = symbol.list_auxiliary_states() - for name in aux_names: - if name in aux_set: - raise ValueError( - ('Find duplicated auxiliary param name \"%s\", ' + - 'please make the weight name non-duplicated(using name arguments), ' + - 'arguments are %s, auxiliary params are %s' - ) % (name, str(arg_names), str(aux_names))) - aux_set.add(name) - - -def _split_input_slice(batch_size, work_load_list): - """Get input slice from the input shape. - Parameters - ---------- - batch_size : int - The number of samples in a mini-batch. - work_load_list : list of float or int, optional - The list of work load for different devices, - in the same order as ctx - Returns - ------- - slices : list of slice - The split slices to get a specific slice. - Raises - ------ - ValueError - If there are two many splits such that some slice can be empty. - """ - total_work_load = sum(work_load_list) - batch_num_list = [round(work_load * batch_size / total_work_load) - for work_load in work_load_list] - batch_num_sum = sum(batch_num_list) - if batch_num_sum < batch_size: - batch_num_list[-1] += batch_size - batch_num_sum - slices = [] - end = 0 - for batch_num in batch_num_list: - begin = int(min((end, batch_size))) - end = int(min((begin + batch_num, batch_size))) - if begin >= end: - raise ValueError('Too many slices such that some splits are empty') - slices.append(slice(begin, end)) - return slices - def _create_kvstore(kvstore, num_device, arg_params): """Create kvstore This function select and create a proper kvstore if given the kvstore type @@ -153,6 +74,46 @@ def _create_kvstore(kvstore, num_device, arg_params): return (kv, update_on_kvstore) +def _initialize_kvstore(kvstore, param_arrays, arg_params, param_names, + update_on_kvstore): + """ Initialize kvstore""" + for idx in range(len(param_arrays)): + param_on_devs = param_arrays[idx] + kvstore.init(idx, arg_params[param_names[idx]]) + + if update_on_kvstore: + kvstore.pull(idx, param_on_devs, priority=-idx) + +def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore): + """ Perform update of param_arrays from grad_arrays on kvstore.""" + for index, pair in enumerate(zip(param_arrays, grad_arrays)): + arg_list, grad_list = pair + if grad_list[0] is None: + continue + # push gradient, priority is negative index + kvstore.push(index, grad_list, priority=-index) + # pull back the weights + kvstore.pull(index, arg_list, priority=-index) + +def _update_params(param_arrays, grad_arrays, updater, num_device, + kvstore=None): + """ Perform update of param_arrays from grad_arrays not on kvstore.""" + for index, pair in enumerate(zip(param_arrays, grad_arrays)): + arg_list, grad_list = pair + if grad_list[0] is None: + continue + if kvstore: + # push gradient, priority is negative index + kvstore.push(index, grad_list, priority=-index) + # pull back the sum gradients, to the same locations. + kvstore.pull(index, grad_list, priority=-index) + for k, p in enumerate(zip(arg_list, grad_list)): + # faked an index here, to make optimizer create diff + # state for the same index but on diff devs, TODO(mli) + # use a better solution latter + w, g = p + updater(index*num_device+k, g, w) + def _train_multi_device(symbol, ctx, arg_names, param_names, aux_names, arg_params, aux_params, begin_epoch, end_epoch, epoch_size, optimizer, @@ -216,66 +177,31 @@ def _train_multi_device(symbol, ctx, arg_names, param_names, aux_names, """ if logger is None: logger = logging - # preparation - num_device = len(ctx) - logging.info('Start training with %s', str(ctx)) - - # make sure the architecture is valid - _check_arguments(symbol) - - if work_load_list is None: - work_load_list = [1] * num_device - assert isinstance(work_load_list, list) and len(work_load_list) == num_device, \ - "Invalid settings for work load. " - slices = _split_input_slice(train_data.batch_size, work_load_list) - train_execs = [] - for i in range(len(ctx)): - data_shapes = {k: tuple([slices[i].stop-slices[i].start] + list(v[1:])) - for k, v in train_data.provide_data} - train_exec = symbol.simple_bind(ctx[i], 'write', **data_shapes) - if monitor: - monitor.install(train_exec) - train_execs.append(train_exec) - - # data structure - data_names = [x[0] for x in train_data.provide_data] - label_names = [x[0] for x in train_data.provide_label] - - data_arrays = [[(slices[i], e.arg_dict[name]) for i, e in enumerate(train_execs)] - for name in data_names] - label_arrays = [[(slices[i], e.arg_dict[name]) for i, e in enumerate(train_execs)] - for name in label_names] - - param_idx = [i for i in range(len(arg_names)) if arg_names[i] in param_names] - param_names = [arg_names[i] for i in param_idx] - param_arrays = [[e.arg_arrays[i] for e in train_execs] for i in param_idx] - grad_arrays = [[e.grad_arrays[i] for e in train_execs] for i in param_idx] - aux_arrays = [[e.aux_arrays[i] for e in train_execs] for i in range(len(aux_names))] - - for texec in train_execs: - texec.copy_params_from(arg_params, aux_params) + executor_manager = DataParallelExecutorManager(symbol=symbol, + ctx=ctx, + train_data=train_data, + param_names=param_names, + arg_names=arg_names, + aux_names=aux_names, + work_load_list=work_load_list, + logger=logger) + if monitor: + executor_manager.install_monitor(monitor) + + executor_manager.set_params(arg_params, aux_params) if not update_on_kvstore: updater = get_updater(optimizer) - # init kvstore if kvstore: - # init optimizer - if update_on_kvstore: - kvstore.set_optimizer(optimizer) - - # init kv - for idx in range(len(param_arrays)): - param_on_devs = param_arrays[idx] - kvstore.init(idx, arg_params[param_names[idx]]) - - if update_on_kvstore: - kvstore.pull(idx, param_on_devs, priority=-idx) + _initialize_kvstore(kvstore=kvstore, + param_arrays=executor_manager.param_arrays, + arg_params=arg_params, + param_names=executor_manager.param_names, + update_on_kvstore=update_on_kvstore) - batch_size = train_data.batch_size - - output_shapes = [tuple([batch_size]+list(x.shape[1:])) for x in train_execs[0].outputs] - cpu_output_arrays = [nd.zeros(s) for s in output_shapes] + if update_on_kvstore: + kvstore.set_optimizer(optimizer) # Now start training for epoch in range(begin_epoch, end_epoch): @@ -288,48 +214,31 @@ def _train_multi_device(symbol, ctx, arg_names, param_names, aux_names, while True: do_reset = True for data_batch in train_data: - _load_data(data_batch, data_arrays) - _load_label(data_batch, label_arrays) + + executor_manager.load_data_batch(data_batch) if monitor is not None: monitor.tic() - # forward backward pass - for texec, islice in zip(train_execs, slices): - texec.forward(is_train=True) - for cpu_out, dev_out in zip(cpu_output_arrays, texec.outputs): - dev_out.copyto(cpu_out[islice]) - #texec.outputs[0].copyto(out_cpu_array[islice]) - for texec in train_execs: - texec.backward() - - # update the parameters - for index, pair in enumerate(zip(param_arrays, grad_arrays)): - arg_list, grad_list = pair - if grad_list[0] is None: - continue - # Gradient synchronization - if kvstore: - # push gradient, priority is negative index - kvstore.push(index, grad_list, priority=-index) - if update_on_kvstore: - # pull back the weights - kvstore.pull(index, arg_list, priority=-index) - else: - # pull back the sum gradients, to the same locations. - kvstore.pull(index, grad_list, priority=-index) - if not update_on_kvstore: - for k, p in enumerate(zip(arg_list, grad_list)): - # faked an index here, to make optimizer create diff - # state for the same index but on diff devs, TODO(mli) - # use a better solution latter - w, g = p - updater(index*num_device+k, g, w) + + executor_manager.forward(is_train=True) + executor_manager.backward() + + if update_on_kvstore: + _update_params_on_kvstore(executor_manager.param_arrays, + executor_manager.grad_arrays, + kvstore) + else: + _update_params(executor_manager.param_arrays, + executor_manager.grad_arrays, + updater=updater, + num_device=len(ctx), + kvstore=kvstore) if monitor is not None: monitor.toc_print() # evaluate at end, so out_cpu_array can lazy copy - eval_metric.update(data_batch.label, cpu_output_arrays) + eval_metric.update(data_batch.label, executor_manager.cpu_output_arrays) nbatch += 1 # batch callback (for print purpose) @@ -364,26 +273,15 @@ def _train_multi_device(symbol, ctx, arg_names, param_names, aux_names, eval_metric.reset() eval_data.reset() for eval_batch in eval_data: - _load_data(eval_batch, data_arrays) - _load_label(eval_batch, label_arrays) - - # forward pass - for texec, islice in zip(train_execs, slices): - texec.forward(is_train=False) - for cpu_out, dev_out in zip(cpu_output_arrays, texec.outputs): - dev_out.copyto(cpu_out[islice]) - eval_metric.update(eval_batch.label, cpu_output_arrays) + executor_manager.load_data_batch(eval_batch) + executor_manager.forward(is_train=False) + eval_metric.update(eval_batch.label, executor_manager.cpu_output_arrays) + name, value = eval_metric.get() logger.info('Epoch[%d] Validation-%s=%f', epoch, name, value) if epoch_end_callback or epoch + 1 == end_epoch: - # copy data back to cpu - for name, block in zip(param_names, param_arrays): - weight = sum(w.copyto(cpu()) for w in block) / len(block) - weight.copyto(arg_params[name]) - for name, block in zip(aux_names, aux_arrays): - weight = sum(w.copyto(cpu()) for w in block) / len(block) - weight.copyto(aux_params[name]) + executor_manager.copy_to(arg_params, aux_params) if epoch_end_callback != None: if isinstance(epoch_end_callback, list):