Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
refactor model to make more modular
Browse files Browse the repository at this point in the history
  • Loading branch information
lukemetz committed Dec 12, 2015
1 parent f43d8f7 commit d4ce662
Showing 1 changed file with 237 additions and 110 deletions.
347 changes: 237 additions & 110 deletions python/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,210 @@ def _create_kvstore(kvstore, num_device, arg_params):

return (kv, update_on_kvstore)

class ExecutorManager(object):
""" Helper class to manage multiple executors.
Parameters
----------
symbol : Symbol
output symbol
ctx : list of Context
devices to run on
param_names: list of str
Name of all trainable parameters of the network.
arg_names: list of str
Name of all arguments of the network.
aux_names: list of str
Name of all auxiliary states of the network.
train_data : DataIter
Training data iterator.
work_load_list : list of float or int, optional
The list of work load for different devices,
in the same order as ctx
logger : logging logger
When not specified, default logger will be used.
"""
def __init__(self, symbol, ctx, train_data,
param_names, arg_names, aux_names,
work_load_list=None, logger=None):
if logger is None:
logger = logging
# preparation
num_device = len(ctx)
logger.info('Start training with %s', str(ctx))

# make sure the architecture is valid
_check_arguments(symbol)

if work_load_list is None:
work_load_list = [1] * num_device
assert isinstance(work_load_list, list) and len(work_load_list) == num_device, \
"Invalid settings for work load. "

slices = _split_input_slice(train_data.batch_size, work_load_list)
self.slices = slices

self.train_execs = []
for i in range(len(ctx)):
data_shapes = {k: tuple([slices[i].stop-slices[i].start] + list(v[1:]))
for k, v in train_data.provide_data}
train_exec = symbol.simple_bind(ctx[i], 'write', **data_shapes)
self.train_execs.append(train_exec)

# data structure
self.data_names = [x[0] for x in train_data.provide_data]
self.label_names = [x[0] for x in train_data.provide_label]
self.aux_names = aux_names

self.data_arrays = [[(slices[i], e.arg_dict[name]) for i, e in enumerate(self.train_execs)]
for name in self.data_names]
self.label_arrays = [[(slices[i], e.arg_dict[name]) for i, e in enumerate(self.train_execs)]
for name in self.label_names]

self.param_idx = [i for i in range(len(arg_names)) if arg_names[i] in param_names]
self.param_names = [arg_names[i] for i in self.param_idx]
self.param_arrays = [[e.arg_arrays[i] for e in self.train_execs]
for i in self.param_idx]
self.grad_arrays = [[e.grad_arrays[i] for e in self.train_execs]
for i in self.param_idx]
self.aux_arrays = [[e.aux_arrays[i] for e in self.train_execs]
for i in range(len(aux_names))]

batch_size = train_data.batch_size

output_shapes = [tuple([batch_size]+list(x.shape[1:])) for x in self.train_execs[0].outputs]
self.cpu_output_arrays = [nd.zeros(s) for s in output_shapes]

def install_monitor(self, monitor):
""" Install monitor on all executors """
for train_exec in self.train_execs:
monitor.install(train_exec)

def set_params(self, arg_params, aux_params):
""" set parameter and aux values
Parameters
----------
arg_params : list of NDArray
source parameter arrays
aux_params : list of NDArray
source aux arrays
"""

for texec in self.train_execs:
texec.copy_params_from(arg_params, aux_params)

def copy_to(self, arg_params, aux_params):
""" Copy data from each executor to `arg_params` and `aux_params`
Parameters
----------
arg_params : list of NDArray
target parameter arrays
aux_params : list of NDArray
target aux arrays
Notes
-----
- This function will inplace update the NDArrays in arg_params and aux_params.
"""
for name, block in zip(self.param_names, self.param_arrays):
weight = sum(w.copyto(cpu()) for w in block) / len(block)
weight.copyto(arg_params[name])
for name, block in zip(self.aux_names, self.aux_arrays):
weight = sum(w.copyto(cpu()) for w in block) / len(block)
weight.copyto(aux_params[name])

def load_data_batch(self, data_batch):
""" load data and labels into arrays """
_load_data(data_batch, self.data_arrays)
_load_label(data_batch, self.label_arrays)

def forward(self, is_train=True):
""" Perform a forward pass on each executor """
for texec, islice in zip(self.train_execs, self.slices):
texec.forward(is_train=is_train)
for cpu_out, dev_out in zip(self.cpu_output_arrays, texec.outputs):
dev_out.copyto(cpu_out[islice])

