From aa92c3b32ed08b2bb7720e2b2db5e34ed8ab5edd Mon Sep 17 00:00:00 2001 From: Leonard Lausen Date: Thu, 14 Nov 2019 13:27:31 +0000 Subject: [PATCH] Fix and clarify doc --- python/mxnet/gluon/contrib/estimator/estimator.py | 6 +++++- tests/python/unittest/test_gluon_estimator.py | 5 +---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py index 37e4eaf0fe36..83b954d02e10 100644 --- a/python/mxnet/gluon/contrib/estimator/estimator.py +++ b/python/mxnet/gluon/contrib/estimator/estimator.py @@ -328,7 +328,11 @@ def fit(self, train_data, Number of epochs to iterate on the training data. You can only specify one and only one type of iteration(epochs or batches). event_handlers : EventHandler or list of EventHandler - List of :py:class:`EventHandlers` to apply during training. + List of :py:class:`EventHandlers` to apply during training. Besides + the event handlers specified here, a StoppingHandler, + LoggingHandler and MetricHandler will be added by default if not + yet specified manually. If validation data is provided, a + ValidationHandler is also added if not already specified. batches : int, default None Number of batches to iterate on the training data. You can only specify one and only one type of iteration(epochs or batches). diff --git a/tests/python/unittest/test_gluon_estimator.py b/tests/python/unittest/test_gluon_estimator.py index bae576734a3e..aaf9839b29f3 100644 --- a/tests/python/unittest/test_gluon_estimator.py +++ b/tests/python/unittest/test_gluon_estimator.py @@ -346,10 +346,7 @@ def test_default_handlers(): train_metrics = est.train_metrics val_metrics = est.val_metrics logging = LoggingHandler(train_metrics=train_metrics, val_metrics=val_metrics) - with warnings.catch_warnings(record=True) as w: - est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[logging]) - # provide metric handler by default - assert 'MetricHandler' in str(w[-1].message) + est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[logging]) # handler with all user defined metrics # use mix of default and user defined handlers