Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix and clarify doc
Browse files Browse the repository at this point in the history
  • Loading branch information
leezu committed Nov 14, 2019
1 parent 8354b1b commit c47217e
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 9 deletions.
6 changes: 5 additions & 1 deletion python/mxnet/gluon/contrib/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,11 @@ def fit(self, train_data,
Number of epochs to iterate on the training data.
You can only specify one and only one type of iteration(epochs or batches).
event_handlers : EventHandler or list of EventHandler
List of :py:class:`EventHandlers` to apply during training.
List of :py:class:`EventHandlers` to apply during training. Besides
the event handlers specified here, a StoppingHandler,
LoggingHandler and MetricHandler will be added by default if not
yet specified manually. If validation data is provided, a
ValidationHandler is also added if not already specified.
batches : int, default None
Number of batches to iterate on the training data.
You can only specify one and only one type of iteration(epochs or batches).
Expand Down
8 changes: 4 additions & 4 deletions python/mxnet/gluon/contrib/estimator/event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,13 +408,13 @@ def __init__(self,
else:
# use greater for accuracy and f1 and less otherwise
if 'acc' or 'f1' in self.monitor.get()[0].lower():
warnings.warn("`greater` operator will be used to determine if %s has improved. "
warnings.warn("`greater` operator will be used to determine if {} has improved. "
"Please specify `mode='min'` to use the `less` operator. "
"Specify `mode='max' to disable this warning.`"
.format(self.monitor.get()[0]))
self.monitor_op = np.greater
else:
warnings.warn("`less` operator will be used to determine if %s has improved. "
warnings.warn("`less` operator will be used to determine if {} has improved. "
"Please specify `mode='max'` to use the `greater` operator. "
"Specify `mode='min' to disable this warning.`"
.format(self.monitor.get()[0]))
Expand Down Expand Up @@ -656,13 +656,13 @@ def __init__(self,
self.monitor_op = np.greater
else:
if 'acc' or 'f1' in self.monitor.get()[0].lower():
warnings.warn("`greater` operator will be used to determine if %s has improved. "
warnings.warn("`greater` operator will be used to determine if {} has improved. "
"Please specify `mode='min'` to use the `less` operator. "
"Specify `mode='max' to disable this warning.`"
.format(self.monitor.get()[0]))
self.monitor_op = np.greater
else:
warnings.warn("`less` operator will be used to determine if %s has improved. "
warnings.warn("`less` operator will be used to determine if {} has improved. "
"Please specify `mode='max'` to use the `greater` operator. "
"Specify `mode='min' to disable this warning.`"
.format(self.monitor.get()[0]))
Expand Down
5 changes: 1 addition & 4 deletions tests/python/unittest/test_gluon_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,10 +346,7 @@ def test_default_handlers():
train_metrics = est.train_metrics
val_metrics = est.val_metrics
logging = LoggingHandler(train_metrics=train_metrics, val_metrics=val_metrics)
with warnings.catch_warnings(record=True) as w:
est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[logging])
# provide metric handler by default
assert 'MetricHandler' in str(w[-1].message)
est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[logging])

# handler with all user defined metrics
# use mix of default and user defined handlers
Expand Down

0 comments on commit c47217e

Please sign in to comment.