diff --git a/doc/python/index.md b/doc/python/index.md index 3a1713c81afa..82e1eaaa9ed1 100644 --- a/doc/python/index.md +++ b/doc/python/index.md @@ -20,4 +20,6 @@ Python API Documents -------------------- * [NDArray API](ndarray.md) * [Symbolic API](symbol.md) +* [KVStore API](kvstore.md) * [Data Loading API](io.md) +* [Model API](model.md) \ No newline at end of file diff --git a/doc/python/model.md b/doc/python/model.md new file mode 100644 index 000000000000..bd15379eeeee --- /dev/null +++ b/doc/python/model.md @@ -0,0 +1,116 @@ +MXNet Python Model API +====================== +The model API in mxnet as not really an API. +It is a thin wrapper build on top of [ndarray](ndarray.md) and [symbolic](symbol.md) +modules to make neural network training easy. + +* [Train a Model](#overloaded-operators) introduces operator overloading of symbols +* [Serialization](#serialization) introduces how to save and load symbols. +* [Multiple Outputs](#multiple-outputs) introduces how to configure multiple outputs +* [API Reference](#api-reference) gives reference to all functions. +* [Symbol Object Document](#mxnet.symbol.Symbol) gives API reference to the Symbol Object. + + +Train a Model +------------- +To train a model, you can follow two steps, first a configuration using symbol, +then call ```model.Feedforward.create``` to create a model for you. +The following example creates a two layer neural networks. + +```python +batch_size = 100 +data = mx.symbol.Variable('data') +fc1 = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128) +act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu") +fc2 = mx.symbol.FullyConnected(act1, name = 'fc2', num_hidden = 64) +softmax = mx.symbol.Softmax(fc2, name = 'sm') + +model = mx.model.FeedForward.create( + softmax, + X=data_set, + num_round=num_round, + learning_rate=0.01) +``` + +You can also use scikit-learn style construct and fit function to create a model. +For more information, you can refer to [Model API Reference](#model-api-reference). + +Save the Model +-------------- +It is important to save your work after the job done. +To save the model, you can directly pickle it if you like the pythonic way. +We also provide a save and load function. + +```python +# save a model to mymodel-symbol.json and mymodel-0100.params +prefix = 'mymodel' +model.save(prefix, 100) + +# load model back +model_loaded = mx.model.FeedForward.load(prefix, 100) +``` +The advantage of this save and load function is they are language agnostic, +and you should be able to save and load directly into cloud storage such as S3 and HDFS. + +Periodically Checkpoint +----------------------- +It is also helpful to periodically checkpoint your model after each iteration. +To do so, you can simply add a checkpoint callback to the function. +The training process will automatically checkpoint to the specified place after +each iteration. + +```python +prefix='models/chkpt' +model = mx.model.FeedForward.create( + softmax, + X=data_set, + iter_end_callback=mx.model.do_checkpoint(prefix), + num_round=num_round, + learning_rate=0.01) +``` +You can load the model checkpoint later using ```Feedforward.load```. + +Use Multiple Devices +-------------------- +Simply set ```ctx``` to be the list of devices you like to train on. + +```python +devices = [mx.gpu(i) for i in range(num_device)] +model = mx.model.FeedForward.create( + softmax, + X=dataset, + ctx=devices, + ...) +``` + +Initializer API Reference +------------------------- + +```eval_rst +.. automodule:: mxnet.initializer + :members: +``` + +Evaluation Metric API Reference +------------------------------- + +```eval_rst +.. automodule:: mxnet.metric + :members: +``` + +Optimizer API Reference +----------------------- + +```eval_rst +.. automodule:: mxnet.optimizer + :members: +``` + +Model API Reference +------------------- + +```eval_rst +.. automodule:: mxnet.model + :members: +``` diff --git a/python/mxnet/model.py b/python/mxnet/model.py index a84244a2a777..df07512af64d 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -529,7 +529,7 @@ def predict(self, X): def fit(self, X, y=None, eval_data=None, eval_metric='acc', iter_end_callback=None, logger=None): - """fit the model + """Fit the model. Parameters ---------- @@ -629,3 +629,54 @@ def load(prefix, iteration, ctx=None): return FeedForward(symbol, ctx=ctx, arg_params=arg_params, aux_params=aux_params) + @staticmethod + def create(symbol, X, y=None, ctx=None, + num_round=None, optimizer='sgd', initializer=Xavier(), + eval_data=None, eval_metric='acc', iter_end_callback=None, + logger=None, **kwargs): + """Functional style to create a model. + + This function will be more consistent with functional + languages such as R, where mutation is not allowed. + + Parameters + ---------- + symbol : Symbol + The symbol configuration of computation network. + + X : DataIter + Training data + + y : numpy.ndarray, optional + If X is numpy.ndarray y is required to set + + ctx : Context or list of Context, optional + The device context of training and prediction. + To use multi GPU training, pass in a list of gpu contexts. + + num_round : int, optional + Training parameter, number of training rounds(iterations). + + optimizer : str or Optimizer, optional + Training parameter, name or optimizer object for training. + + initializier : initializer function, optional + Training parameter, the initialization scheme used. + + eval_data : DataIter or numpy.ndarray pair + If eval_set is numpy.ndarray pair, it should be (valid_data, valid_label) + + eval_metric : function + Evaluation metric function. + + iter_end_callback : callable(iteration, symbol, arg_params, aux_states) + A callback that is invoked at end of each iteration. + This can be used to checkpoint model each iteration. + + logger : logging logger, optional + """ + model = FeedForward(symbol, ctx=ctx, num_round=num_round, + optimizer=optimizer, initializer=initializer, **kwargs) + model.fit(X, y, eval_data=eval_data, eval_metric=eval_metric, + iter_end_callback=iter_end_callback, logger=logger) + return model diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h index b08ea9997369..f8cbf53d27b3 100644 --- a/src/kvstore/kvstore_local.h +++ b/src/kvstore/kvstore_local.h @@ -119,7 +119,7 @@ class KVStoreLocal : public KVStore { } else { CHECK_EQ(ctx.dev_mask(), gpu::kDevMask); NDArray *copy_buf = buf.AllocCopyBuf(ctx.dev_id, val[0].shape()); - CopyFromTo(val[0], copy_buf); + CopyFromTo(val[i], copy_buf); buf.merged += *copy_buf; } } diff --git a/tests/python/test_mlp_multi_devices.py.bak b/tests/python/test_mlp_multi_devices.py.bak deleted file mode 100644 index 7a2e7ce1938a..000000000000 --- a/tests/python/test_mlp_multi_devices.py.bak +++ /dev/null @@ -1,120 +0,0 @@ -# pylint: skip-file -import sys -sys.path.append('../../python/') - -import mxnet as mx -import numpy as np -import os, gzip -import pickle as pickle -import get_data - -# symbol net -data = mx.symbol.Variable('data') -fc1 = mx.symbol.FullyConnected(data = data, name='fc1', nb_hidden=128) -act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu") -fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', nb_hidden = 64) -act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu") -fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', nb_hidden=10) -mlp = mx.symbol.Softmax(data = fc3, name = 'mlp') - -# use multiple devices -num_devs = 2 -devs = [mx.Context('cpu', i) for i in range(num_devs)] - -# infer shape -batch_size = 100 -input_shape = (batch_size / num_devs, 784) -param_shapes, out_shapes, aux_shapes = mlp.infer_shape(data=input_shape) -param_names = mlp.list_arguments() - -# allocate memory -params = [[mx.narray.create(s, d) for s in param_shapes] for d in devs]; -grads = [[mx.narray.create(s, d) for s in param_shapes] for d in devs]; - -# only need to init param on device 0 -mx.kvstore.init_devices(devs) -sync_keys = [i for i,m in enumerate(param_names) if "weight" in m or "bias" in m] -np.random.seed(0) -for k in sync_keys: - if "weight" in param_names[k]: - params[0][k].numpy[:, :] = np.random.uniform(-0.07, 0.07, v.numpy.shape) - else: - params[0][k].numpy[:] = 0 -mx.kvstore.init((k,params[0][k]) for k in sync_keys) - -# register param updater -def make_updater(env): - def updater(grad, weight): - eta = env['lr'] / sqrt(env['iter']) / env['batch_size'] - env['iter'] += 1 - weight[:] -= eta * grad - return updater - -mx.kvstore.register(make_updater( - {'lr' : 0.1, 'batch_size' : batch_size, 'wd' : .00004})) - -# create exector for each device - -req = ['write_to' for i in range(len(param_names))] -executors = [mlp.bind(devs[i], params[i], grads[i], req) for i in range(num_devs)] -forward_out = [mx.narray.create(e.heads()[0].shape) for e in executors] - -# data reader -get_data.GetMNIST_ubyte() -train_dataiter = mx.io.MNISTIter( - image="data/train-images-idx3-ubyte", - label="data/train-labels-idx1-ubyte", - batch_size=batch_size, shuffle=True, flat=True, silent=False, seed=10) -val_dataiter = mx.io.MNISTIter( - image="data/t10k-images-idx3-ubyte", - label="data/t10k-labels-idx1-ubyte", - batch_size=batch_size, shuffle=True, flat=True, silent=False) - -def cal_acc(out, label): - pred = np.argmax(out, axis=1) - return np.sum(pred == label) * 1.0 / out.shape[0] - -def test_mlp(): - epoch = 9 - acc_train = 0. - acc_val = 0. - for i in range(epoch): - # train - print("Epoch %d" % i) - train_acc = 0.0 - for data, label in train_dataiter: - data = data.numpy - label = label.numpy.flatten() - k = batch_size / num_devs - - for d in range(num_devs): - # feed input - idx = range(d*k, (d+1)*k) - params[d][param_names.index('data')].numpy[:] = data[idx,:] - params[d][param_names.index('mlp_label')].numpy[:] = label[idx] - - # pull weight - mx.kvstore.pull((k,params[d][k]) for k in sync_keys) - - # forward and backward - executors[d].forward() - executors[d].heads()[0].copyto(forward_out[d]) - executors[d].backward([forward_out[d]]) - - # push gradient - mx.kvstore.push((k, grads[d][k]) for k in sync_keys) - - # evaluate. cannot put into the above for loop since it is blocked - # until all forwards are finished - for d in range(num_devs): - train_acc += cal_acc(forward_out[d].numpy, label[range(d*k, (d+1)*k)]) - - train_acc /= train_nbatch - train_nbatch += 1 - print("Train Acc: ", train_acc) - train_dataiter.reset() - - assert(acc_train > 0.98) - -if __name__ == "__main__": - test_mlp() diff --git a/tests/python/train/test_mlp.py b/tests/python/train/test_mlp.py index bd635e980297..dad0ef0f1db5 100644 --- a/tests/python/train/test_mlp.py +++ b/tests/python/train/test_mlp.py @@ -18,11 +18,7 @@ num_round = 4 prefix = './mlp' -model = mx.model.FeedForward(softmax, - [mx.cpu(i) for i in range(2)], - num_round=num_round, - learning_rate=0.01, wd=0.0004, - momentum=0.9) + #check data get_data.GetMNIST_ubyte() @@ -44,10 +40,17 @@ def test_mlp(): console.setLevel(logging.DEBUG) logging.getLogger('').addHandler(console) - model.fit(X=train_dataiter, - eval_data=val_dataiter, - iter_end_callback=mx.model.do_checkpoint(prefix)) - logging.info('Finish fit...') + model = mx.model.FeedForward.create( + softmax, + X=train_dataiter, + eval_data=val_dataiter, + iter_end_callback=mx.model.do_checkpoint(prefix), + ctx=[mx.cpu(i) for i in range(2)], + num_round=num_round, + learning_rate=0.01, wd=0.0004, + momentum=0.9) + + logging.info('Finish traning...') prob = model.predict(val_dataiter) logging.info('Finish predict...') val_dataiter.reset() @@ -69,6 +72,9 @@ def test_mlp(): assert np.sum(np.abs(prob - prob3)) == 0 # save model explicitly + + + model.save(prefix, 128) model4 = mx.model.FeedForward.load(prefix, 128) prob4 = model4.predict(val_dataiter)