Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-949] Module API to Gluon API tutorial #12542

Merged
merged 12 commits into from
Mar 21, 2019
3 changes: 2 additions & 1 deletion docs/tutorials/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ Select API: 
* [Learning Rate Schedules](/tutorials/gluon/learning_rate_schedules.html)
* [Advanced Learning Rate Schedules](/tutorials/gluon/learning_rate_schedules_advanced.html)
* [Profiling MXNet Models](/tutorials/python/profiler.html)
* [Module to Gluon API](/tutorials/python/module_to_gluon.html)<span style="color:red"> (new!)<span></span></span>
ThomasDelteil marked this conversation as resolved.
Show resolved Hide resolved
* API Guides
* Core APIs
* NDArray
Expand All @@ -89,6 +90,7 @@ Select API:&nbsp;
* [HybridBlocks](/tutorials/gluon/hybrid.html) ([Alternative](http://gluon.mxnet.io/chapter07_distributed-learning/hybridize.html) <img src="https://upload.wikimedia.org/wikipedia/commons/6/6a/External_link_font_awesome.svg" alt="External link" height="15px" style="margin: 0px 0px 3px 3px;"/>)
* [Block Naming](/tutorials/gluon/naming.html)
* [Custom Operators](/tutorials/gluon/customop.html)
* [Control Flow operators](/tutorials/control_flow/ControlFlowTutorial.html)<span style="color:red"> (new!)<span></span></span>
* Autograd
* [AutoGrad API](/tutorials/gluon/autograd.html)
* [AutoGrad API with chain rule](http://gluon.mxnet.io/chapter01_crashcourse/autograd.html) <img src="https://upload.wikimedia.org/wikipedia/commons/6/6a/External_link_font_awesome.svg" alt="External link" height="15px" style="margin: 0px 0px 3px 3px;"/>
Expand Down Expand Up @@ -117,7 +119,6 @@ Select API:&nbsp;
* [Fine-Tuning a pre-trained ImageNet model with a new dataset](/faq/finetune.html)
* [Large-Scale Multi-Host Multi-GPU Image Classification](/tutorials/vision/large_scale_classification.html)
* [Importing an ONNX model into MXNet](/tutorials/onnx/super_resolution.html)
* [Hybridize Gluon models with control flows](/tutorials/control_flow/ControlFlowTutorial.html)
* API Guides
* Core APIs
* NDArray
Expand Down
297 changes: 297 additions & 0 deletions docs/tutorials/python/module_to_gluon.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,297 @@

# Converting Module API code to the Gluon API

Sometimes, you find yourself in the situation where the model you want to use has been written using the symbolic Module API rather than the simpler, easier-to-debug, more flexible, imperative Gluon API. In this tutorial, we will give you a comprehensive guide you can use in order to see how you can transform your Module code, to work with the Gluon API.
ThomasDelteil marked this conversation as resolved.
Show resolved Hide resolved

The different steps to take into consideration are:

I) Data loading

II) Model definition

III) Loss

IV) Training Loop

V) Exporting Models

In the following section we will look at 1:1 mappings between the Module and the Gluon ways of training a neural networks.
ThomasDelteil marked this conversation as resolved.
Show resolved Hide resolved

## I - Data Loading


ThomasDelteil marked this conversation as resolved.
Show resolved Hide resolved
```python
import logging
logging.basicConfig(level=logging.INFO)

import numpy as np
import mxnet as mx
from mxnet.gluon.data import ArrayDataset, DataLoader
from mxnet.gluon import nn
from mxnet import gluon

batch_size = 5
dataset_length = 200
```

#### Module

