diff --git a/python/mxnet/executor.py b/python/mxnet/executor.py index 05a269d1e2db..21ff30882e0a 100644 --- a/python/mxnet/executor.py +++ b/python/mxnet/executor.py @@ -1,5 +1,5 @@ # coding: utf-8 -# pylint: disable=invalid-name, protected-access, too-many-locals +# pylint: disable=invalid-name, protected-access, too-many-locals, too-many-arguments """Symbolic Executor component of MXNet.""" from __future__ import absolute_import @@ -8,6 +8,9 @@ from .base import mx_uint, NDArrayHandle, ExecutorHandle from .base import check_call, c_array, py_str from .ndarray import NDArray +from . import ndarray as nd +from .context import cpu +import logging class Executor(object): """ Executor is the actual executing object of MXNet.""" @@ -216,3 +219,211 @@ def debug_str(self): check_call(_LIB.MXExecutorPrint( self.handle, ctypes.byref(debug_str))) return py_str(debug_str.value) + +def _split_input_slice(batch_size, work_load_list): + """Get input slice from the input shape. + Parameters + ---------- + batch_size : int + The number of samples in a mini-batch. + work_load_list : list of float or int, optional + The list of work load for different devices, + in the same order as ctx + Returns + ------- + slices : list of slice + The split slices to get a specific slice. + Raises + ------ + ValueError + If there are two many splits such that some slice can be empty. + """ + total_work_load = sum(work_load_list) + batch_num_list = [round(work_load * batch_size / total_work_load) + for work_load in work_load_list] + batch_num_sum = sum(batch_num_list) + if batch_num_sum < batch_size: + batch_num_list[-1] += batch_size - batch_num_sum + slices = [] + end = 0 + for batch_num in batch_num_list: + begin = int(min((end, batch_size))) + end = int(min((begin + batch_num, batch_size))) + if begin >= end: + raise ValueError('Too many slices such that some splits are empty') + slices.append(slice(begin, end)) + return slices + +def _check_arguments(symbol): + """Check the argument names of symbol. + This function checks the duplication of arguments in Symbol. + The check is done for feedforward net for now. + Parameters + ---------- + symbol : Symbol + The network configuration + """ + arg_set = set() + arg_names = symbol.list_arguments() + for name in arg_names: + if name in arg_set: + raise ValueError(('Find duplicated argument name \"%s\", ' + + 'please make the weight name non-duplicated(using name arguments), ' + + 'arguments are %s') % (name, str(arg_names))) + arg_set.add(name) + + aux_set = set() + aux_names = symbol.list_auxiliary_states() + for name in aux_names: + if name in aux_set: + raise ValueError( + ('Find duplicated auxiliary param name \"%s\", ' + + 'please make the weight name non-duplicated(using name arguments), ' + + 'arguments are %s, auxiliary params are %s' + ) % (name, str(arg_names), str(aux_names))) + aux_set.add(name) + +def _load_general(data, targets): + """Load a list of arrays into a list of arrays specified by slices""" + for d_src, d_targets in zip(data, targets): + if isinstance(d_targets, nd.NDArray): + d_src.copyto(d_targets) + else: + for slice_idx, d_dst in d_targets: + d_src[slice_idx].copyto(d_dst) + +def _load_data(batch, targets): + """Load data into sliced arrays""" + _load_general(batch.data, targets) + +def _load_label(batch, targets): + """Load label into sliced arrays""" + _load_general(batch.label, targets) + +class ExecutorManager(object): + """ Helper class to manage multiple executors. + Parameters + ---------- + symbol : Symbol + output symbol + ctx : list of Context + devices to run on + param_names: list of str + Name of all trainable parameters of the network. + arg_names: list of str + Name of all arguments of the network. + aux_names: list of str + Name of all auxiliary states of the network. + train_data : DataIter + Training data iterator. + work_load_list : list of float or int, optional + The list of work load for different devices, + in the same order as ctx + logger : logging logger + When not specified, default logger will be used. + """ + def __init__(self, symbol, ctx, train_data, + param_names, arg_names, aux_names, + work_load_list=None, logger=None): + if logger is None: + logger = logging + # preparation + num_device = len(ctx) + logger.info('Start training with %s', str(ctx)) + + # make sure the architecture is valid + _check_arguments(symbol) + + if work_load_list is None: + work_load_list = [1] * num_device + assert isinstance(work_load_list, list) and len(work_load_list) == num_device, \ + "Invalid settings for work load. " + + slices = _split_input_slice(train_data.batch_size, work_load_list) + self.slices = slices + + self.train_execs = [] + for i in range(len(ctx)): + data_shapes = {k: tuple([slices[i].stop-slices[i].start] + list(v[1:])) + for k, v in train_data.provide_data} + train_exec = symbol.simple_bind(ctx[i], 'write', **data_shapes) + self.train_execs.append(train_exec) + + # data structure + self.data_names = [x[0] for x in train_data.provide_data] + self.label_names = [x[0] for x in train_data.provide_label] + self.aux_names = aux_names + + self.data_arrays = [[(slices[i], e.arg_dict[name]) for i, e in enumerate(self.train_execs)] + for name in self.data_names] + self.label_arrays = [[(slices[i], e.arg_dict[name]) for i, e in enumerate(self.train_execs)] + for name in self.label_names] + + self.param_idx = [i for i in range(len(arg_names)) if arg_names[i] in param_names] + self.param_names = [arg_names[i] for i in self.param_idx] + self.param_arrays = [[e.arg_arrays[i] for e in self.train_execs] + for i in self.param_idx] + self.grad_arrays = [[e.grad_arrays[i] for e in self.train_execs] + for i in self.param_idx] + + self.aux_arrays = [[e.aux_arrays[i] for e in self.train_execs] + for i in range(len(aux_names))] + + batch_size = train_data.batch_size + + output_shapes = [tuple([batch_size]+list(x.shape[1:])) for x in self.train_execs[0].outputs] + self.cpu_output_arrays = [nd.zeros(s) for s in output_shapes] + + def install_monitor(self, monitor): + """ Install monitor on all executors """ + for train_exec in self.train_execs: + monitor.install(train_exec) + + def set_params(self, arg_params, aux_params): + """ set parameter and aux values + Parameters + ---------- + arg_params : list of NDArray + source parameter arrays + aux_params : list of NDArray + source aux arrays + """ + + for texec in self.train_execs: + texec.copy_params_from(arg_params, aux_params) + + def copy_to(self, arg_params, aux_params): + """ Copy data from each executor to `arg_params` and `aux_params` + Parameters + ---------- + arg_params : list of NDArray + target parameter arrays + aux_params : list of NDArray + target aux arrays + Notes + ----- + - This function will inplace update the NDArrays in arg_params and aux_params. + """ + for name, block in zip(self.param_names, self.param_arrays): + weight = sum(w.copyto(cpu()) for w in block) / len(block) + weight.copyto(arg_params[name]) + for name, block in zip(self.aux_names, self.aux_arrays): + weight = sum(w.copyto(cpu()) for w in block) / len(block) + weight.copyto(aux_params[name]) + + def load_data_batch(self, data_batch): + """ load data and labels into arrays """ + _load_data(data_batch, self.data_arrays) + _load_label(data_batch, self.label_arrays) + + def forward(self, is_train=False): + """ Perform a forward pass on each executor """ + for texec, islice in zip(self.train_execs, self.slices): + texec.forward(is_train=is_train) + for cpu_out, dev_out in zip(self.cpu_output_arrays, texec.outputs): + dev_out.copyto(cpu_out[islice]) + + def backward(self): + """ Perform a backward pass on each executor """ + for texec in self.train_execs: + texec.backward() diff --git a/python/mxnet/model.py b/python/mxnet/model.py index 7af0d8c10b1d..10681372300b 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -15,7 +15,8 @@ from .context import Context, cpu from .initializer import Uniform from collections import namedtuple -from .optimizer import get_updater +from .optimizer import UpdateManager +from .executor import ExecutorManager, _check_arguments, _load_data BASE_ESTIMATOR = object @@ -31,86 +32,6 @@ 'nbatch', 'eval_metric']) -def _load_general(data, targets): - """Load a list of arrays into a list of arrays specified by slices""" - for d_src, d_targets in zip(data, targets): - if isinstance(d_targets, nd.NDArray): - d_src.copyto(d_targets) - else: - for slice_idx, d_dst in d_targets: - d_src[slice_idx].copyto(d_dst) -def _load_data(batch, targets): - """Load data into sliced arrays""" - _load_general(batch.data, targets) -def _load_label(batch, targets): - """Load label into sliced arrays""" - _load_general(batch.label, targets) - - -def _check_arguments(symbol): - """Check the argument names of symbol. - This function checks the duplication of arguments in Symbol. - The check is done for feedforward net for now. - Parameters - ---------- - symbol : Symbol - The network configuration - """ - arg_set = set() - arg_names = symbol.list_arguments() - for name in arg_names: - if name in arg_set: - raise ValueError(('Find duplicated argument name \"%s\", ' + - 'please make the weight name non-duplicated(using name arguments), ' + - 'arguments are %s') % (name, str(arg_names))) - arg_set.add(name) - - aux_set = set() - aux_names = symbol.list_auxiliary_states() - for name in aux_names: - if name in aux_set: - raise ValueError( - ('Find duplicated auxiliary param name \"%s\", ' + - 'please make the weight name non-duplicated(using name arguments), ' + - 'arguments are %s, auxiliary params are %s' - ) % (name, str(arg_names), str(aux_names))) - aux_set.add(name) - - -def _split_input_slice(batch_size, work_load_list): - """Get input slice from the input shape. - Parameters - ---------- - batch_size : int - The number of samples in a mini-batch. - work_load_list : list of float or int, optional - The list of work load for different devices, - in the same order as ctx - Returns - ------- - slices : list of slice - The split slices to get a specific slice. - Raises - ------ - ValueError - If there are two many splits such that some slice can be empty. - """ - total_work_load = sum(work_load_list) - batch_num_list = [round(work_load * batch_size / total_work_load) - for work_load in work_load_list] - batch_num_sum = sum(batch_num_list) - if batch_num_sum < batch_size: - batch_num_list[-1] += batch_size - batch_num_sum - slices = [] - end = 0 - for batch_num in batch_num_list: - begin = int(min((end, batch_size))) - end = int(min((begin + batch_num, batch_size))) - if begin >= end: - raise ValueError('Too many slices such that some splits are empty') - slices.append(slice(begin, end)) - return slices - def _create_kvstore(kvstore, num_device, arg_params): """Create kvstore This function select and create a proper kvstore if given the kvstore type @@ -151,210 +72,6 @@ def _create_kvstore(kvstore, num_device, arg_params): return (kv, update_on_kvstore) -class ExecutorManager(object): - """ Helper class to manage multiple executors. - Parameters - ---------- - symbol : Symbol - output symbol - ctx : list of Context - devices to run on - param_names: list of str - Name of all trainable parameters of the network. - arg_names: list of str - Name of all arguments of the network. - aux_names: list of str - Name of all auxiliary states of the network. - train_data : DataIter - Training data iterator. - work_load_list : list of float or int, optional - The list of work load for different devices, - in the same order as ctx - logger : logging logger - When not specified, default logger will be used. - """ - def __init__(self, symbol, ctx, train_data, - param_names, arg_names, aux_names, - work_load_list=None, logger=None): - if logger is None: - logger = logging - # preparation - num_device = len(ctx) - logger.info('Start training with %s', str(ctx)) - - # make sure the architecture is valid - _check_arguments(symbol) - - if work_load_list is None: - work_load_list = [1] * num_device - assert isinstance(work_load_list, list) and len(work_load_list) == num_device, \ - "Invalid settings for work load. " - - slices = _split_input_slice(train_data.batch_size, work_load_list) - self.slices = slices - - self.train_execs = [] - for i in range(len(ctx)): - data_shapes = {k: tuple([slices[i].stop-slices[i].start] + list(v[1:])) - for k, v in train_data.provide_data} - train_exec = symbol.simple_bind(ctx[i], 'write', **data_shapes) - self.train_execs.append(train_exec) - - # data structure - self.data_names = [x[0] for x in train_data.provide_data] - self.label_names = [x[0] for x in train_data.provide_label] - self.aux_names = aux_names - - self.data_arrays = [[(slices[i], e.arg_dict[name]) for i, e in enumerate(self.train_execs)] - for name in self.data_names] - self.label_arrays = [[(slices[i], e.arg_dict[name]) for i, e in enumerate(self.train_execs)] - for name in self.label_names] - - self.param_idx = [i for i in range(len(arg_names)) if arg_names[i] in param_names] - self.param_names = [arg_names[i] for i in self.param_idx] - self.param_arrays = [[e.arg_arrays[i] for e in self.train_execs] - for i in self.param_idx] - self.grad_arrays = [[e.grad_arrays[i] for e in self.train_execs] - for i in self.param_idx] - self.aux_arrays = [[e.aux_arrays[i] for e in self.train_execs] - for i in range(len(aux_names))] - - batch_size = train_data.batch_size - - output_shapes = [tuple([batch_size]+list(x.shape[1:])) for x in self.train_execs[0].outputs] - self.cpu_output_arrays = [nd.zeros(s) for s in output_shapes] - - def install_monitor(self, monitor): - """ Install monitor on all executors """ - for train_exec in self.train_execs: - monitor.install(train_exec) - - def set_params(self, arg_params, aux_params): - """ set parameter and aux values - Parameters - ---------- - arg_params : list of NDArray - source parameter arrays - aux_params : list of NDArray - source aux arrays - """ - - for texec in self.train_execs: - texec.copy_params_from(arg_params, aux_params) - - def copy_to(self, arg_params, aux_params): - """ Copy data from each executor to `arg_params` and `aux_params` - Parameters - ---------- - arg_params : list of NDArray - target parameter arrays - aux_params : list of NDArray - target aux arrays - Notes - ----- - - This function will inplace update the NDArrays in arg_params and aux_params. - """ - for name, block in zip(self.param_names, self.param_arrays): - weight = sum(w.copyto(cpu()) for w in block) / len(block) - weight.copyto(arg_params[name]) - for name, block in zip(self.aux_names, self.aux_arrays): - weight = sum(w.copyto(cpu()) for w in block) / len(block) - weight.copyto(aux_params[name]) - - def load_data_batch(self, data_batch): - """ load data and labels into arrays """ - _load_data(data_batch, self.data_arrays) - _load_label(data_batch, self.label_arrays) - - def forward(self, is_train=True): - """ Perform a forward pass on each executor """ - for texec, islice in zip(self.train_execs, self.slices): - texec.forward(is_train=is_train) - for cpu_out, dev_out in zip(self.cpu_output_arrays, texec.outputs): - dev_out.copyto(cpu_out[islice]) - - def backward(self): - """ Perform a backward pass on each executor """ - for texec in self.train_execs: - texec.backward() - - -class Updater(object): - """ Helper to manage kvstore and optimizers to do updates of parameters - Parameters - ---------- - kvstore : KVStore - The KVStore - update_on_kvstore : bool - whether or not perform weight updating on kvstore - optimizer : Optimizer - The optimization algorithm - param_args : list of list of NDArray - location of parameters per device - arg_params : list of NDArray - locacation of parameters - param_names : list of str - names of parameters to place in kvstore - ctx : list of Context - The training devices. - - Notes - ----- - - This function will inplace update the NDArrays in arg_params. - """ - def __init__(self, kvstore, update_on_kvstore, optimizer, param_arrays, - arg_params, param_names, ctx): - if not update_on_kvstore: - self.updater = get_updater(optimizer) - - self.num_device = len(ctx) - - # init kvstore - if kvstore: - # init optimizer - if update_on_kvstore: - kvstore.set_optimizer(optimizer) - - # init kv - for idx in range(len(param_arrays)): - param_on_devs = param_arrays[idx] - kvstore.init(idx, arg_params[param_names[idx]]) - - if update_on_kvstore: - kvstore.pull(idx, param_on_devs, priority=-idx) - - self.kvstore = kvstore - self.update_on_kvstore = update_on_kvstore - - def do_update(self, param_arrays, grad_arrays): - """ Update parameters with given gradients - Parameters - ---------- - param_arrays: list of NDArray - grad_arrays: list of NDarray - """ - for index, pair in enumerate(zip(param_arrays, grad_arrays)): - arg_list, grad_list = pair - if grad_list[0] is None: - continue - # Gradient synchronization - if self.kvstore: - # push gradient, priority is negative index - self.kvstore.push(index, grad_list, priority=-index) - if self.update_on_kvstore: - # pull back the weights - self.kvstore.pull(index, arg_list, priority=-index) - else: - # pull back the sum gradients, to the same locations. - self.kvstore.pull(index, grad_list, priority=-index) - if not self.update_on_kvstore: - for k, p in enumerate(zip(arg_list, grad_list)): - # faked an index here, to make optimizer create diff - # state for the same index but on diff devs, TODO(mli) - # use a better solution latter - w, g = p - self.updater(index*self.num_device+k, g, w) - def _train_multi_device(symbol, ctx, arg_names, param_names, aux_names, arg_params, aux_params, begin_epoch, end_epoch, epoch_size, optimizer, @@ -431,13 +148,13 @@ def _train_multi_device(symbol, ctx, arg_names, param_names, aux_names, executor_manager.set_params(arg_params, aux_params) - updater = Updater(kvstore=kvstore, - update_on_kvstore=update_on_kvstore, - optimizer=optimizer, - param_arrays=executor_manager.param_arrays, - arg_params=arg_params, - param_names=executor_manager.param_names, - ctx=ctx) + updater = UpdateManager(kvstore=kvstore, + update_on_kvstore=update_on_kvstore, + optimizer=optimizer, + param_arrays=executor_manager.param_arrays, + arg_params=arg_params, + param_names=executor_manager.param_names, + ctx=ctx) # Now start training for epoch in range(begin_epoch, end_epoch): @@ -456,7 +173,7 @@ def _train_multi_device(symbol, ctx, arg_names, param_names, aux_names, if monitor is not None: monitor.tic() - executor_manager.forward() + executor_manager.forward(is_train=True) executor_manager.backward() updater.do_update(executor_manager.param_arrays, executor_manager.grad_arrays) diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py index 589158e0b075..18994de3dff6 100644 --- a/python/mxnet/optimizer.py +++ b/python/mxnet/optimizer.py @@ -524,3 +524,81 @@ def updater(index, grad, weight): states[index] = optimizer.create_state(index, weight) optimizer.update(index, weight, grad, states[index]) return updater + + +class UpdateManager(object): + """ Helper to manage kvstore and optimizers to do updates of parameters + Parameters + ---------- + kvstore : KVStore + The KVStore + update_on_kvstore : bool + whether or not perform weight updating on kvstore + optimizer : Optimizer + The optimization algorithm + param_args : list of list of NDArray + location of parameters per device + arg_params : list of NDArray + locacation of parameters + param_names : list of str + names of parameters to place in kvstore + ctx : list of Context + The training devices. + + Notes + ----- + - This class will inplace update the NDArrays in arg_params. + """ + def __init__(self, kvstore, update_on_kvstore, optimizer, param_arrays, + arg_params, param_names, ctx): + if not update_on_kvstore: + self.updater = get_updater(optimizer) + + self.num_device = len(ctx) + + # init kvstore + if kvstore: + # init optimizer + if update_on_kvstore: + kvstore.set_optimizer(optimizer) + + # init kv + for idx in range(len(param_arrays)): + param_on_devs = param_arrays[idx] + kvstore.init(idx, arg_params[param_names[idx]]) + + if update_on_kvstore: + kvstore.pull(idx, param_on_devs, priority=-idx) + + self.kvstore = kvstore + self.update_on_kvstore = update_on_kvstore + + def do_update(self, param_arrays, grad_arrays): + """ Update parameters with given gradients + Parameters + ---------- + param_arrays: list of NDArray + grad_arrays: list of NDarray + """ + + for index, pair in enumerate(zip(param_arrays, grad_arrays)): + arg_list, grad_list = pair + if grad_list[0] is None: + continue + # Gradient synchronization + if self.kvstore: + # push gradient, priority is negative index + self.kvstore.push(index, grad_list, priority=-index) + if self.update_on_kvstore: + # pull back the weights + self.kvstore.pull(index, arg_list, priority=-index) + else: + # pull back the sum gradients, to the same locations. + self.kvstore.pull(index, grad_list, priority=-index) + if not self.update_on_kvstore: + for k, p in enumerate(zip(arg_list, grad_list)): + # faked an index here, to make optimizer create diff + # state for the same index but on diff devs, TODO(mli) + # use a better solution latter + w, g = p + self.updater(index*self.num_device+k, g, w)