def backward(self):
""" Perform a backward pass on each executor """
for texec in self.train_execs:
texec.backward()


class Updater(object):
""" Helper to manage kvstore and optimizers to do updates of parameters
Parameters
----------
kvstore : KVStore
The KVStore
update_on_kvstore : bool
whether or not perform weight updating on kvstore
optimizer : Optimizer
The optimization algorithm
param_args : list of list of NDArray
location of parameters per device
arg_params : list of NDArray
locacation of parameters
param_names : list of str
names of parameters to place in kvstore
ctx : list of Context
The training devices.
Notes
-----
- This function will inplace update the NDArrays in arg_params.
"""
def __init__(self, kvstore, update_on_kvstore, optimizer, param_arrays,
arg_params, param_names, ctx):
if not update_on_kvstore:
self.updater = get_updater(optimizer)

self.num_device = len(ctx)

# init kvstore
if kvstore:
# init optimizer
if update_on_kvstore:
kvstore.set_optimizer(optimizer)

# init kv
for idx in range(len(param_arrays)):
param_on_devs = param_arrays[idx]
kvstore.init(idx, arg_params[param_names[idx]])

if update_on_kvstore:
kvstore.pull(idx, param_on_devs, priority=-idx)

self.kvstore = kvstore
self.update_on_kvstore = update_on_kvstore

def do_update(self, param_arrays, grad_arrays):
""" Update parameters with given gradients
Parameters
----------
param_arrays: list of NDArray
grad_arrays: list of NDarray
"""
for index, pair in enumerate(zip(param_arrays, grad_arrays)):
arg_list, grad_list = pair
if grad_list[0] is None:
continue
# Gradient synchronization
if self.kvstore:
# push gradient, priority is negative index
self.kvstore.push(index, grad_list, priority=-index)
if self.update_on_kvstore:
# pull back the weights
self.kvstore.pull(index, arg_list, priority=-index)
else:
# pull back the sum gradients, to the same locations.
self.kvstore.pull(index, grad_list, priority=-index)
if not self.update_on_kvstore:
for k, p in enumerate(zip(arg_list, grad_list)):
# faked an index here, to make optimizer create diff
# state for the same index but on diff devs, TODO(mli)
# use a better solution latter
w, g = p
self.updater(index*self.num_device+k, g, w)

