From a7fc95ac89838697ab827b25658de90a3d0165f7 Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Thu, 1 Aug 2019 15:09:04 -0700 Subject: [PATCH] [MXNET-1358] Fit api tutorial (#15353) * Added tutorial for FIT API * Added tests for Fit API tutorial * Updated index.md for the new tutorial to show up * Addressed PR feedback * Addressed PR feedback * Removed spurious comment for Py2 and Py3 compatibility * Address PR feedback * Addressed PR feedback * Fixed typo * Added example to showcase custom event handler * Fixed imports as estimator moved to contrib package * Added a side note to inform about estimator reference being updated by the handlers * Corrected typo * update tutorial * address comments * new line * fix import * fix cached graph * fix import * address comments * fix doc gen * add softmax * add to website index * fix doc string * Fix doc gen (#12) * fix warining * fix test * fix * fix * fix print * fix test (#13) * fix warning (#14) * fix href (#15) --- docs/api/python/gluon/contrib.md | 30 ++ docs/tutorials/gluon/fit_api_tutorial.md | 271 ++++++++++++++++++ docs/tutorials/index.md | 2 + python/mxnet/gluon/contrib/__init__.py | 2 + .../mxnet/gluon/contrib/estimator/__init__.py | 2 + .../gluon/contrib/estimator/estimator.py | 64 +++-- .../gluon/contrib/estimator/event_handler.py | 23 +- tests/python/unittest/test_gluon_estimator.py | 7 +- tests/tutorials/test_tutorials.py | 3 + 9 files changed, 367 insertions(+), 37 deletions(-) create mode 100644 docs/tutorials/gluon/fit_api_tutorial.md diff --git a/docs/api/python/gluon/contrib.md b/docs/api/python/gluon/contrib.md index a940f697de69..22cdebb53b85 100644 --- a/docs/api/python/gluon/contrib.md +++ b/docs/api/python/gluon/contrib.md @@ -114,6 +114,33 @@ In the rest of this document, we list routines provided by the `gluon.contrib` p WikiText103 ``` +### Estimator + +```eval_rst +.. currentmodule:: mxnet.gluon.contrib.estimator + +.. autosummary:: + :nosignatures: + + Estimator +``` + +#### EventHandler + +```eval_rst +.. currentmodule:: mxnet.gluon.contrib.estimator + +.. autosummary:: + :nosignatures: + + StoppingHandler + MetricHandler + ValidationHandler + LoggingHandler + CheckpointHandler + EarlyStoppingHandler +``` + ## API Reference @@ -144,6 +171,9 @@ In the rest of this document, we list routines provided by the `gluon.contrib` p :members: :imported-members: +.. automodule:: mxnet.gluon.contrib.estimator + :members: + :imported-members: ``` diff --git a/docs/tutorials/gluon/fit_api_tutorial.md b/docs/tutorials/gluon/fit_api_tutorial.md new file mode 100644 index 000000000000..bc50690ac1a2 --- /dev/null +++ b/docs/tutorials/gluon/fit_api_tutorial.md @@ -0,0 +1,271 @@ + + + + + + + + + + + + + + + + + + +# MXNet Gluon Fit API + +In this tutorial, you will learn how to use the [Gluon Fit API](https://cwiki.apache.org/confluence/display/MXNET/Gluon+Fit+API+-+Tech+Design) which is the easiest way to train deep learning models using the [Gluon API](http://mxnet.incubator.apache.org/versions/master/gluon/index.html) in Apache MXNet. + +With the Fit API, you can train a deep learning model with a minimal amount of code. Just specify the network, loss function and the data you want to train on. You don't need to worry about the boiler plate code to loop through the dataset in batches (often called as 'training loop'). Advanced users can train with bespoke training loops, and many of these use cases will be covered by the Fit API. + +To demonstrate the Fit API, you will train an image classification model using the [ResNet-18](https://arxiv.org/abs/1512.03385) neural network architecture. The model will be trained using the [Fashion-MNIST dataset](https://research.zalando.com/welcome/mission/research-projects/fashion-mnist/). + +## Prerequisites + +To complete this tutorial, you will need: + +- [MXNet](https://mxnet.incubator.apache.org/install/#overview) (The version of MXNet will be >= 1.5.0, you can use `pip install mxnet` to get 1.5.0 release pip package or build from source with master, refer to [MXNet installation](http://mxnet.incubator.apache.org/versions/master/install/index.html?platform=Linux&language=Python&processor=CPU) +- [Jupyter Notebook](https://jupyter.org/index.html) (For interactively running the provided .ipynb file) + + + + +```python +import mxnet as mx +from mxnet import gluon +from mxnet.gluon.model_zoo import vision +from mxnet.gluon.contrib.estimator import estimator +from mxnet.gluon.contrib.estimator.event_handler import TrainBegin, TrainEnd, EpochEnd, CheckpointHandler + +gpu_count = mx.context.num_gpus() +ctx = [mx.gpu(i) for i in range(gpu_count)] if gpu_count > 0 else mx.cpu() +``` + +## Dataset + +[Fashion-MNIST](https://research.zalando.com/welcome/mission/research-projects/fashion-mnist/) dataset consists of fashion items divided into ten categories: t-shirt/top, trouser, pullover, dress, coat, sandal, shirt, sneaker, bag and ankle boot. + +- It has 60,000 grayscale images of size 28 * 28 for training. +- It has 10,000 grayscale images of size 28 * 28 for testing/validation. + +We will use the ```gluon.data.vision``` package to directly import the Fashion-MNIST dataset and perform pre-processing on it. + + +```python +# Get the training data +fashion_mnist_train = gluon.data.vision.FashionMNIST(train=True) + +# Get the validation data +fashion_mnist_val = gluon.data.vision.FashionMNIST(train=False) +``` + + +```python +transforms = [gluon.data.vision.transforms.Resize(224), # We pick 224 as the model we use takes an input of size 224. + gluon.data.vision.transforms.ToTensor()] + +# Now we will stack all these together. +transforms = gluon.data.vision.transforms.Compose(transforms) +``` + + +```python +# Apply the transformations +fashion_mnist_train = fashion_mnist_train.transform_first(transforms) +fashion_mnist_val = fashion_mnist_val.transform_first(transforms) +``` + + +```python +batch_size = 256 # Batch size of the images +num_workers = 4 # The number of parallel workers for loading the data using Data Loaders. + +train_data_loader = gluon.data.DataLoader(fashion_mnist_train, batch_size=batch_size, + shuffle=True, num_workers=num_workers) +val_data_loader = gluon.data.DataLoader(fashion_mnist_val, batch_size=batch_size, + shuffle=False, num_workers=num_workers) +``` + +## Model and Optimizers + +Let's load the resnet-18 model architecture from [Gluon Model Zoo](http://mxnet.apache.org/api/python/gluon/model_zoo.html) and initialize its parameters. The Gluon Model Zoo contains a repository of pre-trained models as well the model architecture definitions. We are using the model architecture from the model zoo in order to train it from scratch. + + +```python +resnet_18_v1 = vision.resnet18_v1(pretrained=False, classes = 10) +resnet_18_v1.initialize(init = mx.init.Xavier(), ctx=ctx) +``` + +We will be using `SoftmaxCrossEntropyLoss` as the loss function since this is a multi-class classification problem. We will be using `sgd` (Stochastic Gradient Descent) as the optimizer. +You can experiment with a [different loss](http://mxnet.incubator.apache.org/versions/master/api/python/gluon/loss.html) or [optimizer](http://mxnet.incubator.apache.org/versions/master/api/python/optimization/optimization.html) as well. + + +```python +loss_fn = gluon.loss.SoftmaxCrossEntropyLoss() +``` + +Let's define the trainer object for training the model. + + +```python +learning_rate = 0.04 # You can experiment with your own learning rate here +num_epochs = 2 # You can run training for more epochs +trainer = gluon.Trainer(resnet_18_v1.collect_params(), + 'sgd', {'learning_rate': learning_rate}) +``` + +## Train using Fit API + +As stated earlier, the Fit API greatly simplifies the boiler plate code and complexity for training using MXNet Gluon. + +In the basic usage example, with just 2 lines of code, we will set up our model for training. + +### Basic Usage + + +```python +train_acc = mx.metric.Accuracy() # Metric to monitor + +# Define the estimator, by passing to it the model, loss function, metrics, trainer object and context +est = estimator.Estimator(net=resnet_18_v1, + loss=loss_fn, + metrics=train_acc, + trainer=trainer, + context=ctx) + +# ignore warnings for nightly test on CI only +import warnings +with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # Magic line + est.fit(train_data=train_data_loader, + epochs=num_epochs) +``` + + Training begin: using optimizer SGD with current learning rate 0.0400 + Train for 2 epochs. + + [Epoch 0] finished in 25.110s: train_accuracy : 0.7877 train_softmaxcrossentropyloss0 : 0.5905 + + [Epoch 1] finished in 23.595s: train_accuracy : 0.8823 train_softmaxcrossentropyloss0 : 0.3197 + Train finished using total 48s at epoch 1. train_accuracy : 0.8823 train_softmaxcrossentropyloss0 : 0.3197 + + +### Advanced Usage + +The Fit API is also customizable with several `Event Handlers` which give a fine grained control over the steps in training and exposes callback methods that provide control over the stages involved in training. Available callback methods are: `train_begin`, `train_end`, `batch_begin`, `batch_end`, `epoch_begin` and `epoch_end`. + +You can use built-in event handlers such as `LoggingHandler`, `CheckpointHandler` or `EarlyStoppingHandler` to log and save the model at certain time-steps during training. You can also stop the training when the model's performance plateaus. +There are also some default utility handlers that will be added to your estimator by default. For example, `StoppingHandler` is used to control when the training ends, based on number of epochs or number of batches trained. +`MetricHandler` is used to calculate training metrics at end of each batch and epoch. +`ValidationHandler` is used to validate your model on test data at each epoch's end and then calculate validation metrics. +You can create these utility handlers with different configurations and pass to estimator. This will override the default handler configuration. +You can create a custom handler by inheriting one or multiple +[base event handlers](https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/gluon/contrib/estimator/event_handler.py#L32) + including: `TrainBegin`, `TrainEnd`, `EpochBegin`, `EpochEnd`, `BatchBegin`, `BatchEnd`. + + +### Custom Event Handler + +Here we will showcase an example custom event handler the inherits features from a few base handler classes. +Our custom event handler is a simple one: record the loss values at the end of every epoch in our training phase. + +Note: For each of the method, the `Estimator` object is passed along, so you can access training metrics. + +```python +class LossRecordHandler(TrainBegin, TrainEnd, EpochEnd): + def __init__(self): + super(LossRecordHandler, self).__init__() + self.loss_history = {} + + def train_begin(self, estimator, *args, **kwargs): + print("Training begin") + + def train_end(self, estimator, *args, **kwargs): + # Print all the losses at the end of training + print("Training ended") + for loss_name in self.loss_history: + for i, loss_val in enumerate(self.loss_history[loss_name]): + print("Epoch: {}, Loss name: {}, Loss value: {}".format(i, loss_name, loss_val)) + + def epoch_end(self, estimator, *args, **kwargs): + for metric in estimator.train_metrics: + # look for train Loss in training metrics + # we wrapped loss value as a metric to record it + if isinstance(metric, mx.metric.Loss): + loss_name, loss_val = metric.get() + # append loss value for this epoch + self.loss_history.setdefault(loss_name, []).append(loss_val) +``` + + +```python +# Let's reset the model, trainer and accuracy objects from above + +resnet_18_v1.initialize(force_reinit=True, init = mx.init.Xavier(), ctx=ctx) +trainer = gluon.Trainer(resnet_18_v1.collect_params(), + 'sgd', {'learning_rate': learning_rate}) +train_acc = mx.metric.Accuracy() +``` + + +```python +# Define the estimator, by passing to it the model, loss function, metrics, trainer object and context +est = estimator.Estimator(net=resnet_18_v1, + loss=loss_fn, + metrics=train_acc, + trainer=trainer, + context=ctx) + +# Define the handlers, let's say in built Checkpointhandler +checkpoint_handler = CheckpointHandler(model_dir='./', + model_prefix='my_model', + monitor=train_acc, # Monitors a metric + save_best=True) # Save the best model in terms of +# Let's instantiate another handler which we defined above +loss_record_handler = LossRecordHandler() +# ignore warnings for nightly test on CI only +import warnings +with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # Magic line + est.fit(train_data=train_data_loader, + val_data=val_data_loader, + epochs=num_epochs, + event_handlers=[checkpoint_handler, loss_record_handler]) # Add the event handlers +``` + + Training begin: using optimizer SGD with current learning rate 0.0400 + Train for 2 epochs. + + [Epoch 0] finished in 25.236s: train_accuracy : 0.7917 train_softmaxcrossentropyloss0 : 0.5741 val_accuracy : 0.6612 val_softmaxcrossentropyloss0 : 0.8627 + + [Epoch 1] finished in 24.892s: train_accuracy : 0.8826 train_softmaxcrossentropyloss0 : 0.3229 val_accuracy : 0.8474 val_softmaxcrossentropyloss0 : 0.4262 + + Train finished using total 50s at epoch 1. train_accuracy : 0.8826 train_softmaxcrossentropyloss0 : 0.3229 val_accuracy : 0.8474 val_softmaxcrossentropyloss0 : 0.4262 + + Training begin + Epoch 1, loss 0.5741 + Epoch 2, loss 0.3229 + +You can load the saved model, by using the `load_parameters` API in Gluon. For more details refer to the [Loading model parameters from file tutorial](save_load_params.html#saving-model-parameters-to-file) + + +```python +resnet_18_v1 = vision.resnet18_v1(pretrained=False, classes=10) +resnet_18_v1.load_parameters('./my_model-best.params', ctx=ctx) +``` + +## Summary + +- To learn more about deep learning with MXNeT, see [Dive Into Deep Learning](http://gluon.io) + +## Next Steps + +- For more hands on learning about deep learning, check out [Dive into Deep Learning](https://d2l.ai) + + diff --git a/docs/tutorials/index.md b/docs/tutorials/index.md index e01a30dbe68c..f773a79f63a7 100644 --- a/docs/tutorials/index.md +++ b/docs/tutorials/index.md @@ -139,6 +139,8 @@ Select API:  * [Data Transforms](/tutorials/gluon/transforms.html) * [Applying Data Augmentation](/tutorials/gluon/data_augmentation.html) * [Data Augmentation with Masks (for Object Segmentation)](https://mxnet.incubator.apache.org/tutorials/python/data_augmentation_with_masks.html) + * Fit API + * [Using Fit API](/tutorials/gluon/fit_api_tutorial.html)
diff --git a/python/mxnet/gluon/contrib/__init__.py b/python/mxnet/gluon/contrib/__init__.py index 83be8a39ba32..7590eb740f67 100644 --- a/python/mxnet/gluon/contrib/__init__.py +++ b/python/mxnet/gluon/contrib/__init__.py @@ -25,3 +25,5 @@ from . import cnn from . import data + +from . import estimator diff --git a/python/mxnet/gluon/contrib/estimator/__init__.py b/python/mxnet/gluon/contrib/estimator/__init__.py index 58600dadffb4..bb0a0917c363 100644 --- a/python/mxnet/gluon/contrib/estimator/__init__.py +++ b/python/mxnet/gluon/contrib/estimator/__init__.py @@ -17,5 +17,7 @@ # pylint: disable=wildcard-import """Gluon Estimator Module""" +from . import estimator +from . import event_handler from .estimator import * from .event_handler import * diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py index da1a3915caec..b6142e100d96 100644 --- a/python/mxnet/gluon/contrib/estimator/estimator.py +++ b/python/mxnet/gluon/contrib/estimator/estimator.py @@ -24,9 +24,15 @@ from .event_handler import MetricHandler, ValidationHandler, LoggingHandler, StoppingHandler from .event_handler import TrainBegin, EpochBegin, BatchBegin, BatchEnd, EpochEnd, TrainEnd -from .... import gluon, autograd +from ...data import DataLoader +from ...loss import SoftmaxCrossEntropyLoss +from ...loss import Loss as gluon_loss +from ...trainer import Trainer +from ...utils import split_and_load +from .... import autograd from ....context import Context, cpu, gpu, num_gpus -from ....metric import EvalMetric, Loss, Accuracy +from ....metric import EvalMetric, Accuracy +from ....metric import Loss as metric_loss __all__ = ['Estimator'] @@ -69,9 +75,9 @@ def __init__(self, net, self.trainer = self._check_trainer(trainer) def _check_loss(self, loss): - if isinstance(loss, gluon.loss.Loss): + if isinstance(loss, gluon_loss): loss = [loss] - elif isinstance(loss, list) and all([isinstance(l, gluon.loss.Loss) for l in loss]): + elif isinstance(loss, list) and all([isinstance(l, gluon_loss) for l in loss]): loss = loss else: raise ValueError("loss must be a Loss or a list of Loss, " @@ -146,9 +152,9 @@ def _check_trainer(self, trainer): if not trainer: warnings.warn("No trainer specified, default SGD optimizer " "with learning rate 0.001 is used.") - trainer = gluon.Trainer(self.net.collect_params(), - 'sgd', {'learning_rate': 0.001}) - elif not isinstance(trainer, gluon.Trainer): + trainer = Trainer(self.net.collect_params(), + 'sgd', {'learning_rate': 0.001}) + elif not isinstance(trainer, Trainer): raise ValueError("Trainer must be a Gluon Trainer instance, refer to " "gluon.Trainer:{}".format(trainer)) return trainer @@ -165,8 +171,8 @@ def _is_initialized(self): def _get_data_and_label(self, batch, ctx, batch_axis=0): data = batch[0] label = batch[1] - data = gluon.utils.split_and_load(data, ctx_list=ctx, batch_axis=batch_axis) - label = gluon.utils.split_and_load(label, ctx_list=ctx, batch_axis=batch_axis) + data = split_and_load(data, ctx_list=ctx, batch_axis=batch_axis) + label = split_and_load(label, ctx_list=ctx, batch_axis=batch_axis) return data, label def prepare_loss_and_metrics(self): @@ -179,13 +185,13 @@ def prepare_loss_and_metrics(self): """ if any(not hasattr(self, attribute) for attribute in ['train_metrics', 'val_metrics']): - # Use default mx.metric.Accuracy() for gluon.loss.SoftmaxCrossEntropyLoss() - if not self.train_metrics and any([isinstance(l, gluon.loss.SoftmaxCrossEntropyLoss) for l in self.loss]): + # Use default mx.metric.Accuracy() for SoftmaxCrossEntropyLoss() + if not self.train_metrics and any([isinstance(l, SoftmaxCrossEntropyLoss) for l in self.loss]): self.train_metrics = [Accuracy()] self.val_metrics = [] for loss in self.loss: # remove trailing numbers from loss name to avoid confusion - self.train_metrics.append(Loss(loss.name.rstrip('1234567890'))) + self.train_metrics.append(metric_loss(loss.name.rstrip('1234567890'))) for metric in self.train_metrics: val_metric = copy.deepcopy(metric) metric.name = "train " + metric.name @@ -208,10 +214,10 @@ def evaluate(self, batch_axis : int, default 0 Batch axis to split the validation data into devices. """ - if not isinstance(val_data, gluon.data.DataLoader): + if not isinstance(val_data, DataLoader): raise ValueError("Estimator only support input as Gluon DataLoader. Alternatively, you " "can transform your DataIter or any NDArray into Gluon DataLoader. " - "Refer to gluon.data.dataloader") + "Refer to gluon.data.DataLoader") for metric in val_metrics: metric.reset() @@ -222,7 +228,7 @@ def evaluate(self, loss = [self.loss[0](y_hat, y) for y_hat, y in zip(pred, label)] # update metrics for metric in val_metrics: - if isinstance(metric, Loss): + if isinstance(metric, metric_loss): metric.update(0, loss) else: metric.update(label, pred) @@ -254,7 +260,7 @@ def fit(self, train_data, batch_axis : int, default 0 Batch axis to split the training data into devices. """ - if not isinstance(train_data, gluon.data.DataLoader): + if not isinstance(train_data, DataLoader): raise ValueError("Estimator only support input as Gluon DataLoader. Alternatively, you " "can transform your DataIter or any NDArray into Gluon DataLoader. " "Refer to gluon.data.dataloader") @@ -328,28 +334,36 @@ def fit(self, train_data, def _prepare_default_handlers(self, val_data, event_handlers): event_handlers = event_handlers or [] default_handlers = [] - train_metrics, val_metrics = self.prepare_loss_and_metrics() + self.prepare_loss_and_metrics() # no need to add to default handler check as StoppingHandler does not use metrics event_handlers.append(StoppingHandler(self.max_epoch, self.max_batch)) + default_handlers.append("StoppingHandler") if not any(isinstance(handler, MetricHandler) for handler in event_handlers): - event_handlers.append(MetricHandler(train_metrics=train_metrics)) + event_handlers.append(MetricHandler(train_metrics=self.train_metrics)) default_handlers.append("MetricHandler") - if val_data and not any(isinstance(handler, ValidationHandler) for handler in event_handlers): - event_handlers.append(ValidationHandler(val_data=val_data, eval_fn=self.evaluate, - val_metrics=val_metrics)) - default_handlers.append("ValidationHandler") + if not any(isinstance(handler, ValidationHandler) for handler in event_handlers): + # no validation handler + if val_data: + # add default validation handler if validation data found + event_handlers.append(ValidationHandler(val_data=val_data, eval_fn=self.evaluate, + val_metrics=self.val_metrics)) + default_handlers.append("ValidationHandler") + val_metrics = self.val_metrics + else: + # set validation metrics to None if no validation data and no validation handler + val_metrics = [] if not any(isinstance(handler, LoggingHandler) for handler in event_handlers): - event_handlers.append(LoggingHandler(train_metrics=train_metrics, + event_handlers.append(LoggingHandler(train_metrics=self.train_metrics, val_metrics=val_metrics)) default_handlers.append("LoggingHandler") # if there is a mix of user defined event handlers and default event handlers # they should have the same set of loss and metrics - if default_handlers: + if default_handlers and len(event_handlers) != len(default_handlers): msg = "You are training with the following default event handlers: %s. " \ "They use loss and metrics from estimator.prepare_loss_and_metrics(). " \ "Please use the same set of metrics for all your other handlers." % \ @@ -368,7 +382,7 @@ def _prepare_default_handlers(self, val_data, event_handlers): # remove None metric references references = set([ref for ref in references if ref]) for metric in references: - if metric not in train_metrics + val_metrics: + if metric not in self.train_metrics + self.val_metrics: msg = "We have added following default handlers for you: %s and used " \ "estimator.prepare_loss_and_metrics() to pass metrics to " \ "those handlers. Please use the same set of metrics " \ diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py index ed97c7bc3d19..da2c84455e35 100644 --- a/python/mxnet/gluon/contrib/estimator/event_handler.py +++ b/python/mxnet/gluon/contrib/estimator/event_handler.py @@ -26,7 +26,12 @@ import numpy as np -from ....metric import EvalMetric, Loss +from ....metric import EvalMetric +from ....metric import Loss as metric_loss + +__all__ = ['TrainBegin', 'TrainEnd', 'EpochBegin', 'EpochEnd', 'BatchBegin', 'BatchEnd', + 'StoppingHandler', 'MetricHandler', 'ValidationHandler', + 'LoggingHandler', 'CheckpointHandler', 'EarlyStoppingHandler'] class TrainBegin(object): @@ -127,7 +132,7 @@ def batch_end(self, estimator, *args, **kwargs): label = kwargs['label'] loss = kwargs['loss'] for metric in self.train_metrics: - if isinstance(metric, Loss): + if isinstance(metric, metric_loss): # metric wrapper for loss values metric.update(0, loss) else: @@ -135,7 +140,7 @@ def batch_end(self, estimator, *args, **kwargs): class ValidationHandler(TrainBegin, BatchEnd, EpochEnd): - """"Validation Handler that evaluate model on validation dataset + """Validation Handler that evaluate model on validation dataset :py:class:`ValidationHandler` takes validation dataset, an evaluation function, metrics to be evaluated, and how often to run the validation. You can provide custom @@ -430,7 +435,7 @@ def train_begin(self, estimator, *args, **kwargs): self.current_epoch = 0 self.current_batch = 0 if self.save_best: - self.best = np.Inf if self.monitor_op == np.less else -np.Inf # pylint: disable=comparison-with-callable + self.best = np.Inf if self.monitor_op == np.less else -np.Inf # pylint: disable=comparison-with-callable if self.resume_from_checkpoint: error_msg = "To use resume from checkpoint, you must only specify " \ "the same type of period you used for training." \ @@ -506,12 +511,12 @@ def _save_checkpoint(self, estimator): def _save_symbol(self, estimator): symbol_file = os.path.join(self.model_dir, self.model_prefix + '-symbol.json') - if hasattr(estimator.net, '_cached_graph'): + if hasattr(estimator.net, '_cached_graph') and estimator.net._cached_graph: sym = estimator.net._cached_graph[1] sym.save(symbol_file) else: - self.logger.info("Model architecture(symbol file) is not saved, please use HybridBlock" - "to construct your model, can call net.hybridize() before passing to" + self.logger.info("Model architecture(symbol file) is not saved, please use HybridBlock " + "to construct your model, can call net.hybridize() before passing to " "Estimator in order to save model architecture as %s.", symbol_file) def _save_params_and_trainer(self, estimator, file_prefix): @@ -666,7 +671,7 @@ def __init__(self, "if you want otherwise", self.monitor.get()[0]) self.monitor_op = np.less - if self.monitor_op == np.greater: # pylint: disable=comparison-with-callable + if self.monitor_op == np.greater: # pylint: disable=comparison-with-callable self.min_delta *= 1 else: self.min_delta *= -1 @@ -679,7 +684,7 @@ def train_begin(self, estimator, *args, **kwargs): if self.baseline is not None: self.best = self.baseline else: - self.best = np.Inf if self.monitor_op == np.less else -np.Inf # pylint: disable=comparison-with-callable + self.best = np.Inf if self.monitor_op == np.less else -np.Inf # pylint: disable=comparison-with-callable def epoch_end(self, estimator, *args, **kwargs): monitor_name, monitor_value = self.monitor.get() diff --git a/tests/python/unittest/test_gluon_estimator.py b/tests/python/unittest/test_gluon_estimator.py index d2e8c082aa08..ae47d925670f 100644 --- a/tests/python/unittest/test_gluon_estimator.py +++ b/tests/python/unittest/test_gluon_estimator.py @@ -19,11 +19,13 @@ import sys import unittest +import warnings import mxnet as mx from mxnet import gluon from mxnet.gluon import nn from mxnet.gluon.contrib.estimator import * +from mxnet.gluon.contrib.estimator.event_handler import * from nose.tools import assert_raises @@ -335,10 +337,9 @@ def test_default_handlers(): metrics=train_acc, trainer=trainer, context=ctx) - # no handler + # no handler(all default handlers), no warning with warnings.catch_warnings(record=True) as w: est.fit(train_data=train_data, epochs=num_epochs) - assert 'You are training with the' in str(w[-1].message) # handler with prepared loss and metrics # use mix of default and user defined handlers @@ -353,7 +354,7 @@ def test_default_handlers(): # handler with all user defined metrics # use mix of default and user defined handlers metric = MetricHandler(train_metrics=[train_acc]) - logging = LoggingHandler(train_metrics=[train_acc], val_metrics=[mx.metric.RMSE("val acc")]) + logging = LoggingHandler(train_metrics=[train_acc]) est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[metric, logging]) # handler with mixed metrics, some handler use metrics prepared by estimator diff --git a/tests/tutorials/test_tutorials.py b/tests/tutorials/test_tutorials.py index 5fe6a03eae7b..c2173a7dc071 100644 --- a/tests/tutorials/test_tutorials.py +++ b/tests/tutorials/test_tutorials.py @@ -133,6 +133,9 @@ def test_gluon_learning_rate_schedules_advanced(): def test_gluon_info_gan(): assert _test_tutorial_nb('gluon/info_gan') +def test_gluon_fit_api_fashion_mnist(): + assert _test_tutorial_nb('gluon/fit_api_tutorial') + def test_nlp_cnn(): assert _test_tutorial_nb('nlp/cnn')