Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge pull request #6 from dmlc/master
Browse files Browse the repository at this point in the history
merge dmlc/master
  • Loading branch information
mli committed Sep 22, 2015
2 parents b60b496 + 522aa64 commit a026814
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 131 deletions.
2 changes: 2 additions & 0 deletions doc/python/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,6 @@ Python API Documents
--------------------
* [NDArray API](ndarray.md)
* [Symbolic API](symbol.md)
* [KVStore API](kvstore.md)
* [Data Loading API](io.md)
* [Model API](model.md)
116 changes: 116 additions & 0 deletions doc/python/model.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
MXNet Python Model API
======================
The model API in mxnet as not really an API.
It is a thin wrapper build on top of [ndarray](ndarray.md) and [symbolic](symbol.md)
modules to make neural network training easy.

* [Train a Model](#overloaded-operators) introduces operator overloading of symbols
* [Serialization](#serialization) introduces how to save and load symbols.
* [Multiple Outputs](#multiple-outputs) introduces how to configure multiple outputs
* [API Reference](#api-reference) gives reference to all functions.
* [Symbol Object Document](#mxnet.symbol.Symbol) gives API reference to the Symbol Object.


Train a Model
-------------
To train a model, you can follow two steps, first a configuration using symbol,
then call ```model.Feedforward.create``` to create a model for you.
The following example creates a two layer neural networks.

```python
batch_size = 100
data = mx.symbol.Variable('data')
fc1 = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128)
act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu")
fc2 = mx.symbol.FullyConnected(act1, name = 'fc2', num_hidden = 64)
softmax = mx.symbol.Softmax(fc2, name = 'sm')

model = mx.model.FeedForward.create(
softmax,
X=data_set,
num_round=num_round,
learning_rate=0.01)
```

You can also use scikit-learn style construct and fit function to create a model.
For more information, you can refer to [Model API Reference](#model-api-reference).

Save the Model
--------------
It is important to save your work after the job done.
To save the model, you can directly pickle it if you like the pythonic way.
We also provide a save and load function.

```python
# save a model to mymodel-symbol.json and mymodel-0100.params
prefix = 'mymodel'
model.save(prefix, 100)

# load model back
model_loaded = mx.model.FeedForward.load(prefix, 100)
```
The advantage of this save and load function is they are language agnostic,
and you should be able to save and load directly into cloud storage such as S3 and HDFS.

Periodically Checkpoint
-----------------------
It is also helpful to periodically checkpoint your model after each iteration.
To do so, you can simply add a checkpoint callback to the function.
The training process will automatically checkpoint to the specified place after
each iteration.

```python
prefix='models/chkpt'
model = mx.model.FeedForward.create(
softmax,
X=data_set,
iter_end_callback=mx.model.do_checkpoint(prefix),
num_round=num_round,
learning_rate=0.01)
```
You can load the model checkpoint later using ```Feedforward.load```.

Use Multiple Devices
--------------------
Simply set ```ctx``` to be the list of devices you like to train on.

```python
devices = [mx.gpu(i) for i in range(num_device)]
model = mx.model.FeedForward.create(
softmax,
X=dataset,
ctx=devices,
...)
```

Initializer API Reference
-------------------------

```eval_rst
.. automodule:: mxnet.initializer
:members:
```

Evaluation Metric API Reference
-------------------------------

```eval_rst
.. automodule:: mxnet.metric
:members:
```

Optimizer API Reference
-----------------------

```eval_rst
.. automodule:: mxnet.optimizer
:members:
```

Model API Reference
-------------------

```eval_rst
.. automodule:: mxnet.model
:members:
```
53 changes: 52 additions & 1 deletion python/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ def predict(self, X):

def fit(self, X, y=None, eval_data=None, eval_metric='acc',
iter_end_callback=None, logger=None):
"""fit the model
"""Fit the model.
Parameters
----------
Expand Down Expand Up @@ -629,3 +629,54 @@ def load(prefix, iteration, ctx=None):
return FeedForward(symbol, ctx=ctx,
arg_params=arg_params, aux_params=aux_params)

@staticmethod
def create(symbol, X, y=None, ctx=None,
num_round=None, optimizer='sgd', initializer=Xavier(),
eval_data=None, eval_metric='acc', iter_end_callback=None,
logger=None, **kwargs):
"""Functional style to create a model.
This function will be more consistent with functional
languages such as R, where mutation is not allowed.
Parameters
----------
symbol : Symbol
The symbol configuration of computation network.
X : DataIter
Training data
y : numpy.ndarray, optional
If X is numpy.ndarray y is required to set
ctx : Context or list of Context, optional
The device context of training and prediction.
To use multi GPU training, pass in a list of gpu contexts.
num_round : int, optional
Training parameter, number of training rounds(iterations).
optimizer : str or Optimizer, optional
Training parameter, name or optimizer object for training.
initializier : initializer function, optional
Training parameter, the initialization scheme used.
eval_data : DataIter or numpy.ndarray pair
If eval_set is numpy.ndarray pair, it should be (valid_data, valid_label)
eval_metric : function
Evaluation metric function.
iter_end_callback : callable(iteration, symbol, arg_params, aux_states)
A callback that is invoked at end of each iteration.
This can be used to checkpoint model each iteration.
logger : logging logger, optional
"""
model = FeedForward(symbol, ctx=ctx, num_round=num_round,
optimizer=optimizer, initializer=initializer, **kwargs)
model.fit(X, y, eval_data=eval_data, eval_metric=eval_metric,
iter_end_callback=iter_end_callback, logger=logger)
return model
2 changes: 1 addition & 1 deletion src/kvstore/kvstore_local.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class KVStoreLocal : public KVStore {
} else {
CHECK_EQ(ctx.dev_mask(), gpu::kDevMask);
NDArray *copy_buf = buf.AllocCopyBuf(ctx.dev_id, val[0].shape());
CopyFromTo(val[0], copy_buf);
CopyFromTo(val[i], copy_buf);
buf.merged += *copy_buf;
}
}
Expand Down
120 changes: 0 additions & 120 deletions tests/python/test_mlp_multi_devices.py.bak

This file was deleted.

24 changes: 15 additions & 9 deletions tests/python/train/test_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,7 @@

num_round = 4
prefix = './mlp'
model = mx.model.FeedForward(softmax,
[mx.cpu(i) for i in range(2)],
num_round=num_round,
learning_rate=0.01, wd=0.0004,
momentum=0.9)

#check data
get_data.GetMNIST_ubyte()

Expand All @@ -44,10 +40,17 @@ def test_mlp():
console.setLevel(logging.DEBUG)
logging.getLogger('').addHandler(console)

model.fit(X=train_dataiter,
eval_data=val_dataiter,
iter_end_callback=mx.model.do_checkpoint(prefix))
logging.info('Finish fit...')
model = mx.model.FeedForward.create(
softmax,
X=train_dataiter,
eval_data=val_dataiter,
iter_end_callback=mx.model.do_checkpoint(prefix),
ctx=[mx.cpu(i) for i in range(2)],
num_round=num_round,
learning_rate=0.01, wd=0.0004,
momentum=0.9)

logging.info('Finish traning...')
prob = model.predict(val_dataiter)
logging.info('Finish predict...')
val_dataiter.reset()
Expand All @@ -69,6 +72,9 @@ def test_mlp():
assert np.sum(np.abs(prob - prob3)) == 0

# save model explicitly



model.save(prefix, 128)
model4 = mx.model.FeedForward.load(prefix, 128)
prob4 = model4.predict(val_dataiter)
Expand Down

0 comments on commit a026814

Please sign in to comment.