From 799f04bbf4d487d7fff532365cd0b27491a339a0 Mon Sep 17 00:00:00 2001 From: Leonard Lausen Date: Fri, 8 Nov 2019 02:35:54 +0000 Subject: [PATCH 1/6] Fix doc --- python/mxnet/gluon/contrib/estimator/estimator.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py index 4f2b8fd99cac..fdd9af11f6e1 100644 --- a/python/mxnet/gluon/contrib/estimator/estimator.py +++ b/python/mxnet/gluon/contrib/estimator/estimator.py @@ -243,8 +243,7 @@ def evaluate(self, for _, batch in enumerate(val_data): self.evaluate_batch(batch, val_metrics, batch_axis) - def fit_batch(self, train_batch, - batch_axis=0): + def fit_batch(self, train_batch, batch_axis=0): """Trains the model on a batch of training data. Parameters @@ -257,13 +256,15 @@ def fit_batch(self, train_batch, Returns ------- data: List of NDArray - Sharded data from the batch. + Sharded data from the batch. Data is sharded with + `gluon.split_and_load`. label: List of NDArray - Sharded label from the batch. + Sharded label from the batch. Labels are sharded with + `gluon.split_and_load`. pred: List of NDArray - Prediction of each of the shareded batch. + Prediction on each of the sharded inputs. loss: List of NDArray - Loss of each of the shareded batch. + Loss on each of the sharded inputs. """ data, label = self._get_data_and_label(train_batch, self.context, batch_axis) From ae9647bc3cdb50ac7f90b55be57ae860c78993c5 Mon Sep 17 00:00:00 2001 From: Leonard Lausen Date: Fri, 8 Nov 2019 02:36:08 +0000 Subject: [PATCH 2/6] Remove warning --- python/mxnet/gluon/contrib/estimator/estimator.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py index fdd9af11f6e1..15f7e6107183 100644 --- a/python/mxnet/gluon/contrib/estimator/estimator.py +++ b/python/mxnet/gluon/contrib/estimator/estimator.py @@ -406,11 +406,6 @@ def _prepare_default_handlers(self, val_data, event_handlers): event_handlers.extend(added_default_handlers) if mixing_handlers: - msg = "The following default event handlers are added: {}.".format( - ", ".join([type(h).__name__ for h in added_default_handlers])) - warnings.warn(msg) - - # check if all handlers have the same set of references to metrics known_metrics = set(self.train_metrics + self.val_metrics) for handler in event_handlers: From 2bcdedc059d4949a1a929a6eba459c73e72b63ef Mon Sep 17 00:00:00 2001 From: Leonard Lausen Date: Fri, 8 Nov 2019 02:49:54 +0000 Subject: [PATCH 3/6] Don't use global logger. Make it specific to each estimator. gluon.contrib.estimator used a global Logger obtained via `logging.getLogger('gluon.contrib.estimator.event_handlers')`. This logger used to be configured every time a gluon.contrib.estimator.LoggingHandler was created, which is a bug. We can't modify a global Logger instance whenever the user creates an Estimator and a LoggingHandler. Instead, this commit separates the LoggingHandler (responsible for logging metadata during estimator.fit) from the configuration of the Logger. We expose the Logger as attribute of the Estimator, and configure it to output to stdout by default. Instructions are given how users can configure the Estimator.logger to log to a file instead. --- .../gluon/contrib/estimator/estimator.py | 25 +++- .../gluon/contrib/estimator/event_handler.py | 131 +++++++----------- .../unittest/test_gluon_event_handler.py | 8 +- 3 files changed, 79 insertions(+), 85 deletions(-) diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py index 15f7e6107183..37e4eaf0fe36 100644 --- a/python/mxnet/gluon/contrib/estimator/estimator.py +++ b/python/mxnet/gluon/contrib/estimator/estimator.py @@ -20,6 +20,8 @@ """Gluon Estimator""" import copy +import logging +import sys import warnings from .event_handler import MetricHandler, ValidationHandler, LoggingHandler, StoppingHandler @@ -57,6 +59,25 @@ class Estimator(object): Trainer to apply optimizer on network parameters. context : Context or list of Context Device(s) to run the training on. + + """ + + logger = None + """logging.Logger object associated with the Estimator. + + The logger is used for all logs generated by this estimator and its + handlers. A new logging.Logger is created during Estimator construction and + configured to write all logs with level logging.INFO or higher to + sys.stdout. + + You can modify the logging settings using the standard Python methods. For + example, to save logs to a file in addition to printing them to stdout + output, you can attach a logging.FileHandler to the logger. + + >>> est = Estimator(net, loss) + >>> import logging + >>> est.logger.addHandler(logging.FileHandler(filename)) + """ def __init__(self, net, @@ -65,13 +86,15 @@ def __init__(self, net, initializer=None, trainer=None, context=None): - self.net = net self.loss = self._check_loss(loss) self._train_metrics = _check_metrics(metrics) self._add_default_training_metrics() self._add_validation_metrics() + self.logger = logging.Logger(name='Estimator', level=logging.INFO) + self.logger.addHandler(logging.StreamHandler(sys.stdout)) + self.context = self._check_context(context) self._initialize(initializer) self.trainer = self._check_trainer(trainer) diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py index 7e143d6f19aa..1cbe2455a4d4 100644 --- a/python/mxnet/gluon/contrib/estimator/event_handler.py +++ b/python/mxnet/gluon/contrib/estimator/event_handler.py @@ -21,12 +21,13 @@ import logging import os +import sys import time import warnings import numpy as np -from ....metric import EvalMetric, CompositeEvalMetric +from ....metric import CompositeEvalMetric, EvalMetric from ....metric import Loss as metric_loss from .utils import _check_metrics @@ -34,6 +35,7 @@ 'StoppingHandler', 'MetricHandler', 'ValidationHandler', 'LoggingHandler', 'CheckpointHandler', 'EarlyStoppingHandler'] + class EventHandler(object): pass @@ -194,7 +196,6 @@ def __init__(self, # order to be called among all callbacks # validation metrics need to be calculated before other callbacks can access them self.priority = -np.Inf - self.logger = logging.getLogger(__name__) def train_begin(self, estimator, *args, **kwargs): # reset epoch and batch counter @@ -211,7 +212,7 @@ def batch_end(self, estimator, *args, **kwargs): for monitor in self.val_metrics: name, value = monitor.get() msg += '%s: %.4f, ' % (name, value) - self.logger.info(msg.rstrip(',')) + estimator.logger.info(msg.rstrip(',')) def epoch_end(self, estimator, *args, **kwargs): self.current_epoch += 1 @@ -228,12 +229,6 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat Parameters ---------- - file_name : str - File name to save the logs. - file_location : str - File location to save the logs. - filemode : str, default 'a' - Logging file mode, default using append mode. verbose : int, default LOG_PER_EPOCH Limit the granularity of metrics displayed during training process. verbose=LOG_PER_EPOCH: display metrics every epoch @@ -247,25 +242,10 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat LOG_PER_EPOCH = 1 LOG_PER_BATCH = 2 - def __init__(self, file_name=None, - file_location=None, - filemode='a', - verbose=LOG_PER_EPOCH, + def __init__(self, verbose=LOG_PER_EPOCH, train_metrics=None, val_metrics=None): super(LoggingHandler, self).__init__() - self.logger = logging.getLogger(__name__) - self.logger.setLevel(logging.INFO) - self._added_logging_handlers = [logging.StreamHandler()] - # save logger to file only if file name or location is specified - if file_name or file_location: - file_name = file_name or 'estimator_log' - file_location = file_location or './' - file_handler = logging.FileHandler(os.path.join(file_location, file_name), mode=filemode) - self._added_logging_handlers.append(file_handler) - for handler in self._added_logging_handlers: - self.logger.addHandler(handler) - if verbose not in [self.LOG_PER_EPOCH, self.LOG_PER_BATCH]: raise ValueError("verbose level must be either LOG_PER_EPOCH or " "LOG_PER_BATCH, received %s. " @@ -281,24 +261,18 @@ def __init__(self, file_name=None, # it will also shut down logging at train end self.priority = np.Inf - def __del__(self): - for handler in self._added_logging_handlers: - handler.flush() - self.logger.removeHandler(handler) - handler.close() - def train_begin(self, estimator, *args, **kwargs): self.train_start = time.time() trainer = estimator.trainer optimizer = trainer.optimizer.__class__.__name__ lr = trainer.learning_rate - self.logger.info("Training begin: using optimizer %s " - "with current learning rate %.4f ", - optimizer, lr) + estimator.logger.info("Training begin: using optimizer %s " + "with current learning rate %.4f ", + optimizer, lr) if estimator.max_epoch: - self.logger.info("Train for %d epochs.", estimator.max_epoch) + estimator.logger.info("Train for %d epochs.", estimator.max_epoch) else: - self.logger.info("Train for %d batches.", estimator.max_batch) + estimator.logger.info("Train for %d batches.", estimator.max_batch) # reset all counters self.current_epoch = 0 self.batch_index = 0 @@ -311,13 +285,7 @@ def train_end(self, estimator, *args, **kwargs): for metric in self.train_metrics + self.val_metrics: name, value = metric.get() msg += '%s: %.4f, ' % (name, value) - self.logger.info(msg.rstrip(', ')) - # make a copy of handler list and remove one by one - # as removing handler will edit the handler list - for handler in self.logger.handlers[:]: - handler.close() - self.logger.removeHandler(handler) - logging.shutdown() + estimator.logger.info(msg.rstrip(', ')) def batch_begin(self, estimator, *args, **kwargs): if self.verbose == self.LOG_PER_BATCH: @@ -334,14 +302,14 @@ def batch_end(self, estimator, *args, **kwargs): # only log current training loss & metric after each batch name, value = metric.get() msg += '%s: %.4f, ' % (name, value) - self.logger.info(msg.rstrip(', ')) + estimator.logger.info(msg.rstrip(', ')) self.batch_index += 1 def epoch_begin(self, estimator, *args, **kwargs): if self.verbose >= self.LOG_PER_EPOCH: self.epoch_start = time.time() - self.logger.info("[Epoch %d] Begin, current learning rate: %.4f", - self.current_epoch, estimator.trainer.learning_rate) + estimator.logger.info("[Epoch %d] Begin, current learning rate: %.4f", + self.current_epoch, estimator.trainer.learning_rate) def epoch_end(self, estimator, *args, **kwargs): if self.verbose >= self.LOG_PER_EPOCH: @@ -350,7 +318,7 @@ def epoch_end(self, estimator, *args, **kwargs): for monitor in self.train_metrics + self.val_metrics: name, value = monitor.get() msg += '%s: %.4f, ' % (name, value) - self.logger.info(msg.rstrip(', ')) + estimator.logger.info(msg.rstrip(', ')) self.current_epoch += 1 self.batch_index = 0 @@ -424,7 +392,6 @@ def __init__(self, self.max_checkpoints = max_checkpoints self.resume_from_checkpoint = resume_from_checkpoint self.saved_checkpoints = [] - self.logger = logging.getLogger(__name__) if self.save_best: if mode not in ['auto', 'min', 'max']: warnings.warn('ModelCheckpoint mode %s is unknown, ' @@ -443,14 +410,14 @@ def __init__(self, else: # use greater for accuracy and f1 and less otherwise if 'acc' or 'f1' in self.monitor.get()[0].lower(): - self.logger.info("`greater` operator will be used to determine " - "if %s has improved, please use `min` for mode " - "if you want otherwise", self.monitor.get()[0]) + warnings.warn("`greater` operator will be used to determine if %s has improved. " + "Please specify `mode='min'` to use the `less` operator. " + "Specify `mode='max' to disable this warning.`", self.monitor.get()[0]) self.monitor_op = np.greater else: - self.logger.info("`less` operator will be used to determine " - "if %s has improved, please use `max` for mode " - "if you want otherwise", self.monitor.get()[0]) + warnings.warn("`less` operator will be used to determine if %s has improved. " + "Please specify `mode='max'` to use the `greater` operator. " + "Specify `mode='min' to disable this warning.`", self.monitor.get()[0]) self.monitor_op = np.less def train_begin(self, estimator, *args, **kwargs): @@ -501,9 +468,9 @@ def _save_checkpoint(self, estimator): prefix = "%s-epoch%dbatch%d" % (self.model_prefix, save_epoch_number, save_batch_number) self._save_params_and_trainer(estimator, prefix) if self.verbose > 0: - self.logger.info('[Epoch %d] CheckpointHandler: trained total %d batches, ' - 'saving model at %s with prefix: %s', - self.current_epoch, self.current_batch + 1, self.model_dir, prefix) + estimator.logger.info('[Epoch %d] CheckpointHandler: trained total %d batches, ' + 'saving model at %s with prefix: %s', + self.current_epoch, self.current_batch + 1, self.model_dir, prefix) if self.save_best: monitor_name, monitor_value = self.monitor.get() @@ -519,18 +486,18 @@ def _save_checkpoint(self, estimator): self._save_params_and_trainer(estimator, prefix) self.best = monitor_value if self.verbose > 0: - self.logger.info('[Epoch %d] CheckpointHandler: ' - '%s improved from %0.5f to %0.5f, ' - 'updating best model at %s with prefix: %s', - self.current_epoch, monitor_name, - self.best, monitor_value, self.model_dir, prefix) + estimator.logger.info('[Epoch %d] CheckpointHandler: ' + '%s improved from %0.5f to %0.5f, ' + 'updating best model at %s with prefix: %s', + self.current_epoch, monitor_name, + self.best, monitor_value, self.model_dir, prefix) else: if self.verbose > 0: - self.logger.info('[Epoch %d] CheckpointHandler: ' - '%s did not improve from %0.5f, ' - 'skipping updating best model', - self.current_batch, monitor_name, - self.best) + estimator.logger.info('[Epoch %d] CheckpointHandler: ' + '%s did not improve from %0.5f, ' + 'skipping updating best model', + self.current_batch, monitor_name, + self.best) def _save_symbol(self, estimator): symbol_file = os.path.join(self.model_dir, self.model_prefix + '-symbol.json') @@ -538,9 +505,11 @@ def _save_symbol(self, estimator): sym = estimator.net._cached_graph[1] sym.save(symbol_file) else: - self.logger.info("Model architecture(symbol file) is not saved, please use HybridBlock " - "to construct your model, and call net.hybridize() before passing to " - "Estimator in order to save model architecture as %s.", symbol_file) + estimator.logger.info( + "Model architecture(symbol file) is not saved, please use HybridBlock " + "to construct your model, and call net.hybridize() before passing to " + "Estimator in order to save model architecture as %s.", + symbol_file) def _save_params_and_trainer(self, estimator, file_prefix): param_file = os.path.join(self.model_dir, file_prefix + '.params') @@ -579,7 +548,7 @@ def _resume_from_checkpoint(self, estimator): msg += "%d batches" % estimator.max_batch else: msg += "%d epochs" % estimator.max_epoch - self.logger.info(msg) + estimator.logger.info(msg) else: msg = "CheckpointHandler: Checkpoint resumed from epoch %d batch %d, " \ "continue to train for " % (self.trained_epoch, self.trained_batch) @@ -607,7 +576,7 @@ def _resume_from_checkpoint(self, estimator): assert os.path.exists(trainer_file), "Failed to load checkpoint, %s does not exist" % trainer_file estimator.net.load_parameters(param_file, ctx=estimator.context) estimator.trainer.load_states(trainer_file) - self.logger.warning(msg) + estimator.logger.warning(msg) def _find_max_iteration(self, dir, prefix, start, end, saved_checkpoints=None): error_msg = "Error parsing checkpoint file, please check your " \ @@ -672,7 +641,6 @@ def __init__(self, self.stopped_epoch = 0 self.current_epoch = 0 self.stop_training = False - self.logger = logging.getLogger(__name__) if mode not in ['auto', 'min', 'max']: warnings.warn('EarlyStopping mode %s is unknown, ' @@ -688,14 +656,14 @@ def __init__(self, self.monitor_op = np.greater else: if 'acc' or 'f1' in self.monitor.get()[0].lower(): - self.logger.info("`greater` operator is used to determine " - "if %s has improved, please use `min` for mode " - "if you want otherwise", self.monitor.get()[0]) + warnings.warn("`greater` operator will be used to determine if %s has improved. " + "Please specify `mode='min'` to use the `less` operator. " + "Specify `mode='max' to disable this warning.`", self.monitor.get()[0]) self.monitor_op = np.greater else: - self.logger.info("`less` operator is used to determine " - "if %s has improved, please use `max` for mode " - "if you want otherwise", self.monitor.get()[0]) + warnings.warn("`less` operator will be used to determine if %s has improved. " + "Please specify `mode='max'` to use the `greater` operator. " + "Specify `mode='min' to disable this warning.`", self.monitor.get()[0]) self.monitor_op = np.less if self.monitor_op == np.greater: # pylint: disable=comparison-with-callable @@ -733,5 +701,6 @@ def epoch_end(self, estimator, *args, **kwargs): def train_end(self, estimator, *args, **kwargs): if self.stopped_epoch > 0: - self.logger.info('[Epoch %d] EarlyStoppingHanlder: early stopping due to %s not improving', - self.stopped_epoch, self.monitor.get()[0]) + estimator.logger.info('[Epoch %d] EarlyStoppingHanlder: ' + 'early stopping due to %s not improving', + self.stopped_epoch, self.monitor.get()[0]) diff --git a/tests/python/unittest/test_gluon_event_handler.py b/tests/python/unittest/test_gluon_event_handler.py index b29c72a0f908..cc6cef004792 100644 --- a/tests/python/unittest/test_gluon_event_handler.py +++ b/tests/python/unittest/test_gluon_event_handler.py @@ -16,6 +16,7 @@ # under the License. import os +import logging import mxnet as mx from common import TemporaryDirectory @@ -143,11 +144,12 @@ def test_logging(): ce_loss = loss.SoftmaxCrossEntropyLoss() acc = mx.metric.Accuracy() est = estimator.Estimator(net, loss=ce_loss, metrics=acc) + + est.logger.addHandler(logging.FileHandler(output_dir)) + train_metrics = est.train_metrics val_metrics = est.val_metrics - logging_handler = event_handler.LoggingHandler(file_name=file_name, - file_location=tmpdir, - train_metrics=train_metrics, + logging_handler = event_handler.LoggingHandler(train_metrics=train_metrics, val_metrics=val_metrics) est.fit(test_data, event_handlers=[logging_handler], epochs=3) assert logging_handler.batch_index == 0 From 8354b1b61599e52bf10058283fd8a7d2f9c27728 Mon Sep 17 00:00:00 2001 From: Leonard Lausen Date: Thu, 14 Nov 2019 07:28:51 +0000 Subject: [PATCH 4/6] Fix --- .../mxnet/gluon/contrib/estimator/event_handler.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py index 1cbe2455a4d4..a44cda758eac 100644 --- a/python/mxnet/gluon/contrib/estimator/event_handler.py +++ b/python/mxnet/gluon/contrib/estimator/event_handler.py @@ -19,9 +19,7 @@ # pylint: disable=wildcard-import, unused-argument, too-many-ancestors """Gluon EventHandlers for Estimators""" -import logging import os -import sys import time import warnings @@ -412,12 +410,14 @@ def __init__(self, if 'acc' or 'f1' in self.monitor.get()[0].lower(): warnings.warn("`greater` operator will be used to determine if %s has improved. " "Please specify `mode='min'` to use the `less` operator. " - "Specify `mode='max' to disable this warning.`", self.monitor.get()[0]) + "Specify `mode='max' to disable this warning.`" + .format(self.monitor.get()[0])) self.monitor_op = np.greater else: warnings.warn("`less` operator will be used to determine if %s has improved. " "Please specify `mode='max'` to use the `greater` operator. " - "Specify `mode='min' to disable this warning.`", self.monitor.get()[0]) + "Specify `mode='min' to disable this warning.`" + .format(self.monitor.get()[0])) self.monitor_op = np.less def train_begin(self, estimator, *args, **kwargs): @@ -658,12 +658,14 @@ def __init__(self, if 'acc' or 'f1' in self.monitor.get()[0].lower(): warnings.warn("`greater` operator will be used to determine if %s has improved. " "Please specify `mode='min'` to use the `less` operator. " - "Specify `mode='max' to disable this warning.`", self.monitor.get()[0]) + "Specify `mode='max' to disable this warning.`" + .format(self.monitor.get()[0])) self.monitor_op = np.greater else: warnings.warn("`less` operator will be used to determine if %s has improved. " "Please specify `mode='max'` to use the `greater` operator. " - "Specify `mode='min' to disable this warning.`", self.monitor.get()[0]) + "Specify `mode='min' to disable this warning.`" + .format(self.monitor.get()[0])) self.monitor_op = np.less if self.monitor_op == np.greater: # pylint: disable=comparison-with-callable From c47217eb6da3f8b5510e142e27aa26130a3e9862 Mon Sep 17 00:00:00 2001 From: Leonard Lausen Date: Thu, 14 Nov 2019 13:27:31 +0000 Subject: [PATCH 5/6] Fix and clarify doc --- python/mxnet/gluon/contrib/estimator/estimator.py | 6 +++++- python/mxnet/gluon/contrib/estimator/event_handler.py | 8 ++++---- tests/python/unittest/test_gluon_estimator.py | 5 +---- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py index 37e4eaf0fe36..83b954d02e10 100644 --- a/python/mxnet/gluon/contrib/estimator/estimator.py +++ b/python/mxnet/gluon/contrib/estimator/estimator.py @@ -328,7 +328,11 @@ def fit(self, train_data, Number of epochs to iterate on the training data. You can only specify one and only one type of iteration(epochs or batches). event_handlers : EventHandler or list of EventHandler - List of :py:class:`EventHandlers` to apply during training. + List of :py:class:`EventHandlers` to apply during training. Besides + the event handlers specified here, a StoppingHandler, + LoggingHandler and MetricHandler will be added by default if not + yet specified manually. If validation data is provided, a + ValidationHandler is also added if not already specified. batches : int, default None Number of batches to iterate on the training data. You can only specify one and only one type of iteration(epochs or batches). diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py index a44cda758eac..3cdc407407c1 100644 --- a/python/mxnet/gluon/contrib/estimator/event_handler.py +++ b/python/mxnet/gluon/contrib/estimator/event_handler.py @@ -408,13 +408,13 @@ def __init__(self, else: # use greater for accuracy and f1 and less otherwise if 'acc' or 'f1' in self.monitor.get()[0].lower(): - warnings.warn("`greater` operator will be used to determine if %s has improved. " + warnings.warn("`greater` operator will be used to determine if {} has improved. " "Please specify `mode='min'` to use the `less` operator. " "Specify `mode='max' to disable this warning.`" .format(self.monitor.get()[0])) self.monitor_op = np.greater else: - warnings.warn("`less` operator will be used to determine if %s has improved. " + warnings.warn("`less` operator will be used to determine if {} has improved. " "Please specify `mode='max'` to use the `greater` operator. " "Specify `mode='min' to disable this warning.`" .format(self.monitor.get()[0])) @@ -656,13 +656,13 @@ def __init__(self, self.monitor_op = np.greater else: if 'acc' or 'f1' in self.monitor.get()[0].lower(): - warnings.warn("`greater` operator will be used to determine if %s has improved. " + warnings.warn("`greater` operator will be used to determine if {} has improved. " "Please specify `mode='min'` to use the `less` operator. " "Specify `mode='max' to disable this warning.`" .format(self.monitor.get()[0])) self.monitor_op = np.greater else: - warnings.warn("`less` operator will be used to determine if %s has improved. " + warnings.warn("`less` operator will be used to determine if {} has improved. " "Please specify `mode='max'` to use the `greater` operator. " "Specify `mode='min' to disable this warning.`" .format(self.monitor.get()[0])) diff --git a/tests/python/unittest/test_gluon_estimator.py b/tests/python/unittest/test_gluon_estimator.py index bae576734a3e..aaf9839b29f3 100644 --- a/tests/python/unittest/test_gluon_estimator.py +++ b/tests/python/unittest/test_gluon_estimator.py @@ -346,10 +346,7 @@ def test_default_handlers(): train_metrics = est.train_metrics val_metrics = est.val_metrics logging = LoggingHandler(train_metrics=train_metrics, val_metrics=val_metrics) - with warnings.catch_warnings(record=True) as w: - est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[logging]) - # provide metric handler by default - assert 'MetricHandler' in str(w[-1].message) + est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[logging]) # handler with all user defined metrics # use mix of default and user defined handlers From 1b43c5ad67d9de1cc6ff40d5a5bfd0105702c8e6 Mon Sep 17 00:00:00 2001 From: Leonard Lausen Date: Fri, 15 Nov 2019 02:33:11 +0000 Subject: [PATCH 6/6] Workaround problems on Windows --- tests/python/unittest/common.py | 2 +- tests/python/unittest/test_gluon_event_handler.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/common.py b/tests/python/unittest/common.py index 06fb16288649..816508fbc3c8 100644 --- a/tests/python/unittest/common.py +++ b/tests/python/unittest/common.py @@ -251,7 +251,7 @@ def setup_module(): try: from tempfile import TemporaryDirectory -except: +except: # Python 2 support # really simple implementation of TemporaryDirectory class TemporaryDirectory(object): def __init__(self, suffix='', prefix='', dir=''): diff --git a/tests/python/unittest/test_gluon_event_handler.py b/tests/python/unittest/test_gluon_event_handler.py index cc6cef004792..17c75813d516 100644 --- a/tests/python/unittest/test_gluon_event_handler.py +++ b/tests/python/unittest/test_gluon_event_handler.py @@ -155,6 +155,7 @@ def test_logging(): assert logging_handler.batch_index == 0 assert logging_handler.current_epoch == 3 assert os.path.isfile(output_dir) + del est # Clean up estimator and logger before deleting tmpdir def test_custom_handler():