diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py index 4f2b8fd99cac..83b954d02e10 100644 --- a/python/mxnet/gluon/contrib/estimator/estimator.py +++ b/python/mxnet/gluon/contrib/estimator/estimator.py @@ -20,6 +20,8 @@ """Gluon Estimator""" import copy +import logging +import sys import warnings from .event_handler import MetricHandler, ValidationHandler, LoggingHandler, StoppingHandler @@ -57,6 +59,25 @@ class Estimator(object): Trainer to apply optimizer on network parameters. context : Context or list of Context Device(s) to run the training on. + + """ + + logger = None + """logging.Logger object associated with the Estimator. + + The logger is used for all logs generated by this estimator and its + handlers. A new logging.Logger is created during Estimator construction and + configured to write all logs with level logging.INFO or higher to + sys.stdout. + + You can modify the logging settings using the standard Python methods. For + example, to save logs to a file in addition to printing them to stdout + output, you can attach a logging.FileHandler to the logger. + + >>> est = Estimator(net, loss) + >>> import logging + >>> est.logger.addHandler(logging.FileHandler(filename)) + """ def __init__(self, net, @@ -65,13 +86,15 @@ def __init__(self, net, initializer=None, trainer=None, context=None): - self.net = net self.loss = self._check_loss(loss) self._train_metrics = _check_metrics(metrics) self._add_default_training_metrics() self._add_validation_metrics() + self.logger = logging.Logger(name='Estimator', level=logging.INFO) + self.logger.addHandler(logging.StreamHandler(sys.stdout)) + self.context = self._check_context(context) self._initialize(initializer) self.trainer = self._check_trainer(trainer) @@ -243,8 +266,7 @@ def evaluate(self, for _, batch in enumerate(val_data): self.evaluate_batch(batch, val_metrics, batch_axis) - def fit_batch(self, train_batch, - batch_axis=0): + def fit_batch(self, train_batch, batch_axis=0): """Trains the model on a batch of training data. Parameters @@ -257,13 +279,15 @@ def fit_batch(self, train_batch, Returns ------- data: List of NDArray - Sharded data from the batch. + Sharded data from the batch. Data is sharded with + `gluon.split_and_load`. label: List of NDArray - Sharded label from the batch. + Sharded label from the batch. Labels are sharded with + `gluon.split_and_load`. pred: List of NDArray - Prediction of each of the shareded batch. + Prediction on each of the sharded inputs. loss: List of NDArray - Loss of each of the shareded batch. + Loss on each of the sharded inputs. """ data, label = self._get_data_and_label(train_batch, self.context, batch_axis) @@ -304,7 +328,11 @@ def fit(self, train_data, Number of epochs to iterate on the training data. You can only specify one and only one type of iteration(epochs or batches). event_handlers : EventHandler or list of EventHandler - List of :py:class:`EventHandlers` to apply during training. + List of :py:class:`EventHandlers` to apply during training. Besides + the event handlers specified here, a StoppingHandler, + LoggingHandler and MetricHandler will be added by default if not + yet specified manually. If validation data is provided, a + ValidationHandler is also added if not already specified. batches : int, default None Number of batches to iterate on the training data. You can only specify one and only one type of iteration(epochs or batches). @@ -405,11 +433,6 @@ def _prepare_default_handlers(self, val_data, event_handlers): event_handlers.extend(added_default_handlers) if mixing_handlers: - msg = "The following default event handlers are added: {}.".format( - ", ".join([type(h).__name__ for h in added_default_handlers])) - warnings.warn(msg) - - # check if all handlers have the same set of references to metrics known_metrics = set(self.train_metrics + self.val_metrics) for handler in event_handlers: diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py index 7e143d6f19aa..3cdc407407c1 100644 --- a/python/mxnet/gluon/contrib/estimator/event_handler.py +++ b/python/mxnet/gluon/contrib/estimator/event_handler.py @@ -19,14 +19,13 @@ # pylint: disable=wildcard-import, unused-argument, too-many-ancestors """Gluon EventHandlers for Estimators""" -import logging import os import time import warnings import numpy as np -from ....metric import EvalMetric, CompositeEvalMetric +from ....metric import CompositeEvalMetric, EvalMetric from ....metric import Loss as metric_loss from .utils import _check_metrics @@ -34,6 +33,7 @@ 'StoppingHandler', 'MetricHandler', 'ValidationHandler', 'LoggingHandler', 'CheckpointHandler', 'EarlyStoppingHandler'] + class EventHandler(object): pass @@ -194,7 +194,6 @@ def __init__(self, # order to be called among all callbacks # validation metrics need to be calculated before other callbacks can access them self.priority = -np.Inf - self.logger = logging.getLogger(__name__) def train_begin(self, estimator, *args, **kwargs): # reset epoch and batch counter @@ -211,7 +210,7 @@ def batch_end(self, estimator, *args, **kwargs): for monitor in self.val_metrics: name, value = monitor.get() msg += '%s: %.4f, ' % (name, value) - self.logger.info(msg.rstrip(',')) + estimator.logger.info(msg.rstrip(',')) def epoch_end(self, estimator, *args, **kwargs): self.current_epoch += 1 @@ -228,12 +227,6 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat Parameters ---------- - file_name : str - File name to save the logs. - file_location : str - File location to save the logs. - filemode : str, default 'a' - Logging file mode, default using append mode. verbose : int, default LOG_PER_EPOCH Limit the granularity of metrics displayed during training process. verbose=LOG_PER_EPOCH: display metrics every epoch @@ -247,25 +240,10 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat LOG_PER_EPOCH = 1 LOG_PER_BATCH = 2 - def __init__(self, file_name=None, - file_location=None, - filemode='a', - verbose=LOG_PER_EPOCH, + def __init__(self, verbose=LOG_PER_EPOCH, train_metrics=None, val_metrics=None): super(LoggingHandler, self).__init__() - self.logger = logging.getLogger(__name__) - self.logger.setLevel(logging.INFO) - self._added_logging_handlers = [logging.StreamHandler()] - # save logger to file only if file name or location is specified - if file_name or file_location: - file_name = file_name or 'estimator_log' - file_location = file_location or './' - file_handler = logging.FileHandler(os.path.join(file_location, file_name), mode=filemode) - self._added_logging_handlers.append(file_handler) - for handler in self._added_logging_handlers: - self.logger.addHandler(handler) - if verbose not in [self.LOG_PER_EPOCH, self.LOG_PER_BATCH]: raise ValueError("verbose level must be either LOG_PER_EPOCH or " "LOG_PER_BATCH, received %s. " @@ -281,24 +259,18 @@ def __init__(self, file_name=None, # it will also shut down logging at train end self.priority = np.Inf - def __del__(self): - for handler in self._added_logging_handlers: - handler.flush() - self.logger.removeHandler(handler) - handler.close() - def train_begin(self, estimator, *args, **kwargs): self.train_start = time.time() trainer = estimator.trainer optimizer = trainer.optimizer.__class__.__name__ lr = trainer.learning_rate - self.logger.info("Training begin: using optimizer %s " - "with current learning rate %.4f ", - optimizer, lr) + estimator.logger.info("Training begin: using optimizer %s " + "with current learning rate %.4f ", + optimizer, lr) if estimator.max_epoch: - self.logger.info("Train for %d epochs.", estimator.max_epoch) + estimator.logger.info("Train for %d epochs.", estimator.max_epoch) else: - self.logger.info("Train for %d batches.", estimator.max_batch) + estimator.logger.info("Train for %d batches.", estimator.max_batch) # reset all counters self.current_epoch = 0 self.batch_index = 0 @@ -311,13 +283,7 @@ def train_end(self, estimator, *args, **kwargs): for metric in self.train_metrics + self.val_metrics: name, value = metric.get() msg += '%s: %.4f, ' % (name, value) - self.logger.info(msg.rstrip(', ')) - # make a copy of handler list and remove one by one - # as removing handler will edit the handler list - for handler in self.logger.handlers[:]: - handler.close() - self.logger.removeHandler(handler) - logging.shutdown() + estimator.logger.info(msg.rstrip(', ')) def batch_begin(self, estimator, *args, **kwargs): if self.verbose == self.LOG_PER_BATCH: @@ -334,14 +300,14 @@ def batch_end(self, estimator, *args, **kwargs): # only log current training loss & metric after each batch name, value = metric.get() msg += '%s: %.4f, ' % (name, value) - self.logger.info(msg.rstrip(', ')) + estimator.logger.info(msg.rstrip(', ')) self.batch_index += 1 def epoch_begin(self, estimator, *args, **kwargs): if self.verbose >= self.LOG_PER_EPOCH: self.epoch_start = time.time() - self.logger.info("[Epoch %d] Begin, current learning rate: %.4f", - self.current_epoch, estimator.trainer.learning_rate) + estimator.logger.info("[Epoch %d] Begin, current learning rate: %.4f", + self.current_epoch, estimator.trainer.learning_rate) def epoch_end(self, estimator, *args, **kwargs): if self.verbose >= self.LOG_PER_EPOCH: @@ -350,7 +316,7 @@ def epoch_end(self, estimator, *args, **kwargs): for monitor in self.train_metrics + self.val_metrics: name, value = monitor.get() msg += '%s: %.4f, ' % (name, value) - self.logger.info(msg.rstrip(', ')) + estimator.logger.info(msg.rstrip(', ')) self.current_epoch += 1 self.batch_index = 0 @@ -424,7 +390,6 @@ def __init__(self, self.max_checkpoints = max_checkpoints self.resume_from_checkpoint = resume_from_checkpoint self.saved_checkpoints = [] - self.logger = logging.getLogger(__name__) if self.save_best: if mode not in ['auto', 'min', 'max']: warnings.warn('ModelCheckpoint mode %s is unknown, ' @@ -443,14 +408,16 @@ def __init__(self, else: # use greater for accuracy and f1 and less otherwise if 'acc' or 'f1' in self.monitor.get()[0].lower(): - self.logger.info("`greater` operator will be used to determine " - "if %s has improved, please use `min` for mode " - "if you want otherwise", self.monitor.get()[0]) + warnings.warn("`greater` operator will be used to determine if {} has improved. " + "Please specify `mode='min'` to use the `less` operator. " + "Specify `mode='max' to disable this warning.`" + .format(self.monitor.get()[0])) self.monitor_op = np.greater else: - self.logger.info("`less` operator will be used to determine " - "if %s has improved, please use `max` for mode " - "if you want otherwise", self.monitor.get()[0]) + warnings.warn("`less` operator will be used to determine if {} has improved. " + "Please specify `mode='max'` to use the `greater` operator. " + "Specify `mode='min' to disable this warning.`" + .format(self.monitor.get()[0])) self.monitor_op = np.less def train_begin(self, estimator, *args, **kwargs): @@ -501,9 +468,9 @@ def _save_checkpoint(self, estimator): prefix = "%s-epoch%dbatch%d" % (self.model_prefix, save_epoch_number, save_batch_number) self._save_params_and_trainer(estimator, prefix) if self.verbose > 0: - self.logger.info('[Epoch %d] CheckpointHandler: trained total %d batches, ' - 'saving model at %s with prefix: %s', - self.current_epoch, self.current_batch + 1, self.model_dir, prefix) + estimator.logger.info('[Epoch %d] CheckpointHandler: trained total %d batches, ' + 'saving model at %s with prefix: %s', + self.current_epoch, self.current_batch + 1, self.model_dir, prefix) if self.save_best: monitor_name, monitor_value = self.monitor.get() @@ -519,18 +486,18 @@ def _save_checkpoint(self, estimator): self._save_params_and_trainer(estimator, prefix) self.best = monitor_value if self.verbose > 0: - self.logger.info('[Epoch %d] CheckpointHandler: ' - '%s improved from %0.5f to %0.5f, ' - 'updating best model at %s with prefix: %s', - self.current_epoch, monitor_name, - self.best, monitor_value, self.model_dir, prefix) + estimator.logger.info('[Epoch %d] CheckpointHandler: ' + '%s improved from %0.5f to %0.5f, ' + 'updating best model at %s with prefix: %s', + self.current_epoch, monitor_name, + self.best, monitor_value, self.model_dir, prefix) else: if self.verbose > 0: - self.logger.info('[Epoch %d] CheckpointHandler: ' - '%s did not improve from %0.5f, ' - 'skipping updating best model', - self.current_batch, monitor_name, - self.best) + estimator.logger.info('[Epoch %d] CheckpointHandler: ' + '%s did not improve from %0.5f, ' + 'skipping updating best model', + self.current_batch, monitor_name, + self.best) def _save_symbol(self, estimator): symbol_file = os.path.join(self.model_dir, self.model_prefix + '-symbol.json') @@ -538,9 +505,11 @@ def _save_symbol(self, estimator): sym = estimator.net._cached_graph[1] sym.save(symbol_file) else: - self.logger.info("Model architecture(symbol file) is not saved, please use HybridBlock " - "to construct your model, and call net.hybridize() before passing to " - "Estimator in order to save model architecture as %s.", symbol_file) + estimator.logger.info( + "Model architecture(symbol file) is not saved, please use HybridBlock " + "to construct your model, and call net.hybridize() before passing to " + "Estimator in order to save model architecture as %s.", + symbol_file) def _save_params_and_trainer(self, estimator, file_prefix): param_file = os.path.join(self.model_dir, file_prefix + '.params') @@ -579,7 +548,7 @@ def _resume_from_checkpoint(self, estimator): msg += "%d batches" % estimator.max_batch else: msg += "%d epochs" % estimator.max_epoch - self.logger.info(msg) + estimator.logger.info(msg) else: msg = "CheckpointHandler: Checkpoint resumed from epoch %d batch %d, " \ "continue to train for " % (self.trained_epoch, self.trained_batch) @@ -607,7 +576,7 @@ def _resume_from_checkpoint(self, estimator): assert os.path.exists(trainer_file), "Failed to load checkpoint, %s does not exist" % trainer_file estimator.net.load_parameters(param_file, ctx=estimator.context) estimator.trainer.load_states(trainer_file) - self.logger.warning(msg) + estimator.logger.warning(msg) def _find_max_iteration(self, dir, prefix, start, end, saved_checkpoints=None): error_msg = "Error parsing checkpoint file, please check your " \ @@ -672,7 +641,6 @@ def __init__(self, self.stopped_epoch = 0 self.current_epoch = 0 self.stop_training = False - self.logger = logging.getLogger(__name__) if mode not in ['auto', 'min', 'max']: warnings.warn('EarlyStopping mode %s is unknown, ' @@ -688,14 +656,16 @@ def __init__(self, self.monitor_op = np.greater else: if 'acc' or 'f1' in self.monitor.get()[0].lower(): - self.logger.info("`greater` operator is used to determine " - "if %s has improved, please use `min` for mode " - "if you want otherwise", self.monitor.get()[0]) + warnings.warn("`greater` operator will be used to determine if {} has improved. " + "Please specify `mode='min'` to use the `less` operator. " + "Specify `mode='max' to disable this warning.`" + .format(self.monitor.get()[0])) self.monitor_op = np.greater else: - self.logger.info("`less` operator is used to determine " - "if %s has improved, please use `max` for mode " - "if you want otherwise", self.monitor.get()[0]) + warnings.warn("`less` operator will be used to determine if {} has improved. " + "Please specify `mode='max'` to use the `greater` operator. " + "Specify `mode='min' to disable this warning.`" + .format(self.monitor.get()[0])) self.monitor_op = np.less if self.monitor_op == np.greater: # pylint: disable=comparison-with-callable @@ -733,5 +703,6 @@ def epoch_end(self, estimator, *args, **kwargs): def train_end(self, estimator, *args, **kwargs): if self.stopped_epoch > 0: - self.logger.info('[Epoch %d] EarlyStoppingHanlder: early stopping due to %s not improving', - self.stopped_epoch, self.monitor.get()[0]) + estimator.logger.info('[Epoch %d] EarlyStoppingHanlder: ' + 'early stopping due to %s not improving', + self.stopped_epoch, self.monitor.get()[0]) diff --git a/tests/python/unittest/common.py b/tests/python/unittest/common.py index 06fb16288649..816508fbc3c8 100644 --- a/tests/python/unittest/common.py +++ b/tests/python/unittest/common.py @@ -251,7 +251,7 @@ def setup_module(): try: from tempfile import TemporaryDirectory -except: +except: # Python 2 support # really simple implementation of TemporaryDirectory class TemporaryDirectory(object): def __init__(self, suffix='', prefix='', dir=''): diff --git a/tests/python/unittest/test_gluon_estimator.py b/tests/python/unittest/test_gluon_estimator.py index bae576734a3e..aaf9839b29f3 100644 --- a/tests/python/unittest/test_gluon_estimator.py +++ b/tests/python/unittest/test_gluon_estimator.py @@ -346,10 +346,7 @@ def test_default_handlers(): train_metrics = est.train_metrics val_metrics = est.val_metrics logging = LoggingHandler(train_metrics=train_metrics, val_metrics=val_metrics) - with warnings.catch_warnings(record=True) as w: - est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[logging]) - # provide metric handler by default - assert 'MetricHandler' in str(w[-1].message) + est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[logging]) # handler with all user defined metrics # use mix of default and user defined handlers diff --git a/tests/python/unittest/test_gluon_event_handler.py b/tests/python/unittest/test_gluon_event_handler.py index b29c72a0f908..17c75813d516 100644 --- a/tests/python/unittest/test_gluon_event_handler.py +++ b/tests/python/unittest/test_gluon_event_handler.py @@ -16,6 +16,7 @@ # under the License. import os +import logging import mxnet as mx from common import TemporaryDirectory @@ -143,16 +144,18 @@ def test_logging(): ce_loss = loss.SoftmaxCrossEntropyLoss() acc = mx.metric.Accuracy() est = estimator.Estimator(net, loss=ce_loss, metrics=acc) + + est.logger.addHandler(logging.FileHandler(output_dir)) + train_metrics = est.train_metrics val_metrics = est.val_metrics - logging_handler = event_handler.LoggingHandler(file_name=file_name, - file_location=tmpdir, - train_metrics=train_metrics, + logging_handler = event_handler.LoggingHandler(train_metrics=train_metrics, val_metrics=val_metrics) est.fit(test_data, event_handlers=[logging_handler], epochs=3) assert logging_handler.batch_index == 0 assert logging_handler.current_epoch == 3 assert os.path.isfile(output_dir) + del est # Clean up estimator and logger before deleting tmpdir def test_custom_handler():