diff --git a/docs/tutorials/gluon/fit_api_tutorial.md b/docs/tutorials/gluon/fit_api_tutorial.md index 3c541ddc3757..6cd48fbc2391 100644 --- a/docs/tutorials/gluon/fit_api_tutorial.md +++ b/docs/tutorials/gluon/fit_api_tutorial.md @@ -24,7 +24,6 @@ With the Fit API, you can train a deep learning model with miminal amount of cod To demonstrate the Fit API, this tutorial will train an Image Classification model using the [ResNet-18](https://arxiv.org/abs/1512.03385) architecture for the neural network. The model will be trained using the [Fashion-MNIST dataset](https://research.zalando.com/welcome/mission/research-projects/fashion-mnist/). - ## Prerequisites To complete this tutorial, you will need: @@ -41,7 +40,8 @@ from mxnet import gluon from mxnet.gluon.model_zoo import vision from mxnet.gluon.estimator import estimator, event_handler -ctx = mx.gpu() if mx.context.num_gpus() > 0 else mx.cpu() +gpu_count = mx.context.num_gpus() +ctx = [mx.gpu(i) for i in range(gpu_count)] if gpu_count > 0 else mx.cpu() mx.random.seed(7) # Set a fixed seed ``` @@ -84,8 +84,10 @@ fashion_mnist_val = fashion_mnist_val.transform_first(transforms) batch_size = 256 # Batch size of the images num_workers = 4 # The number of parallel workers for loading the data using Data Loaders. -train_data_loader = gluon.data.DataLoader(fashion_mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers) -val_data_loader = gluon.data.DataLoader(fashion_mnist_val, batch_size=batch_size, shuffle=False, num_workers=num_workers) +train_data_loader = gluon.data.DataLoader(fashion_mnist_train, batch_size=batch_size, + shuffle=True, num_workers=num_workers) +val_data_loader = gluon.data.DataLoader(fashion_mnist_val, batch_size=batch_size, + shuffle=False, num_workers=num_workers) ``` ## Model and Optimizers @@ -107,6 +109,7 @@ loss_fn = gluon.loss.SoftmaxCrossEntropyLoss() Let's define the trainer object for training the model. + ```python learning_rate = 0.04 # You can experiment with your own learning rate here num_epochs = 2 # You can run training for more epochs @@ -128,23 +131,23 @@ train_acc = mx.metric.Accuracy() # Metric to monitor # Define the estimator, by passing to it the model, loss function, metrics, trainer object and context est = estimator.Estimator(net=resnet_18_v1, - loss=loss_fn, - metrics=train_acc, - trainers=trainer, - context=ctx) + loss=loss_fn, + metrics=train_acc, + trainer=trainer, + context=ctx) # Magic line est.fit(train_data=train_data_loader, - epochs=num_epochs, - batch_size=batch_size) + epochs=num_epochs) ``` - [Epoch 0] [Step 256/60000] time/step: 1.420s accuracy: 0.0938 softmaxcrossentropyloss0: 2.9419 - .... - [Epoch 0] finished in 51.375s: train_accuracy: 0.7916 train_softmaxcrossentropyloss0: 0.5750 - [Epoch 1] [Step 256/60000] time/step: 0.414s accuracy: 0.8555 softmaxcrossentropyloss0: 0.3621 - .... - [Epoch 1] finished in 49.889s: train_accuracy: 0.8854 train_softmaxcrossentropyloss0: 0.3157 + Training begin: using optimizer SGD with current learning rate 0.0400 + Train for 2 epochs. + + [Epoch 0] finished in 25.110s: train_accuracy : 0.7877 train_softmaxcrossentropyloss0 : 0.5905 + + [Epoch 1] finished in 23.595s: train_accuracy : 0.8823 train_softmaxcrossentropyloss0 : 0.3197 + Train finished using total 48s at epoch 1. train_accuracy : 0.8823 train_softmaxcrossentropyloss0 : 0.3197 ### Advanced Usage @@ -161,41 +164,47 @@ resnet_18_v1.initialize(force_reinit=True, init = mx.init.Xavier(), ctx=ctx) trainer = gluon.Trainer(resnet_18_v1.collect_params(), 'sgd', {'learning_rate': learning_rate}) train_acc = mx.metric.Accuracy() - ``` ```python # Define the estimator, by passing to it the model, loss function, metrics, trainer object and context est = estimator.Estimator(net=resnet_18_v1, - loss=loss_fn, - metrics=train_acc, - trainers=trainer, - context=ctx) + loss=loss_fn, + metrics=train_acc, + trainer=trainer, + context=ctx) # Define the handlers, let's say Checkpointhandler -checkpoint_handler = event_handler.CheckpointHandler(estimator=est, - filepath='./my_best_model.params', - monitor='train_accuracy', # Monitors a metric +checkpoint_handler = event_handler.CheckpointHandler(filepath='./my_best_model.params', + monitor='val_accuracy', # Monitors a metric save_best_only=True) # Save the best model in terms of # training accuracy # Magic line est.fit(train_data=train_data_loader, - epochs=num_epochs, - event_handlers=checkpoint_handler, # Add the event handlers - batch_size=batch_size) + val_data=val_data_loader, + epochs=num_epochs, + event_handlers=checkpoint_handler) # Add the event handlers ``` - [Epoch 0] [Step 256/60000] time/step: 0.426s accuracy: 0.1211 softmaxcrossentropyloss0: 2.6261 - .... - [Epoch 0] finished in 50.390s: train_accuracy: 0.7936 train_softmaxcrossentropyloss0: 0.5639 - [Epoch 1] [Step 256/60000] time/step: 0.414s accuracy: 0.8984 softmaxcrossentropyloss0: 0.2958 - .... - [Epoch 1] finished in 50.474s: train_accuracy: 0.8871 train_softmaxcrossentropyloss0: 0.3101 + Training begin: using optimizer SGD with current learning rate 0.0400 + Train for 2 epochs. + + [Epoch 0] finished in 25.236s: train_accuracy : 0.7917 train_softmaxcrossentropyloss0 : 0.5741 val_accuracy : 0.6612 val_softmaxcrossentropyloss0 : 0.8627 + + [Epoch 1] finished in 24.892s: train_accuracy : 0.8826 train_softmaxcrossentropyloss0 : 0.3229 val_accuracy : 0.8474 val_softmaxcrossentropyloss0 : 0.4262 + + Train finished using total 50s at epoch 1. train_accuracy : 0.8826 train_softmaxcrossentropyloss0 : 0.3229 val_accuracy : 0.8474 val_softmaxcrossentropyloss0 : 0.4262 + +You can load the saved model, by using ```load_parameters``` API in Gluon. For more details refer to the [Loding model parameters from file tutorial](http://mxnet.incubator.apache.org/versions/master/tutorials/gluon/save_load_params.html#saving-model-parameters-to-file) -You can load the saved model, by using ```load_parameters``` API in Gluon. For more details refer to the [Loading model parameters from file tutorial](http://mxnet.incubator.apache.org/versions/master/tutorials/gluon/save_load_params.html#saving-model-parameters-to-file) + +```python +resnet_18_v1 = vision.resnet18_v1(pretrained=False, classes = 10) +resnet_18_v1.load_parameters('./my_best_model.params', ctx=ctx) +``` ## Summary