Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Gluon] Improve estimator usability and fix logging logic #16810

Merged
merged 6 commits into from
Nov 16, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 36 additions & 13 deletions python/mxnet/gluon/contrib/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
"""Gluon Estimator"""

import copy
import logging
import sys
import warnings

from .event_handler import MetricHandler, ValidationHandler, LoggingHandler, StoppingHandler
Expand Down Expand Up @@ -57,6 +59,25 @@ class Estimator(object):
Trainer to apply optimizer on network parameters.
context : Context or list of Context
Device(s) to run the training on.

"""

logger = None
"""logging.Logger object associated with the Estimator.

The logger is used for all logs generated by this estimator and its
handlers. A new logging.Logger is created during Estimator construction and
configured to write all logs with level logging.INFO or higher to
sys.stdout.

You can modify the logging settings using the standard Python methods. For
example, to save logs to a file in addition to printing them to stdout
output, you can attach a logging.FileHandler to the logger.

>>> est = Estimator(net, loss)
>>> import logging
>>> est.logger.addHandler(logging.FileHandler(filename))

"""

def __init__(self, net,
Expand All @@ -65,13 +86,15 @@ def __init__(self, net,
initializer=None,
trainer=None,
context=None):

self.net = net
self.loss = self._check_loss(loss)
self._train_metrics = _check_metrics(metrics)
self._add_default_training_metrics()
self._add_validation_metrics()

self.logger = logging.Logger(name='Estimator', level=logging.INFO)
self.logger.addHandler(logging.StreamHandler(sys.stdout))

self.context = self._check_context(context)
self._initialize(initializer)
self.trainer = self._check_trainer(trainer)
Expand Down Expand Up @@ -243,8 +266,7 @@ def evaluate(self,
for _, batch in enumerate(val_data):
self.evaluate_batch(batch, val_metrics, batch_axis)

def fit_batch(self, train_batch,
batch_axis=0):
def fit_batch(self, train_batch, batch_axis=0):
"""Trains the model on a batch of training data.

Parameters
Expand All @@ -257,13 +279,15 @@ def fit_batch(self, train_batch,
Returns
-------
data: List of NDArray
Sharded data from the batch.
Sharded data from the batch. Data is sharded with
`gluon.split_and_load`.
label: List of NDArray
Sharded label from the batch.
Sharded label from the batch. Labels are sharded with
`gluon.split_and_load`.
pred: List of NDArray
Prediction of each of the shareded batch.
Prediction on each of the sharded inputs.
loss: List of NDArray
Loss of each of the shareded batch.
Loss on each of the sharded inputs.
"""
data, label = self._get_data_and_label(train_batch, self.context, batch_axis)

Expand Down Expand Up @@ -304,7 +328,11 @@ def fit(self, train_data,
Number of epochs to iterate on the training data.
You can only specify one and only one type of iteration(epochs or batches).
event_handlers : EventHandler or list of EventHandler
List of :py:class:`EventHandlers` to apply during training.
List of :py:class:`EventHandlers` to apply during training. Besides
the event handlers specified here, a StoppingHandler,
LoggingHandler and MetricHandler will be added by default if not
yet specified manually. If validation data is provided, a
ValidationHandler is also added if not already specified.
batches : int, default None
Number of batches to iterate on the training data.
You can only specify one and only one type of iteration(epochs or batches).
Expand Down Expand Up @@ -405,11 +433,6 @@ def _prepare_default_handlers(self, val_data, event_handlers):
event_handlers.extend(added_default_handlers)

if mixing_handlers:
msg = "The following default event handlers are added: {}.".format(
", ".join([type(h).__name__ for h in added_default_handlers]))
warnings.warn(msg)


# check if all handlers have the same set of references to metrics
known_metrics = set(self.train_metrics + self.val_metrics)
for handler in event_handlers:
Expand Down
Loading