When using the Module API we use a [`DataIter`](https://mxnet.incubator.apache.org/api/python/io/io.html?highlight=dataiter#mxnet.io.DataIter), in addition to the data itself, the [`DataIter`](https://mxnet.incubator.apache.org/api/python/io/io.html?highlight=dataiter#mxnet.io.DataIter) contains information about the name of the input symbols.

Let's create some random data, following the same format as grayscale 28x28 images.


```python
train_data = np.random.rand(dataset_length, 28,28).astype('float32')
ThomasDelteil marked this conversation as resolved.
Show resolved Hide resolved
train_label = np.random.randint(0, 10, (dataset_length,)).astype('float32')
```


ThomasDelteil marked this conversation as resolved.
Show resolved Hide resolved
```python
data_iter = mx.io.NDArrayIter(data=train_data, label=train_label, batch_size=batch_size, shuffle=False, data_name='data', label_name='softmax_label')
for batch in data_iter:
print(batch.data[0].shape, batch.label[0])
break;
ThomasDelteil marked this conversation as resolved.
Show resolved Hide resolved
```

(5, 28, 28)
[5. 0. 3. 4. 9.]
<NDArray 5 @cpu(0)>


#### Gluon

With Gluon, the preferred method is to use a [`DataLoader`](https://mxnet.incubator.apache.org/api/python/gluon/data.html?highlight=dataloader#mxnet.gluon.data.DataLoader) that make use of a [`Dataset`](https://mxnet.incubator.apache.org/api/python/gluon/data.html?highlight=dataset#mxnet.gluon.data.Dataset) to prefetch asynchronously the data.
ThomasDelteil marked this conversation as resolved.
Show resolved Hide resolved


```python
dataset = ArrayDataset(train_data, train_label)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)
for data, label in dataloader:
print(data.shape, label)
break
```

(5, 28, 28)
[5. 0. 3. 4. 9.]
<NDArray 5 @cpu(0)>


#### Notable differences

- Gluon keeps a strict separation between data holding, and data loading / fetching. The `Dataset` role is to hold onto some data, in or out of memory, and the `DataLoader` role is to request certain indices of the dataset, in the main thread or through multi-processing workers. This flexible API allows to efficiently pre-fetch data and separate the concerns.
ThomasDelteil marked this conversation as resolved.
Show resolved Hide resolved
- In the module API, `DataIter`s are responsible for both holding the data and iterating through it. Some `DataIter` support multi-threading like the [`ImageRecordIter`](https://mxnet.incubator.apache.org/api/python/io/io.html#mxnet.io.ImageRecordIter), while other don't like the `NDArrayIter`.
ThomasDelteil marked this conversation as resolved.
Show resolved Hide resolved
ThomasDelteil marked this conversation as resolved.
Show resolved Hide resolved

ThomasDelteil marked this conversation as resolved.
Show resolved Hide resolved
You can checkout the [`Dataset` and `DataLoader` tutorial](https://mxnet.incubator.apache.org/tutorials/gluon/datasets.html). You can either rewrite your code in order to use one of the provided [`Dataset`](https://mxnet.incubator.apache.org/api/python/gluon/data.html?highlight=dataset#mxnet.gluon.data.Dataset) class, like the [`ArrayDataset`](https://mxnet.incubator.apache.org/api/python/gluon/data.html?highlight=arraydataset#mxnet.gluon.data.ArrayDataset) or the [`ImageFolderDataset`](https://mxnet.incubator.apache.org/api/python/gluon/data.html?highlight=imagefolderdataset#mxnet.gluon.data.vision.datasets.ImageFolderDataset), or you can simply wrap your existing [`DataIter`](https://mxnet.incubator.apache.org/api/python/io/io.html?highlight=dataiter#mxnet.io.DataIter) to have a similar usage pattern as a `DataLoader`:


```python
class DataIterLoader():
def __init__(self, data_iter):
self.data_iter = data_iter

def __iter__(self):
self.data_iter.reset()
return self

def __next__(self):
batch = self.data_iter.__next__()
assert len(batch.data) == len(batch.label) == 1
data = batch.data[0]
label = batch.label[0]
return data, label

def next(self):
return self.__next__() # for Python 2
```


ThomasDelteil marked this conversation as resolved.
Show resolved Hide resolved
```python
data_iter = mx.io.NDArrayIter(data=train_data, label=train_label, batch_size=batch_size)
data_iter_loader = DataIterLoader(data_iter)
for data, label in data_iter_loader:
print(data.shape, label)
break
```

(5, 28, 28)
[5. 0. 3. 4. 9.]
<NDArray 5 @cpu(0)>


## II - Model definition

Let's look at the model definition from the [MNIST Module Tutorial](https://mxnet.incubator.apache.org/tutorials/python/mnist.html):


```python
ctx = mx.gpu()
ThomasDelteil marked this conversation as resolved.
Show resolved Hide resolved
```

#### Module

For the Module API, you define the data flow by setting `data` keyword argument of one layer to the next.
You then bind the symbolic model to a specific compute context and specify the symbol names for the data and the label.
ThomasDelteil marked this conversation as resolved.
Show resolved Hide resolved

```python
data = mx.sym.var('data')
data = mx.sym.flatten(data=data)
fc1 = mx.sym.FullyConnected(data=data, num_hidden=128)
act1 = mx.sym.Activation(data=fc1, act_type="relu")
fc2 = mx.sym.FullyConnected(data=act1, num_hidden = 64)
act2 = mx.sym.Activation(data=fc2, act_type="relu")
fc3 = mx.sym.FullyConnected(data=act2, num_hidden=10)
mlp = mx.sym.SoftmaxOutput(data=fc3, name='softmax')

# Bind model to Module
mlp_model = mx.mod.Module(symbol=mlp, context=ctx, data_names=['data'], label_names=['softmax_label'])
```

#### Gluon

In Gluon, for a sequential model like that, you would create a `Sequential` block, in that case a `HybridSequential` block to allow for future hybridization since we are only using hybridizable blocks. Learn more [about hybridization](https://mxnet.incubator.apache.org/tutorials/gluon/hybrid.html). The flow of the data will be automatically set from one layer to the next, since they are held in a `Sequential` block.
ThomasDelteil marked this conversation as resolved.
Show resolved Hide resolved


```python
net = nn.HybridSequential()
with net.name_scope():
net.add(
nn.Flatten(),
nn.Dense(units=128, activation="relu"),
nn.Dense(units=64, activation="relu"),
nn.Dense(units=10)
)
```

ThomasDelteil marked this conversation as resolved.
Show resolved Hide resolved
## III - Loss

The loss, that you are trying to minimize using an optimization algorithm like SGD, is defined differently in the Module API than in Gluon.

In the module API, the loss is part of the network. It has usually a forward pass result, that is the inference value, and a backward pass that is the gradient of the output with respect to that particular loss.

For example the [sym.SoftmaxOutput](https://mxnet.incubator.apache.org/api/python/symbol/symbol.html?highlight=softmaxout#mxnet.symbol.SoftmaxOutput) is a softmax output in the forward pass and the gradient with respect to the cross-entropy loss in the backward pass.
ThomasDelteil marked this conversation as resolved.
Show resolved Hide resolved
ThomasDelteil marked this conversation as resolved.
Show resolved Hide resolved

In Gluon, it is a lot more transparent. Losses, like the [SoftmaxCrossEntropyLoss](https://mxnet.incubator.apache.org/api/python/gluon/loss.html?highlight=softmaxcross#mxnet.gluon.loss.SoftmaxCrossEntropyLoss), are only computing the actual value of the loss. You then call `.backward()` on the loss value to compute the gradient of the parameters with respect to that loss. At inference time, you simply call `.softmax()` on your output to get the output of your network normalized according to the softmax function.

#### Module


ThomasDelteil marked this conversation as resolved.
Show resolved Hide resolved
```python
# Softmax with cross entropy loss, directly part of the network
mlp = mx.sym.SoftmaxOutput(data=fc3, name='softmax')
```

#### Gluon


ThomasDelteil marked this conversation as resolved.
Show resolved Hide resolved
```python
# We simply create a loss function we will use in our training loop
loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
```

## IV - Training Loop

The Module API provides a [`.fit()`](https://mxnet.incubator.apache.org/api/python/module/module.html?highlight=.fit#mxnet.module.BaseModule.fit) functions that takes care of fitting training data to your symbolic model. With Gluon, you execution flow controls the data flow, so you need to write your own loop. It might seems like it is more verbose, but you have a lot more control as to what is happening during the training.
ThomasDelteil marked this conversation as resolved.
Show resolved Hide resolved
With the [`.fit()`](https://mxnet.incubator.apache.org/api/python/module/module.html?highlight=.fit#mxnet.module.BaseModule.fit) function, you control the metric reporting, checkpointing, through a lot of different keyword arguments (check the [docs](https://mxnet.incubator.apache.org/api/python/module/module.html?highlight=.fit#mxnet.module.BaseModule.fit)). That is where you define the optimizer for example.
ThomasDelteil marked this conversation as resolved.
Show resolved Hide resolved

With Gluon, you do these operations directly in the training loop, and the optimizer is part of the [`Trainer`](https://mxnet.incubator.apache.org/api/python/gluon/gluon.html?highlight=trainer#mxnet.gluon.Trainer) object that handles the weight updates of your parameters.

#### Module


ThomasDelteil marked this conversation as resolved.
Show resolved Hide resolved
```python
mlp_model.fit(data_iter, # train data
eval_data=data_iter, # validation data
optimizer='adam', # use SGD to train
ThomasDelteil marked this conversation as resolved.
Show resolved Hide resolved
force_init=True,
force_rebind=True,
optimizer_params={'learning_rate':0.1}, # use fixed learning rate
eval_metric='acc', # report accuracy during training
num_epoch=5) # train for at most 10 dataset passes
ThomasDelteil marked this conversation as resolved.
Show resolved Hide resolved
```

```INFO:root:Epoch[4] Train-accuracy=0.070000```<!--notebook-skip-line-->

```INFO:root:Epoch[4] Time cost=0.038```<!--notebook-skip-line-->

```INFO:root:Epoch[4] Validation-accuracy=0.125000```<!--notebook-skip-line-->

#### Gluon


ThomasDelteil marked this conversation as resolved.
Show resolved Hide resolved
```python
# Initialize network and trainer
net.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.1})

# Pick a metric
metric = mx.metric.Accuracy()

for e in range(5): # start of epoch

for data, label in dataloader: # start of mini-batch
data = data.as_in_context(ctx)
label = label.as_in_context(ctx)

with mx.autograd.record():
output = net(data) # forward pass
loss = loss_fn(output, label) # get loss
loss.backward() # compute gradients

trainer.step(data.shape[0]) # update weights with SGD
metric.update(label, output) # update the metrics
# end of mini-batch
name, acc = metric.get()
print('training metrics at epoch %d: %s=%f'%(e, name, acc))
metric.reset()
# end of epoch
```

```training metrics at epoch 3: accuracy=0.155000```<!--notebook-skip-line-->

```training metrics at epoch 4: accuracy=0.145000```<!--notebook-skip-line-->

ThomasDelteil marked this conversation as resolved.
Show resolved Hide resolved

## V - Exporting model

The ultimate purpose of training a model is to be able to export it and share it, whether it is for deployment or simply reproducibility purposes.

With the Module API, you can save model using the [`.save_checkpoint()`](https://mxnet.incubator.apache.org/api/python/module/module.html?highlight=save_chec#mxnet.module.Module.save_checkpoint) and get a `-symbol.json` and a `.params` file that represent your network.

ThomasDelteil marked this conversation as resolved.
Show resolved Hide resolved
With Gluon, network parameters are associated with a `Block`, but the execution flow is controlled in python through the code in `.forward()` function. Hence only [hybridized networks]() can be exported with a `-symbol.json` and `.params` file using [`.export()`](https://mxnet.incubator.apache.org/api/python/gluon/gluon.html?highlight=export#mxnet.gluon.HybridBlock.export), non-hybridized models can only have their parameters exported using [`.save_parameters()`](https://mxnet.incubator.apache.org/api/python/gluon/gluon.html?highlight=save_pa#mxnet.gluon.Block.save_parameters). Check this great tutorial to learn more: [Saving and Loading Gluon Models](https://mxnet.incubator.apache.org/tutorials/gluon/save_load_params.html).

#### Module


```python
mlp_model.save_checkpoint('module-model', epoch=5)
# nodule-model-0005.params module-model-symbol.json
ThomasDelteil marked this conversation as resolved.
Show resolved Hide resolved
```

```INFO:root:Saved checkpoint to "module-model-0005.params"```<!--notebook-skip-line-->

#### Gluon


```python
# save only the parameters
net.save_parameters('gluon-model.params')
# gluon-model.params
```


```python
# save the parameters and the symbolic representation
net.hybridize()
net(mx.nd.ones((1,1,28,28), ctx))

net.export('gluon-model-hybrid', epoch=5)
# gluon-model-hybrid-symbol.json gluon-model-hybrid-0005.params
```

ThomasDelteil marked this conversation as resolved.
Show resolved Hide resolved
## Conclusion

This tutorial lead you through the steps necessary to train a deep learning model and showed you the difference between the symbolic approach of the Module API and the imperative one of the Gluon API. If you need more help converting your Module API code to the Gluon API, reach out to the community on the [discuss forum](https://discuss.mxnet.io)!
ThomasDelteil marked this conversation as resolved.
Show resolved Hide resolved


<!-- INSERT SOURCE DOWNLOAD BUTTONS -->
4 changes: 4 additions & 0 deletions tests/tutorials/test_tutorials.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ def test_python_data_augmentation_with_masks():
def test_python_kvstore():
assert _test_tutorial_nb('python/kvstore')

def test_module_to_gluon():
assert _test_tutorial_nb('python/module_to_gluon')

def test_python_types_of_data_augmentation():
assert _test_tutorial_nb('python/types_of_data_augmentation')

Expand Down Expand Up @@ -189,3 +192,4 @@ def test_vision_cnn_visualization():

def test_control_flow():
assert _test_tutorial_nb('control_flow/ControlFlowTutorial')