Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Address PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
piyushghai committed Apr 5, 2019
1 parent 15de480 commit 90c5aea
Showing 1 changed file with 17 additions and 15 deletions.
32 changes: 17 additions & 15 deletions docs/tutorials/gluon/fit_api_tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
<!--- under the License. -->


# Gluon Fit API
# MXNet Gluon Fit API

In this tutorial, we will see how to use the [Gluon Fit API](https://cwiki.apache.org/confluence/display/MXNET/Gluon+Fit+API+-+Tech+Design) which is a simple and flexible way to train deep learning models using the [Gluon APIs](http://mxnet.incubator.apache.org/versions/master/gluon/index.html) in Apache MXNet.
In this tutorial, we will see how to use the [Gluon Fit API](https://cwiki.apache.org/confluence/display/MXNET/Gluon+Fit+API+-+Tech+Design) which is the easiest way to train deep learning models using the [Gluon API](http://mxnet.incubator.apache.org/versions/master/gluon/index.html) in Apache MXNet.

Prior to Fit API, training using Gluon required one to write a custom ["Gluon training loop"](https://mxnet.incubator.apache.org/versions/master/tutorials/gluon/logistic_regression_explained.html#defining-and-training-the-model). Fit API reduces the complexity and amount of boiler plate code required to train a model, provides an easy to use and a powerful API.
With the Fit API, you can train a deep learning model with miminal amount of code. Just specify the network, loss function and the data you want to train on. You don't need to worry about the boiler plate code to loop through the dataset in batches(often called as 'training loop'). Advanced users can still do this for bespolke training loops, but most use cases will be covered by the Fit API.

To demonstrate the Fit API, this tutorial will train an Image Classification model using the [ResNet-18](https://arxiv.org/abs/1512.03385) architecture for the neural network. The model will be trained using the [Fashion-MNIST dataset](https://research.zalando.com/welcome/mission/research-projects/fashion-mnist/).

Expand All @@ -47,12 +47,12 @@ mx.random.seed(7) # Set a fixed seed

## Dataset

[Fashion-MNIST](https://research.zalando.com/welcome/mission/research-projects/fashion-mnist/) dataset consists of fashion items divided into ten categories : t-shirt/top, trouser, pullover, dress, coat, sandal, shirt, sneaker, bag and ankle boot.
[Fashion-MNIST](https://research.zalando.com/welcome/mission/research-projects/fashion-mnist/) dataset consists of fashion items divided into ten categories: t-shirt/top, trouser, pullover, dress, coat, sandal, shirt, sneaker, bag and ankle boot.

- It has 60,000 gray scale images of size 28 * 28 for training.
- It has 10,000 gray scale images os size 28 * 28 for testing/validation.

We will use ```gluon.data.vision``` package to directly import the Fashion-MNIST dataset and perform pre-processing on it.
We will use the ```gluon.data.vision``` package to directly import the Fashion-MNIST dataset and perform pre-processing on it.


```python
Expand Down Expand Up @@ -90,23 +90,25 @@ val_data_loader = gluon.data.DataLoader(fashion_mnist_val, batch_size=batch_size

## Model and Optimizers

Let's load the resnet-18 model architecture from [Gluon Model Zoo](http://mxnet.apache.org/api/python/gluon/model_zoo.html) and initialize it's parameters.
Let's load the resnet-18 model architecture from [Gluon Model Zoo](http://mxnet.apache.org/api/python/gluon/model_zoo.html) and initialize it's parameters. The Gluon Model Zoo contains a repository of pre-trained models as well the model architecture definitions. We are using the model architecture from the model zoo in order to train it from scratch.


```python
resnet_18_v1 = vision.resnet18_v1(pretrained=False, classes = 10, ctx=ctx)
resnet_18_v1.initialize(force_reinit=True, init = mx.init.Xavier(), ctx=ctx)
resnet_18_v1 = vision.resnet18_v1(pretrained=False, classes = 10)
resnet_18_v1.initialize(init = mx.init.Xavier(), ctx=ctx)
```

After defining the model, let's setup the trainer object for training.

We will be using ```SoftmaxCrossEntropyLoss``` as the loss function since this is a multi-class classification problem. We will be using ```sgd``` (Stochastic Gradient Descent) as the optimizer. You can experiment with a different optimizer as well.


```python
loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
learning_rate = 0.04 # You can experiment with your own learning rate here
```

Let's define the trainer object for training the model.

```python
learning_rate = 0.04 # You can experiment with your own learning rate here
num_epochs = 2 # You can run training for more epochs
trainer = gluon.Trainer(resnet_18_v1.collect_params(),
'sgd', {'learning_rate': learning_rate})
Expand Down Expand Up @@ -147,9 +149,9 @@ est.fit(train_data=train_data_loader,

### Advanced Usage

Fit API is also customizable with several `Event Handlers` which give a fine grained control over the steps in training and exposes callback methods for : `train_begin`, `train_end`, `batch_begin`, `batch_end`, `epoch_begin` and `epoch_end`.
Fit API is also customizable with several `Event Handlers` which give a fine grained control over the steps in training and exposes callback methods that provide control over the stages involved in training. Available callback methods are: `train_begin`, `train_end`, `batch_begin`, `batch_end`, `epoch_begin` and `epoch_end`.

One can use built-in event handlers such as ```LoggingHandler```, ```CheckpointHandler``` or ```EarlyStoppingHandler``` or to create a custom handler, one can create a new class by inherinting [```EventHandler```](https://github.com/apache/incubator-mxnet/blob/fit-api/python/mxnet/gluon/estimator/event_handler.py#L31).
One can use built-in event handlers such as `LoggingHandler`, `CheckpointHandler` or `EarlyStoppingHandler` to log and save the model at certain timesteps during training and stopping the training when the model's performance plateaus. One can also create a custom handler by inheriting [`EventHandler`](https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/gluon/estimator/event_handler.py#L31).


```python
Expand Down Expand Up @@ -181,7 +183,7 @@ checkpoint_handler = event_handler.CheckpointHandler(estimator=est,
# Magic line
est.fit(train_data=train_data_loader,
epochs=num_epochs,
event_handlers=[checkpoint_handler], # Add the event handlers
event_handlers=checkpoint_handler, # Add the event handlers
batch_size=batch_size)
```

Expand All @@ -198,7 +200,7 @@ You can load the saved model, by using ```load_parameters``` API in Gluon. For m
## Summary

In this tutorial, we learnt how to use ```Gluon Fit APIs``` for training a deep learning model and also saw an option to customize it with the use of Event Handlers.
For more references on the Fit API and advanced usage details, checkout its [documentation](http://mxnet.apache.org/api/python/gluon/gluon.html).
For more references and advanced usage details can be found in the [documentation](http://mxnet.apache.org/api/python/gluon/gluon.html).

## Next Steps

Expand Down

0 comments on commit 90c5aea

Please sign in to comment.