From d51cb0f71aea099c27909cf525ce30bd09d518b2 Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Fri, 1 Nov 2019 08:57:32 -0700 Subject: [PATCH] [Estimator] refactor estimator and clarify docs (#16694) * refactor estimator and clarify docs * fix info message and test * clean up after releasing logging handler --- .../gluon/contrib/estimator/estimator.py | 134 ++++++++---------- .../gluon/contrib/estimator/event_handler.py | 67 ++++++--- python/mxnet/gluon/contrib/estimator/utils.py | 31 +++- tests/python/unittest/test_gluon_estimator.py | 11 +- .../unittest/test_gluon_event_handler.py | 3 +- 5 files changed, 139 insertions(+), 107 deletions(-) diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py index d3eded0cc8cd..4f2b8fd99cac 100644 --- a/python/mxnet/gluon/contrib/estimator/estimator.py +++ b/python/mxnet/gluon/contrib/estimator/estimator.py @@ -24,15 +24,14 @@ from .event_handler import MetricHandler, ValidationHandler, LoggingHandler, StoppingHandler from .event_handler import TrainBegin, EpochBegin, BatchBegin, BatchEnd, EpochEnd, TrainEnd -from .utils import _check_metrics +from .event_handler import _check_event_handlers +from .utils import _check_metrics, _suggest_metric_for_loss, _check_handler_metric_ref from ...data import DataLoader -from ...loss import SoftmaxCrossEntropyLoss from ...loss import Loss as gluon_loss from ...trainer import Trainer from ...utils import split_and_load from .... import autograd from ....context import Context, cpu, gpu, num_gpus -from ....metric import Accuracy from ....metric import Loss as metric_loss __all__ = ['Estimator'] @@ -48,8 +47,8 @@ class Estimator(object): ---------- net : gluon.Block The model used for training. - loss : gluon.loss.Loss or list of gluon.loss.Loss - Loss(objective functions) to calculate during training. + loss : gluon.loss.Loss + Loss (objective) function to calculate during training. metrics : EvalMetric or list of EvalMetric Metrics for evaluating models. initializer : Initializer @@ -69,19 +68,17 @@ def __init__(self, net, self.net = net self.loss = self._check_loss(loss) - self.train_metrics = _check_metrics(metrics) + self._train_metrics = _check_metrics(metrics) + self._add_default_training_metrics() + self._add_validation_metrics() self.context = self._check_context(context) self._initialize(initializer) self.trainer = self._check_trainer(trainer) def _check_loss(self, loss): - if isinstance(loss, gluon_loss): - loss = [loss] - elif isinstance(loss, list) and all([isinstance(l, gluon_loss) for l in loss]): - loss = loss - else: - raise ValueError("loss must be a Loss or a list of Loss, " + if not isinstance(loss, gluon_loss): + raise ValueError("loss must be a Loss, " "refer to gluon.loss.Loss:{}".format(loss)) return loss @@ -166,31 +163,30 @@ def _get_data_and_label(self, batch, ctx, batch_axis=0): label = split_and_load(label, ctx_list=ctx, batch_axis=batch_axis) return data, label - def prepare_loss_and_metrics(self): - """ - Based on loss functions and training metrics in estimator - Create metric wrappers to record loss values, - Create copies of train loss/metric objects to record validation values + def _add_default_training_metrics(self): + if not self._train_metrics: + suggested_metric = _suggest_metric_for_loss(self.loss) + if suggested_metric: + self._train_metrics = [suggested_metric] + loss_name = self.loss.name.rstrip('1234567890') + self._train_metrics.append(metric_loss(loss_name)) - Returns - ------- - train_metrics, val_metrics - """ - if any(not hasattr(self, attribute) for attribute in - ['train_metrics', 'val_metrics']): - # Use default mx.metric.Accuracy() for SoftmaxCrossEntropyLoss() - if not self.train_metrics and any([isinstance(l, SoftmaxCrossEntropyLoss) for l in self.loss]): - self.train_metrics = [Accuracy()] - self.val_metrics = [] - for loss in self.loss: - # remove trailing numbers from loss name to avoid confusion - self.train_metrics.append(metric_loss(loss.name.rstrip('1234567890'))) - for metric in self.train_metrics: - val_metric = copy.deepcopy(metric) - metric.name = "train " + metric.name - val_metric.name = "validation " + val_metric.name - self.val_metrics.append(val_metric) - return self.train_metrics, self.val_metrics + for metric in self._train_metrics: + metric.name = "training " + metric.name + + def _add_validation_metrics(self): + self._val_metrics = [copy.deepcopy(metric) for metric in self._train_metrics] + + for metric in self._val_metrics: + metric.name = "validation " + metric.name + + @property + def train_metrics(self): + return self._train_metrics + + @property + def val_metrics(self): + return self._val_metrics def evaluate_batch(self, val_batch, @@ -209,7 +205,7 @@ def evaluate_batch(self, """ data, label = self._get_data_and_label(val_batch, self.context, batch_axis) pred = [self.net(x) for x in data] - loss = [self.loss[0](y_hat, y) for y_hat, y in zip(pred, label)] + loss = [self.loss(y_hat, y) for y_hat, y in zip(pred, label)] # update metrics for metric in val_metrics: if isinstance(metric, metric_loss): @@ -275,7 +271,7 @@ def fit_batch(self, train_batch, with autograd.record(): pred = [self.net(x) for x in data] - loss = [self.loss[0](y_hat, y) for y_hat, y in zip(pred, label)] + loss = [self.loss(y_hat, y) for y_hat, y in zip(pred, label)] for l in loss: l.backward() @@ -377,63 +373,47 @@ def fit(self, train_data, handler.train_end(estimator_ref) def _prepare_default_handlers(self, val_data, event_handlers): - event_handlers = event_handlers or [] - default_handlers = [] - self.prepare_loss_and_metrics() + event_handlers = _check_event_handlers(event_handlers) + added_default_handlers = [] # no need to add to default handler check as StoppingHandler does not use metrics - event_handlers.append(StoppingHandler(self.max_epoch, self.max_batch)) - default_handlers.append("StoppingHandler") + added_default_handlers.append(StoppingHandler(self.max_epoch, self.max_batch)) if not any(isinstance(handler, MetricHandler) for handler in event_handlers): - event_handlers.append(MetricHandler(train_metrics=self.train_metrics)) - default_handlers.append("MetricHandler") + added_default_handlers.append(MetricHandler(train_metrics=self.train_metrics)) if not any(isinstance(handler, ValidationHandler) for handler in event_handlers): # no validation handler if val_data: - # add default validation handler if validation data found - event_handlers.append(ValidationHandler(val_data=val_data, eval_fn=self.evaluate, - val_metrics=self.val_metrics)) - default_handlers.append("ValidationHandler") val_metrics = self.val_metrics + # add default validation handler if validation data found + added_default_handlers.append(ValidationHandler(val_data=val_data, + eval_fn=self.evaluate, + val_metrics=val_metrics)) else: # set validation metrics to None if no validation data and no validation handler val_metrics = [] if not any(isinstance(handler, LoggingHandler) for handler in event_handlers): - event_handlers.append(LoggingHandler(train_metrics=self.train_metrics, - val_metrics=val_metrics)) - default_handlers.append("LoggingHandler") + added_default_handlers.append(LoggingHandler(train_metrics=self.train_metrics, + val_metrics=val_metrics)) # if there is a mix of user defined event handlers and default event handlers - # they should have the same set of loss and metrics - if default_handlers and len(event_handlers) != len(default_handlers): - msg = "You are training with the following default event handlers: %s. " \ - "They use loss and metrics from estimator.prepare_loss_and_metrics(). " \ - "Please use the same set of metrics for all your other handlers." % \ - ", ".join(default_handlers) + # they should have the same set of metrics + mixing_handlers = event_handlers and added_default_handlers + + event_handlers.extend(added_default_handlers) + + if mixing_handlers: + msg = "The following default event handlers are added: {}.".format( + ", ".join([type(h).__name__ for h in added_default_handlers])) warnings.warn(msg) - # check if all handlers has the same set of references to loss and metrics - references = [] + + + # check if all handlers have the same set of references to metrics + known_metrics = set(self.train_metrics + self.val_metrics) for handler in event_handlers: - for attribute in dir(handler): - if any(keyword in attribute for keyword in ['metric' or 'monitor']): - reference = getattr(handler, attribute) - if isinstance(reference, list): - references += reference - else: - references.append(reference) - # remove None metric references - references = set([ref for ref in references if ref]) - for metric in references: - if metric not in self.train_metrics + self.val_metrics: - msg = "We have added following default handlers for you: %s and used " \ - "estimator.prepare_loss_and_metrics() to pass metrics to " \ - "those handlers. Please use the same set of metrics " \ - "for all your handlers." % \ - ", ".join(default_handlers) - raise ValueError(msg) + _check_handler_metric_ref(handler, known_metrics) event_handlers.sort(key=lambda handler: getattr(handler, 'priority', 0)) return event_handlers diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py index c5a4f1a3f836..7e143d6f19aa 100644 --- a/python/mxnet/gluon/contrib/estimator/event_handler.py +++ b/python/mxnet/gluon/contrib/estimator/event_handler.py @@ -16,7 +16,7 @@ # under the License. # coding: utf-8 -# pylint: disable=wildcard-import, unused-argument +# pylint: disable=wildcard-import, unused-argument, too-many-ancestors """Gluon EventHandlers for Estimators""" import logging @@ -34,33 +34,47 @@ 'StoppingHandler', 'MetricHandler', 'ValidationHandler', 'LoggingHandler', 'CheckpointHandler', 'EarlyStoppingHandler'] +class EventHandler(object): + pass -class TrainBegin(object): + +def _check_event_handlers(handlers): + if isinstance(handlers, EventHandler): + handlers = [handlers] + else: + handlers = handlers or [] + if not all([isinstance(handler, EventHandler) for handler in handlers]): + raise ValueError("handlers must be an EventHandler or a list of EventHandler, " + "got: {}".format(handlers)) + return handlers + + +class TrainBegin(EventHandler): def train_begin(self, estimator, *args, **kwargs): pass -class TrainEnd(object): +class TrainEnd(EventHandler): def train_end(self, estimator, *args, **kwargs): pass -class EpochBegin(object): +class EpochBegin(EventHandler): def epoch_begin(self, estimator, *args, **kwargs): pass -class EpochEnd(object): +class EpochEnd(EventHandler): def epoch_end(self, estimator, *args, **kwargs): return False -class BatchBegin(object): +class BatchBegin(EventHandler): def batch_begin(self, estimator, *args, **kwargs): pass -class BatchEnd(object): +class BatchEnd(EventHandler): def batch_end(self, estimator, *args, **kwargs): return False @@ -242,14 +256,16 @@ def __init__(self, file_name=None, super(LoggingHandler, self).__init__() self.logger = logging.getLogger(__name__) self.logger.setLevel(logging.INFO) - stream_handler = logging.StreamHandler() - self.logger.addHandler(stream_handler) + self._added_logging_handlers = [logging.StreamHandler()] # save logger to file only if file name or location is specified if file_name or file_location: file_name = file_name or 'estimator_log' file_location = file_location or './' file_handler = logging.FileHandler(os.path.join(file_location, file_name), mode=filemode) - self.logger.addHandler(file_handler) + self._added_logging_handlers.append(file_handler) + for handler in self._added_logging_handlers: + self.logger.addHandler(handler) + if verbose not in [self.LOG_PER_EPOCH, self.LOG_PER_BATCH]: raise ValueError("verbose level must be either LOG_PER_EPOCH or " "LOG_PER_BATCH, received %s. " @@ -265,6 +281,12 @@ def __init__(self, file_name=None, # it will also shut down logging at train end self.priority = np.Inf + def __del__(self): + for handler in self._added_logging_handlers: + handler.flush() + self.logger.removeHandler(handler) + handler.close() + def train_begin(self, estimator, *args, **kwargs): self.train_start = time.time() trainer = estimator.trainer @@ -393,8 +415,8 @@ def __init__(self, self.model_prefix = model_prefix self.save_best = save_best if self.save_best and not isinstance(self.monitor, EvalMetric): - raise ValueError("To save best model only, please provide one of the metric objects as monitor, " - "You can get these objects using estimator.prepare_loss_and_metric()") + raise ValueError("To save best model only, please provide one of the metric objects " + "from estimator.train_metrics and estimator.val_metrics as monitor.") self.epoch_period = epoch_period self.batch_period = batch_period self.current_batch = 0 @@ -487,10 +509,10 @@ def _save_checkpoint(self, estimator): monitor_name, monitor_value = self.monitor.get() # check if monitor exists in train stats if np.isnan(monitor_value): - warnings.warn(RuntimeWarning('Skipping save best because %s is not updated, make sure you ' - 'pass one of the metric objects as monitor, ' - 'you can use estimator.prepare_loss_and_metrics to' - 'create all metric objects', monitor_name)) + warnings.warn(RuntimeWarning( + 'Skipping save best because %s is not updated, make sure you pass one of the ' + 'metric objects estimator.train_metrics and estimator.val_metrics as monitor', + monitor_name)) else: if self.monitor_op(monitor_value, self.best): prefix = self.model_prefix + '-best' @@ -517,7 +539,7 @@ def _save_symbol(self, estimator): sym.save(symbol_file) else: self.logger.info("Model architecture(symbol file) is not saved, please use HybridBlock " - "to construct your model, can call net.hybridize() before passing to " + "to construct your model, and call net.hybridize() before passing to " "Estimator in order to save model architecture as %s.", symbol_file) def _save_params_and_trainer(self, estimator, file_prefix): @@ -636,8 +658,9 @@ def __init__(self, super(EarlyStoppingHandler, self).__init__() if not isinstance(monitor, EvalMetric): - raise ValueError("Please provide one of the metric objects as monitor, " - "You can create these objects using estimator.prepare_loss_and_metric()") + raise ValueError( + "Please provide one of the metric objects from estimator.train_metrics and " + "estimator.val_metrics as monitor.") if isinstance(monitor, CompositeEvalMetric): raise ValueError("CompositeEvalMetric is not supported for EarlyStoppingHandler, " "please specify a simple metric instead.") @@ -693,9 +716,9 @@ def train_begin(self, estimator, *args, **kwargs): def epoch_end(self, estimator, *args, **kwargs): monitor_name, monitor_value = self.monitor.get() if np.isnan(monitor_value): - warnings.warn(RuntimeWarning('%s is not updated, make sure you pass one of the metric objects' - 'as monitor, you can use estimator.prepare_loss_and_metrics to' - 'create all metric objects', monitor_name)) + warnings.warn(RuntimeWarning( + '%s is not updated, make sure you pass one of the metric objects from' + 'estimator.train_metrics and estimator.val_metrics as monitor.', monitor_name)) else: if self.monitor_op(monitor_value - self.min_delta, self.best): self.best = monitor_value diff --git a/python/mxnet/gluon/contrib/estimator/utils.py b/python/mxnet/gluon/contrib/estimator/utils.py index f5be0878e0d9..d9126a2f6763 100644 --- a/python/mxnet/gluon/contrib/estimator/utils.py +++ b/python/mxnet/gluon/contrib/estimator/utils.py @@ -19,7 +19,8 @@ # pylint: disable=wildcard-import, unused-variable """Gluon Estimator Utility Functions""" -from ....metric import EvalMetric, CompositeEvalMetric +from ...loss import SoftmaxCrossEntropyLoss +from ....metric import Accuracy, EvalMetric, CompositeEvalMetric def _check_metrics(metrics): if isinstance(metrics, CompositeEvalMetric): @@ -30,5 +31,31 @@ def _check_metrics(metrics): metrics = metrics or [] if not all([isinstance(metric, EvalMetric) for metric in metrics]): raise ValueError("metrics must be a Metric or a list of Metric, " - "refer to mxnet.metric.EvalMetric:{}".format(metrics)) + "refer to mxnet.metric.EvalMetric: {}".format(metrics)) return metrics + +def _check_handler_metric_ref(handler, known_metrics): + for attribute in dir(handler): + if any(keyword in attribute for keyword in ['metric' or 'monitor']): + reference = getattr(handler, attribute) + if not reference: + continue + elif isinstance(reference, list): + for metric in reference: + _check_metric_known(handler, metric, known_metrics) + else: + _check_metric_known(handler, reference, known_metrics) + +def _check_metric_known(handler, metric, known_metrics): + if metric not in known_metrics: + raise ValueError( + 'Event handler {} refers to a metric instance {} outside of ' + 'the known training and validation metrics. Please use the metrics from ' + 'estimator.train_metrics and estimator.val_metrics ' + 'instead.'.format(type(handler).__name__, + metric)) + +def _suggest_metric_for_loss(loss): + if isinstance(loss, SoftmaxCrossEntropyLoss): + return Accuracy() + return None diff --git a/tests/python/unittest/test_gluon_estimator.py b/tests/python/unittest/test_gluon_estimator.py index ae47d925670f..bae576734a3e 100644 --- a/tests/python/unittest/test_gluon_estimator.py +++ b/tests/python/unittest/test_gluon_estimator.py @@ -96,7 +96,8 @@ def test_validation(): epochs=num_epochs) # using validation handler - train_metrics, val_metrics = est.prepare_loss_and_metrics() + train_metrics = est.train_metrics + val_metrics = est.val_metrics validation_handler = ValidationHandler(val_data=dataloader, eval_fn=est.evaluate, val_metrics=val_metrics) @@ -222,7 +223,6 @@ def test_metric(): loss=loss, trainer=trainer, context=ctx) - est.prepare_loss_and_metrics() assert isinstance(est.train_metrics[0], mx.metric.Accuracy) @@ -343,11 +343,11 @@ def test_default_handlers(): # handler with prepared loss and metrics # use mix of default and user defined handlers - train_metrics, val_metrics = est.prepare_loss_and_metrics() + train_metrics = est.train_metrics + val_metrics = est.val_metrics logging = LoggingHandler(train_metrics=train_metrics, val_metrics=val_metrics) with warnings.catch_warnings(record=True) as w: est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[logging]) - assert 'You are training with the' in str(w[-1].message) # provide metric handler by default assert 'MetricHandler' in str(w[-1].message) @@ -364,7 +364,8 @@ def test_default_handlers(): est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[logging]) # test handler order - train_metrics, val_metrics = est.prepare_loss_and_metrics() + train_metrics = est.train_metrics + val_metrics = est.val_metrics early_stopping = EarlyStoppingHandler(monitor=val_metrics[0]) handlers = est._prepare_default_handlers(val_data=None, event_handlers=[early_stopping]) assert len(handlers) == 4 diff --git a/tests/python/unittest/test_gluon_event_handler.py b/tests/python/unittest/test_gluon_event_handler.py index 7ea5ff3f4b62..b29c72a0f908 100644 --- a/tests/python/unittest/test_gluon_event_handler.py +++ b/tests/python/unittest/test_gluon_event_handler.py @@ -143,7 +143,8 @@ def test_logging(): ce_loss = loss.SoftmaxCrossEntropyLoss() acc = mx.metric.Accuracy() est = estimator.Estimator(net, loss=ce_loss, metrics=acc) - train_metrics, val_metrics = est.prepare_loss_and_metrics() + train_metrics = est.train_metrics + val_metrics = est.val_metrics logging_handler = event_handler.LoggingHandler(file_name=file_name, file_location=tmpdir, train_metrics=train_metrics,