def _train_multi_device(symbol, ctx, arg_names, param_names, aux_names,
arg_params, aux_params,
begin_epoch, end_epoch, epoch_size, optimizer,
Expand Down Expand Up @@ -214,66 +418,26 @@ def _train_multi_device(symbol, ctx, arg_names, param_names, aux_names,
"""
if logger is None:
logger = logging
# preparation
num_device = len(ctx)
logging.info('Start training with %s', str(ctx))

# make sure the architecture is valid
_check_arguments(symbol)

if work_load_list is None:
work_load_list = [1] * num_device
assert isinstance(work_load_list, list) and len(work_load_list) == num_device, \
"Invalid settings for work load. "
slices = _split_input_slice(train_data.batch_size, work_load_list)
train_execs = []
for i in range(len(ctx)):
data_shapes = {k: tuple([slices[i].stop-slices[i].start] + list(v[1:]))
for k, v in train_data.provide_data}
train_exec = symbol.simple_bind(ctx[i], 'write', **data_shapes)
if monitor:
monitor.install(train_exec)
train_execs.append(train_exec)

# data structure
data_names = [x[0] for x in train_data.provide_data]
label_names = [x[0] for x in train_data.provide_label]

data_arrays = [[(slices[i], e.arg_dict[name]) for i, e in enumerate(train_execs)]
for name in data_names]
label_arrays = [[(slices[i], e.arg_dict[name]) for i, e in enumerate(train_execs)]
for name in label_names]

param_idx = [i for i in range(len(arg_names)) if arg_names[i] in param_names]
param_names = [arg_names[i] for i in param_idx]
param_arrays = [[e.arg_arrays[i] for e in train_execs] for i in param_idx]
grad_arrays = [[e.grad_arrays[i] for e in train_execs] for i in param_idx]
aux_arrays = [[e.aux_arrays[i] for e in train_execs] for i in range(len(aux_names))]

for texec in train_execs:
texec.copy_params_from(arg_params, aux_params)

if not update_on_kvstore:
updater = get_updater(optimizer)

# init kvstore
if kvstore:
# init optimizer
if update_on_kvstore:
kvstore.set_optimizer(optimizer)

# init kv
for idx in range(len(param_arrays)):
param_on_devs = param_arrays[idx]
kvstore.init(idx, arg_params[param_names[idx]])

if update_on_kvstore:
kvstore.pull(idx, param_on_devs, priority=-idx)

batch_size = train_data.batch_size

output_shapes = [tuple([batch_size]+list(x.shape[1:])) for x in train_execs[0].outputs]
cpu_output_arrays = [nd.zeros(s) for s in output_shapes]
executor_manager = ExecutorManager(symbol=symbol,
ctx=ctx,
train_data=train_data,
param_names=param_names,
arg_names=arg_names,
aux_names=aux_names,
work_load_list=work_load_list,
logger=logger)
if monitor:
executor_manager.install_monitor(monitor)

executor_manager.set_params(arg_params, aux_params)

updater = Updater(kvstore=kvstore,
update_on_kvstore=update_on_kvstore,
optimizer=optimizer,
param_arrays=executor_manager.param_arrays,
arg_params=arg_params,
param_names=executor_manager.param_names,
ctx=ctx)

# Now start training
for epoch in range(begin_epoch, end_epoch):
Expand All @@ -286,48 +450,22 @@ def _train_multi_device(symbol, ctx, arg_names, param_names, aux_names,
while True:
do_reset = True
for data_batch in train_data:
_load_data(data_batch, data_arrays)
_load_label(data_batch, label_arrays)

executor_manager.load_data_batch(data_batch)

if monitor is not None:
monitor.tic()
# forward backward pass
for texec, islice in zip(train_execs, slices):
texec.forward(is_train=True)
for cpu_out, dev_out in zip(cpu_output_arrays, texec.outputs):
dev_out.copyto(cpu_out[islice])
#texec.outputs[0].copyto(out_cpu_array[islice])
for texec in train_execs:
texec.backward()

# update the parameters
for index, pair in enumerate(zip(param_arrays, grad_arrays)):
arg_list, grad_list = pair
if grad_list[0] is None:
continue
# Gradient synchronization
if kvstore:
# push gradient, priority is negative index
kvstore.push(index, grad_list, priority=-index)
if update_on_kvstore:
# pull back the weights
kvstore.pull(index, arg_list, priority=-index)
else:
# pull back the sum gradients, to the same locations.
kvstore.pull(index, grad_list, priority=-index)
if not update_on_kvstore:
for k, p in enumerate(zip(arg_list, grad_list)):
# faked an index here, to make optimizer create diff
# state for the same index but on diff devs, TODO(mli)
# use a better solution latter
w, g = p
updater(index*num_device+k, g, w)

executor_manager.forward()
executor_manager.backward()

updater.do_update(executor_manager.param_arrays, executor_manager.grad_arrays)

if monitor is not None:
monitor.toc_print()

# evaluate at end, so out_cpu_array can lazy copy
eval_metric.update(data_batch.label, cpu_output_arrays)
eval_metric.update(data_batch.label, executor_manager.cpu_output_arrays)

nbatch += 1
# batch callback (for print purpose)
Expand Down Expand Up @@ -362,26 +500,15 @@ def _train_multi_device(symbol, ctx, arg_names, param_names, aux_names,
eval_metric.reset()
eval_data.reset()
for eval_batch in eval_data:
_load_data(eval_batch, data_arrays)
_load_label(eval_batch, label_arrays)

# forward pass
for texec, islice in zip(train_execs, slices):
texec.forward(is_train=False)
for cpu_out, dev_out in zip(cpu_output_arrays, texec.outputs):
dev_out.copyto(cpu_out[islice])
eval_metric.update(eval_batch.label, cpu_output_arrays)
executor_manager.load_data_batch(eval_batch)
executor_manager.forward(is_train=False)
eval_metric.update(eval_batch.label, executor_manager.cpu_output_arrays)

name, value = eval_metric.get()
logger.info('Epoch[%d] Validation-%s=%f', epoch, name, value)

if epoch_end_callback or epoch + 1 == end_epoch:
# copy data back to cpu
for name, block in zip(param_names, param_arrays):
weight = sum(w.copyto(cpu()) for w in block) / len(block)
weight.copyto(arg_params[name])
for name, block in zip(aux_names, aux_arrays):
weight = sum(w.copyto(cpu()) for w in block) / len(block)
weight.copyto(aux_params[name])
executor_manager.copy_to(arg_params, aux_params)

if epoch_end_callback != None:
if isinstance(epoch_end_callback, list):
Expand Down

0 comments on commit d4ce662

Please sign in to comment.