diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py index 37e4eaf0fe36..83b954d02e10 100644 --- a/python/mxnet/gluon/contrib/estimator/estimator.py +++ b/python/mxnet/gluon/contrib/estimator/estimator.py @@ -328,7 +328,11 @@ def fit(self, train_data, Number of epochs to iterate on the training data. You can only specify one and only one type of iteration(epochs or batches). event_handlers : EventHandler or list of EventHandler - List of :py:class:`EventHandlers` to apply during training. + List of :py:class:`EventHandlers` to apply during training. Besides + the event handlers specified here, a StoppingHandler, + LoggingHandler and MetricHandler will be added by default if not + yet specified manually. If validation data is provided, a + ValidationHandler is also added if not already specified. batches : int, default None Number of batches to iterate on the training data. You can only specify one and only one type of iteration(epochs or batches). diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py index a44cda758eac..3cdc407407c1 100644 --- a/python/mxnet/gluon/contrib/estimator/event_handler.py +++ b/python/mxnet/gluon/contrib/estimator/event_handler.py @@ -408,13 +408,13 @@ def __init__(self, else: # use greater for accuracy and f1 and less otherwise if 'acc' or 'f1' in self.monitor.get()[0].lower(): - warnings.warn("`greater` operator will be used to determine if %s has improved. " + warnings.warn("`greater` operator will be used to determine if {} has improved. " "Please specify `mode='min'` to use the `less` operator. " "Specify `mode='max' to disable this warning.`" .format(self.monitor.get()[0])) self.monitor_op = np.greater else: - warnings.warn("`less` operator will be used to determine if %s has improved. " + warnings.warn("`less` operator will be used to determine if {} has improved. " "Please specify `mode='max'` to use the `greater` operator. " "Specify `mode='min' to disable this warning.`" .format(self.monitor.get()[0])) @@ -656,13 +656,13 @@ def __init__(self, self.monitor_op = np.greater else: if 'acc' or 'f1' in self.monitor.get()[0].lower(): - warnings.warn("`greater` operator will be used to determine if %s has improved. " + warnings.warn("`greater` operator will be used to determine if {} has improved. " "Please specify `mode='min'` to use the `less` operator. " "Specify `mode='max' to disable this warning.`" .format(self.monitor.get()[0])) self.monitor_op = np.greater else: - warnings.warn("`less` operator will be used to determine if %s has improved. " + warnings.warn("`less` operator will be used to determine if {} has improved. " "Please specify `mode='max'` to use the `greater` operator. " "Specify `mode='min' to disable this warning.`" .format(self.monitor.get()[0])) diff --git a/tests/python/unittest/test_gluon_estimator.py b/tests/python/unittest/test_gluon_estimator.py index bae576734a3e..aaf9839b29f3 100644 --- a/tests/python/unittest/test_gluon_estimator.py +++ b/tests/python/unittest/test_gluon_estimator.py @@ -346,10 +346,7 @@ def test_default_handlers(): train_metrics = est.train_metrics val_metrics = est.val_metrics logging = LoggingHandler(train_metrics=train_metrics, val_metrics=val_metrics) - with warnings.catch_warnings(record=True) as w: - est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[logging]) - # provide metric handler by default - assert 'MetricHandler' in str(w[-1].message) + est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[logging]) # handler with all user defined metrics # use mix of default and user defined handlers