diff --git a/docs/tutorials/gluon/fit_api_tutorial.md b/docs/tutorials/gluon/fit_api_tutorial.md index 99b9efe62465..bc50690ac1a2 100644 --- a/docs/tutorials/gluon/fit_api_tutorial.md +++ b/docs/tutorials/gluon/fit_api_tutorial.md @@ -137,8 +137,12 @@ est = estimator.Estimator(net=resnet_18_v1, trainer=trainer, context=ctx) -# Magic line -est.fit(train_data=train_data_loader, +# ignore warnings for nightly test on CI only +import warnings +with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # Magic line + est.fit(train_data=train_data_loader, epochs=num_epochs) ``` @@ -224,11 +228,15 @@ checkpoint_handler = CheckpointHandler(model_dir='./', save_best=True) # Save the best model in terms of # Let's instantiate another handler which we defined above loss_record_handler = LossRecordHandler() -# Magic line -est.fit(train_data=train_data_loader, - val_data=val_data_loader, - epochs=num_epochs, - event_handlers=[checkpoint_handler, loss_record_handler]) # Add the event handlers +# ignore warnings for nightly test on CI only +import warnings +with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # Magic line + est.fit(train_data=train_data_loader, + val_data=val_data_loader, + epochs=num_epochs, + event_handlers=[checkpoint_handler, loss_record_handler]) # Add the event handlers ``` Training begin: using optimizer SGD with current learning rate 0.0400 diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py index 5e3804784ba8..d077bcd4fdeb 100644 --- a/python/mxnet/gluon/contrib/estimator/estimator.py +++ b/python/mxnet/gluon/contrib/estimator/estimator.py @@ -334,28 +334,36 @@ def fit(self, train_data, def _prepare_default_handlers(self, val_data, event_handlers): event_handlers = event_handlers or [] default_handlers = [] - train_metrics, val_metrics = self.prepare_loss_and_metrics() + self.prepare_loss_and_metrics() # no need to add to default handler check as StoppingHandler does not use metrics event_handlers.append(StoppingHandler(self.max_epoch, self.max_batch)) + default_handlers.append("StoppingHandler") if not any(isinstance(handler, MetricHandler) for handler in event_handlers): - event_handlers.append(MetricHandler(train_metrics=train_metrics)) + event_handlers.append(MetricHandler(train_metrics=self.train_metrics)) default_handlers.append("MetricHandler") - if val_data and not any(isinstance(handler, ValidationHandler) for handler in event_handlers): - event_handlers.append(ValidationHandler(val_data=val_data, eval_fn=self.evaluate, - val_metrics=val_metrics)) - default_handlers.append("ValidationHandler") + if not any(isinstance(handler, ValidationHandler) for handler in event_handlers): + # no validation handler + if val_data: + # add default validation handler if validation data found + event_handlers.append(ValidationHandler(val_data=val_data, eval_fn=self.evaluate, + val_metrics=self.val_metrics)) + default_handlers.append("ValidationHandler") + val_metrics = self.val_metrics + else: + # set validation metrics to None if no validation data and no validation handler + val_metrics = [] if not any(isinstance(handler, LoggingHandler) for handler in event_handlers): - event_handlers.append(LoggingHandler(train_metrics=train_metrics, + event_handlers.append(LoggingHandler(train_metrics=self.train_metrics, val_metrics=val_metrics)) default_handlers.append("LoggingHandler") # if there is a mix of user defined event handlers and default event handlers # they should have the same set of loss and metrics - if default_handlers: + if default_handlers and len(event_handlers) > len(default_handlers): msg = "You are training with the following default event handlers: %s. " \ "They use loss and metrics from estimator.prepare_loss_and_metrics(). " \ "Please use the same set of metrics for all your other handlers." % \ @@ -374,7 +382,7 @@ def _prepare_default_handlers(self, val_data, event_handlers): # remove None metric references references = set([ref for ref in references if ref]) for metric in references: - if metric not in train_metrics + val_metrics: + if metric not in self.train_metrics + self.val_metrics: msg = "We have added following default handlers for you: %s and used " \ "estimator.prepare_loss_and_metrics() to pass metrics to " \ "those handlers. Please use the same set of metrics " \ diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py index 660bf34122c5..0c8dbd9a4e1d 100644 --- a/python/mxnet/gluon/contrib/estimator/event_handler.py +++ b/python/mxnet/gluon/contrib/estimator/event_handler.py @@ -29,7 +29,8 @@ from ....metric import EvalMetric from ....metric import Loss as metric_loss -__all__ = ['StoppingHandler', 'MetricHandler', 'ValidationHandler', +__all__ = ['TrainBegin', 'TrainEnd', 'EpochBegin', 'EpochEnd','BatchBegin', 'BatchEnd', + 'StoppingHandler', 'MetricHandler', 'ValidationHandler', 'LoggingHandler', 'CheckpointHandler', 'EarlyStoppingHandler'] class TrainBegin(object): @@ -513,8 +514,8 @@ def _save_symbol(self, estimator): sym = estimator.net._cached_graph[1] sym.save(symbol_file) else: - self.logger.info("Model architecture(symbol file) is not saved, please use HybridBlock" - "to construct your model, can call net.hybridize() before passing to" + self.logger.info("Model architecture(symbol file) is not saved, please use HybridBlock " + "to construct your model, can call net.hybridize() before passing to " "Estimator in order to save model architecture as %s.", symbol_file) def _save_params_and_trainer(self, estimator, file_prefix): diff --git a/tests/python/unittest/test_gluon_estimator.py b/tests/python/unittest/test_gluon_estimator.py index d2e8c082aa08..5050a0067f6c 100644 --- a/tests/python/unittest/test_gluon_estimator.py +++ b/tests/python/unittest/test_gluon_estimator.py @@ -19,11 +19,13 @@ import sys import unittest +import warnings import mxnet as mx from mxnet import gluon from mxnet.gluon import nn from mxnet.gluon.contrib.estimator import * +from mxnet.gluon.contrib.estimator.event_handler import * from nose.tools import assert_raises