From ffccae1760be0710f653e8c37f45f618474dca04 Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Fri, 12 Apr 2019 11:06:09 -0700 Subject: [PATCH 01/10] improve event handlers --- .../gluon/contrib/estimator/estimator.py | 242 +++++++----------- .../gluon/contrib/estimator/event_handler.py | 186 +++++++++----- 2 files changed, 216 insertions(+), 212 deletions(-) diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py index f7c97c43cd4b..2d66b4c645e0 100644 --- a/python/mxnet/gluon/contrib/estimator/estimator.py +++ b/python/mxnet/gluon/contrib/estimator/estimator.py @@ -20,8 +20,7 @@ """Gluon Estimator""" import copy -import warnings -from .event_handler import EventHandler, LoggingHandler +from .event_handler import * from .... import gluon, autograd from ....context import Context, cpu, gpu, num_gpus from ....metric import EvalMetric, Loss, Accuracy @@ -57,46 +56,41 @@ def __init__(self, net, context=None): self.net = net + self.loss = self._check_loss(loss) + self.train_metrics = self._check_metrics(metrics) + # Use default mx.metric.Accuracy() for gluon.loss.SoftmaxCrossEntropyLoss() + if not self.train_metrics and any([isinstance(l, gluon.loss.SoftmaxCrossEntropyLoss) for l in self.loss]): + self.train_metrics = [Accuracy()] + + self.context = self._check_context(context) + self._initialize(initializer) + self.trainer = self._check_trainer(trainer) + + def _check_loss(self, loss): if isinstance(loss, gluon.loss.Loss): - self.loss = [loss] - elif isinstance(loss, list) and all([isinstance(l, gluon.loss.Loss) for l in loss]): - self.loss = loss - else: + loss = [loss] + elif isinstance(loss, list) and not all([isinstance(l, gluon.loss.Loss) for l in loss]): raise ValueError("loss must be a Loss or a list of Loss, " "refer to gluon.loss.Loss:{}".format(loss)) + return loss + def _check_metrics(self, metrics): if isinstance(metrics, EvalMetric): - self.train_metrics = [metrics] + metrics = [metrics] else: self.train_metrics = metrics or [] if not all([isinstance(metric, EvalMetric) for metric in self.train_metrics]): raise ValueError("metrics must be a Metric or a list of Metric, " "refer to mxnet.metric.EvalMetric:{}".format(metrics)) + return metrics - # Use default mx.metric.Accuracy() for gluon.loss.SoftmaxCrossEntropyLoss() - if not self.train_metrics and any([isinstance(l, gluon.loss.SoftmaxCrossEntropyLoss) for l in self.loss]): - self.train_metrics = [Accuracy()] - - # Use same metrics for validation - self.val_metrics = copy.deepcopy(self.train_metrics) - - # store training statistics - self.train_stats = {} - - # separate train and validation - self.train_loss_metrics = [] - self.val_loss_metrics = [] - # using the metric wrapper for loss to record loss value - for l in self.loss: - self.train_loss_metrics.append(Loss(l.name)) - self.val_loss_metrics.append(Loss(l.name)) - + def _check_context(self, context): # handle context if isinstance(context, Context): - self.context = [context] + context = [context] elif isinstance(context, list) and all([isinstance(c, Context) for c in context]): - self.context = context + context = context elif not context: if num_gpus() > 0: # only use 1 GPU by default @@ -104,40 +98,41 @@ def __init__(self, net, warnings.warn("You have multiple GPUs, gpu(0) will be used by default." "To utilize all your GPUs, specify context as a list of gpus, " "e.g. context=[mx.gpu(0), mx.gpu(1)] ") - self.context = [gpu(0)] + context = [gpu(0)] else: - self.context = [cpu()] + context = [cpu()] else: raise ValueError("context must be a Context or a list of Context, " "refer to mxnet.Context:{}".format(context)) + return context + def _initialize(self, initializer): # initialize the network - self.initializer = initializer - if self.initializer: + if initializer: if self._is_initialized(): # if already initialized, re-init with user specified initializer warnings.warn("Network already initialized, re-initializing with %s. " "You don't need to pass initializer if you already " - "initialized your net." % type(self.initializer).__name__) - self.net.initialize(init=self.initializer, ctx=self.context, force_reinit=True) + "initialized your net." % type(initializer).__name__) + self.net.initialize(init=initializer, ctx=self.context, force_reinit=True) else: # initialize with user specified initializer - self.net.initialize(init=self.initializer, ctx=self.context, force_reinit=False) + self.net.initialize(init=initializer, ctx=self.context, force_reinit=False) else: if not self._is_initialized(): self.net.initialize(ctx=self.context) + def _check_trainer(self, trainer): # handle trainer if not trainer: warnings.warn("No trainer specified, default SGD optimizer " "with learning rate 0.001 is used.") - self.trainer = gluon.Trainer(self.net.collect_params(), + trainer = gluon.Trainer(self.net.collect_params(), 'sgd', {'learning_rate': 0.001}) elif not isinstance(trainer, gluon.Trainer): raise ValueError("Trainer must be a Gluon Trainer instance, refer to " "gluon.Trainer:{}".format(trainer)) - else: - self.trainer = trainer + return trainer def _is_initialized(self): param_dict = self.net.collect_params() @@ -148,20 +143,31 @@ def _is_initialized(self): return False return True - def _batch_fn(self, batch, ctx, is_iterator=False): - if is_iterator: - data = batch.data[0] - label = batch.label[0] - else: - data = batch[0] - label = batch[1] + def _get_data_and_label(self, batch, ctx): + data = batch[0] + label = batch[1] data = gluon.utils.split_and_load(data, ctx_list=ctx, batch_axis=0) label = gluon.utils.split_and_load(label, ctx_list=ctx, batch_axis=0) return data, label + def prepare_loss_and_metrics(self): + train_loss = [] + val_loss = [] + val_metrics = [] + for loss in self.loss: + train_loss.append(Loss("Train " + loss.name)) + val_loss.append(Loss("Val " + loss.name)) + for metric in self.train_metrics: + metric.name = "Train " + metric.name + val_metric = copy.deepcopy(metric) + val_metric.name = "Val " + val_metric.name + val_metrics.append(val_metric) + return train_loss, self.train_metrics, val_loss, val_metrics + def evaluate(self, val_data, - batch_fn=None): + val_loss, + val_metrics): """Evaluate model on validation data Parameters @@ -173,38 +179,27 @@ def evaluate(self, from a data batch and load into contexts(devices) """ - for metric in self.val_metrics + self.val_loss_metrics: + for metric in val_loss + val_metrics: metric.reset() for _, batch in enumerate(val_data): - if not batch_fn: - if isinstance(val_data, gluon.data.DataLoader): - data, label = self._batch_fn(batch, self.context) - else: - raise ValueError("You are using a custom iteration, please also provide " - "batch_fn to extract data and label. Alternatively, you " - "can provide the data as gluon.data.DataLoader.") - else: - data, label = batch_fn(batch, self.context) + if not isinstance(val_data, gluon.data.DataLoader): + raise ValueError("Estimator only support input as Gluon DataLoader. Alternatively, you " + "can transform your DataIter or any NDArray into Gluon DataLoader. " + "Refer to gluon.data.dataloader") + data, label = self._get_data_and_label(batch, self.context) pred = [self.net(x) for x in data] - losses = [] - for loss in self.loss: - losses.append([loss(y_hat, y) for y_hat, y in zip(pred, label)]) + loss = [self.loss[0](y_hat, y) for y_hat, y in zip(pred, label)] # update metrics - for metric in self.val_metrics: + for metric in val_metrics: metric.update(label, pred) - name, value = metric.get() - self.train_stats['val_' + name] = value - for loss, loss_metric, in zip(losses, self.val_loss_metrics): - loss_metric.update(0, [l for l in loss]) - name, value = loss_metric.get() - self.train_stats['val_' + name] = value + for metric in val_loss: + metric.update(0, loss) def fit(self, train_data, val_data=None, epochs=1, - event_handlers=None, - batch_fn=None): + event_handlers=None): """Trains the model on a given dataset for a specified number of epochs. Also, the batch size is inferred from the DataLoader's batch_size. @@ -227,54 +222,40 @@ def fit(self, train_data, from a data batch and load into contexts(devices) """ - self.max_epoch = epochs - self.stop_training = False - self.processed_samples = None - self.batch_idx = 0 - event_handlers = event_handlers or [] # provide default logging handler - if not event_handlers or \ - not any(isinstance(handler, LoggingHandler) for handler in event_handlers): - event_handlers.append(LoggingHandler()) - warnings.warn("No Event Handler specified, default `LoggingHandler()` " - "is used with verbose=LoggingHandler.LOG_VERBOSITY_PER_EPOCH. " - "Please look at gluon.estimator.event_handler for more detail.") + if not event_handlers: + train_loss, train_metrics, val_loss, val_metrics = self.prepare_loss_and_metrics() + event_handlers.append(MetricHandler(train_metrics=train_metrics, train_loss=train_loss)) + if val_data: + event_handlers.append(ValidationHandler(val_data=val_data, eval_fn=self.evaluate, + val_loss=val_loss, val_metrics=val_metrics)) + event_handlers.append(LoggingHandler(train_metrics=train_metrics + train_loss, + val_metrics=val_metrics + val_loss)) + warnings.warn("No Event Handler specified, default %s are used. " + "Please look at gluon.contrib.estimator.event_handler for more detail." % + ", ".join([handler.__class__.__name__ for handler in event_handlers])) + + event_handlers.sort(key= lambda handler : getattr(handler, 'rank', 0), reverse=True) train_begin, epoch_begin, batch_begin, \ batch_end, epoch_end, train_end = self._categorize_handlers(event_handlers) - # passing estimator to event handlers so they can access estimator information - # when a event is triggered - for handler in event_handlers: - handler.estimator = self - # training begin for handler in train_begin: - handler.train_begin() + handler.train_begin(trainer=self.trainer, epochs=epochs) - for epoch in range(self.max_epoch): + for epoch in range(epochs): # epoch begin - self.current_epoch = epoch - # Number of samples trained after every batch - completed_samples = 0 - for handler in epoch_begin: - handler.epoch_begin() - - for metric in self.train_metrics + self.train_loss_metrics: - metric.reset() + handler.epoch_begin(trainer=self.trainer, epochs=epochs) for i, batch in enumerate(train_data): - if not batch_fn: - if isinstance(train_data, gluon.data.DataLoader): - data, label = self._batch_fn(batch, self.context) - else: - raise ValueError("You are using a custom iteration, please also provide " - "batch_fn to extract data and label. Alternatively, you " - "can provide the data as gluon.data.DataLoader") - else: - data, label = batch_fn(batch, self.context) + if not isinstance(train_data, gluon.data.DataLoader): + raise ValueError("Estimator only support input as Gluon DataLoader. Alternatively, you " + "can transform your DataIter or any NDArray into Gluon DataLoader. " + "Refer to gluon.data.dataloader") + data, label = self._get_data_and_label(batch, self.context) batch_size = batch[0].shape[0] @@ -284,49 +265,22 @@ def fit(self, train_data, with autograd.record(): pred = [self.net(x) for x in data] - losses = [] - for loss in self.loss: - losses.append([loss(y_hat, y) for y_hat, y in zip(pred, label)]) - - for loss in losses: - for l in loss: - l.backward() - - # update train metrics - for metric in self.train_metrics: - metric.update(label, pred) - # get metric name and current value and update train stats - name, value = metric.get() - self.train_stats['train_' + name] = value - - # update loss - for loss, loss_metric, in zip(losses, self.train_loss_metrics): - loss_metric.update(0, [l for l in loss]) - name, value = loss_metric.get() - self.train_stats['train_' + name] = value - - completed_samples += batch_size - - self.batch_idx = i - # record trained samples v.s. total samples if using Gluon DataLoader - if isinstance(train_data, gluon.data.DataLoader): - self.processed_samples = "{}/{}".format(completed_samples, - len(train_data._dataset)) + loss = [self.loss[0](y_hat, y) for y_hat, y in zip(pred, label)] + + for l in loss: + l.backward() self.trainer.step(batch_size) # batch end for handler in batch_end: - handler.batch_end() - - if val_data: - self.evaluate(val_data, batch_fn) + if handler.batch_end(batch_size=batch_size, + pred=pred, + label=label, + loss=loss): break # epoch end for handler in epoch_end: - handler.epoch_end() - - if self.stop_training: - break + if handler.epoch_end(): break # train end for handler in train_end: @@ -346,16 +300,16 @@ def _categorize_handlers(self, event_handlers): epoch_end = [] train_end = [] for handler in event_handlers: - if not handler.__class__.train_begin == EventHandler.train_begin: + if isinstance(handler, TrainBegin): train_begin.append(handler) - if not handler.__class__.epoch_begin == EventHandler.epoch_begin: + if isinstance(handler, EpochBegin): epoch_begin.append(handler) - if not handler.__class__.batch_begin == EventHandler.batch_begin: + if isinstance(handler, BatchBegin): batch_begin.append(handler) - if not handler.__class__.batch_end == EventHandler.batch_end: + if isinstance(handler, BatchEnd): batch_end.append(handler) - if not handler.__class__.epoch_end == EventHandler.epoch_end: + if isinstance(handler, EpochEnd): epoch_end.append(handler) - if not handler.__class__.train_end == EventHandler.train_end: + if isinstance(handler, TrainEnd): train_end.append(handler) return train_begin, epoch_begin, batch_begin, batch_end, epoch_end, train_end diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py index 53c0bf5bde86..2d49ecc0d1ef 100644 --- a/python/mxnet/gluon/contrib/estimator/event_handler.py +++ b/python/mxnet/gluon/contrib/estimator/event_handler.py @@ -19,7 +19,6 @@ # pylint: disable=wildcard-import """Gluon EventHandlers for Estimators""" -__all__ = ['EventHandler', 'LoggingHandler'] import logging import os import time @@ -28,50 +27,89 @@ import numpy as np -class EventHandler(object): - """Basic for event handlers - - :py:class:`EventHandler` can perform user defined functions at - different stages of training: train begin, epoch begin, batch begin, - batch end, epoch end, train end. - - Parameters - ---------- - estimator : Estimator - The :py:class:`Estimator` to get training statistics - """ - - def __init__(self): - self._estimator = None - - @property - def estimator(self): - return self._estimator - - @estimator.setter - def estimator(self, estimator): - self._estimator = estimator - - def train_begin(self): - pass - - def train_end(self): +class TrainBegin(object): + def train_begin(self, *args, **kwargs): pass - def batch_begin(self): +class TrainEnd(object): + def train_end(self, *args, **kwargs): pass - def batch_end(self): +class EpochBegin(object): + def epoch_begin(self, *args, **kwargs): pass - def epoch_begin(self): - pass +class EpochEnd(object): + def epoch_end(self, *args, **kwargs): + return False - def epoch_end(self): +class BatchBegin(object): + def batch_begin(self, *args, **kwargs): pass - -class LoggingHandler(EventHandler): +class BatchEnd(object): + def batch_end(self, *args, **kwargs): + return False + + +class MetricHandler(EpochBegin, BatchEnd): + def __init__(self, train_loss, train_metrics): + self.train_loss = train_loss + self.train_metrics = train_metrics + # order to be called among all callbacks + # metrics need to be calculated before other callbacks can access them + self.rank = 1 + + def epoch_begin(self, *args, **kwargs): + for metric in self.train_loss + self.train_metrics: + metric.reset() + + def batch_end(self, *args, **kwargs): + pred = kwargs['pred'] + label = kwargs['label'] + loss = kwargs['loss'] + for metric in self.train_metrics: + metric.update(label, pred) + for metric in self.train_loss: + metric.update(0, loss) + +class ValidationHandler(BatchEnd, EpochEnd): + def __init__(self, + val_data, + eval_fn, + val_loss, + val_metrics=None, + epoch_period=1, + batch_period=None): + self.val_data = val_data + self.eval_fn = eval_fn + self.epoch_period = epoch_period + self.batch_period = batch_period + self.val_loss = val_loss + self.val_metrics = val_metrics + self.num_batches = 0 + self.num_epochs = 0 + # order to be called among all callbacks + # validation metrics need to be calculated before other callbacks can access them + self.rank = 1 + + def batch_end(self, *args, **kwargs): + if self.batch_period and self.num_batches % self.batch_period == 0: + self.eval_fn(val_data=self.val_data, + val_loss= self.val_loss, + val_metrics=self.val_metrics) + self.num_batches += 1 + + def epoch_end(self, *args, **kwargs): + if self.num_epochs % self.epoch_period == 0: + self.eval_fn(val_data=self.val_data, + val_loss= self.val_loss, + val_metrics=self.val_metrics) + + self.num_epochs += 1 + + +class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, BatchEnd): """Basic Logging Handler that applies to every Gluon estimator by default. :py:class:`LoggingHandler` logs hyper-parameters, training statistics, @@ -94,7 +132,11 @@ class LoggingHandler(EventHandler): LOG_VERBOSITY_PER_EPOCH = 1 LOG_VERBOSITY_PER_BATCH = 2 - def __init__(self, file_name=None, file_location=None, verbose=LOG_VERBOSITY_PER_EPOCH): + def __init__(self, file_name=None, + file_location=None, + verbose=LOG_VERBOSITY_PER_EPOCH, + train_metrics=None, + val_metrics=None): super(LoggingHandler, self).__init__() self.logger = logging.getLogger(__name__) self.logger.setLevel(logging.INFO) @@ -112,59 +154,67 @@ def __init__(self, file_name=None, file_location=None, verbose=LOG_VERBOSITY_PER file_location = file_location or './' file_handler = logging.FileHandler(os.path.join(file_location, file_name)) self.logger.addHandler(file_handler) + self.train_metrics = train_metrics + self.val_metrics = val_metrics + self.batch_index = 0 + self.current_epoch = 0 + self.processed_samples = 0 - def train_begin(self): - self.train_start = time.time() - self.logger.info("Training begin: using optimizer %s " - "with current learning rate %.4f ", - self.estimator.trainer.optimizer.__class__.__name__, - self.estimator.trainer.learning_rate) - self.logger.info("Train for %d epochs.", self.estimator.max_epoch) - def train_end(self): + def train_begin(self, *args, **kwargs): + self.train_start = time.time() + if 'trainer' in kwargs: + optimizer = kwargs['trainer'].optimizer.__class__.__name__ + lr = kwargs['trainer'].learning_rate + self.logger.info("Training begin: using optimizer %s " + "with current learning rate %.4f ", + optimizer, lr) + if 'epochs' in kwargs: + self.logger.info("Train for %d epochs.", kwargs['epochs']) + + def train_end(self, *args, **kwargs): train_time = time.time() - self.train_start - epoch = self.estimator.current_epoch - msg = 'Train finished using total %ds at epoch %d. ' % (train_time, epoch) + msg = 'Train finished using total %ds with %d epochs.' % (train_time, self.current_epoch) # log every result in train stats including train/validation loss & metrics - for key in self.estimator.train_stats: - msg += '%s : %.4f ' % (key, self.estimator.train_stats[key]) + for metric in self.train_metrics + self.val_metrics: + name, value = metric.get() + msg += '%s : %.4f ' % (name, value) self.logger.info(msg) - def batch_begin(self): + def batch_begin(self, *args, **kwargs): if self.verbose == self.LOG_VERBOSITY_PER_BATCH: self.batch_start = time.time() - def batch_end(self): + def batch_end(self, *args, **kwargs): if self.verbose == self.LOG_VERBOSITY_PER_BATCH: batch_time = time.time() - self.batch_start - epoch = self.estimator.current_epoch - batch = self.estimator.batch_idx - msg = '[Epoch %d] [Batch %d] ' % (epoch, batch) - if self.estimator.processed_samples: - msg += '[Samples %s] ' % (self.estimator.processed_samples) + msg = '[Epoch %d] [Batch %d] ' % (self.current_epoch, self.batch_index) + self.processed_samples += kwargs['batch_size'] + msg += '[Samples %s] ' % (self.processed_samples) msg += 'time/batch: %.3fs ' % batch_time - for key in self.estimator.train_stats: + for metric in self.train_metrics: # only log current training loss & metric after each batch - if key.startswith('train_'): - msg += key + ': ' + '%.4f ' % self.estimator.train_stats[key] + name, value = metric.get() + msg += '%s : %.4f ' % (name, value) self.logger.info(msg) + self.batch_index += 1 - def epoch_begin(self): + def epoch_begin(self, *args, **kwargs): if self.verbose >= self.LOG_VERBOSITY_PER_EPOCH: self.epoch_start = time.time() - def epoch_end(self): + def epoch_end(self, *args, **kwargs): if self.verbose >= self.LOG_VERBOSITY_PER_EPOCH: epoch_time = time.time() - self.epoch_start - epoch = self.estimator.current_epoch - msg = '\n[Epoch %d] finished in %.3fs: ' % (epoch, epoch_time) - # log every result in train stats including train/validation loss & metrics - for key in self.estimator.train_stats: - msg += '%s : %.4f ' % (key, self.estimator.train_stats[key]) + msg = '\n[Epoch %d] finished in %.3fs: ' % (self.current_epoch, epoch_time) + for monitor in self.train_metrics + self.val_metrics: + name, value = monitor.get() + msg += '%s : %.4f ' % (name, value) self.logger.info(msg) + self.current_epoch += 1 -class CheckpointHandler(EventHandler): +class CheckpointHandler(object): """Save the model after every epoch. :py:class:`CheckpointHandler` save the network parameters every epoch @@ -260,7 +310,7 @@ def epoch_end(self, ): self.estimator.net.save_parameters(self.filepath) -class EarlyStoppingHandler(EventHandler): +class EarlyStoppingHandler(object): """Early stop training if monitored value is not improving Parameters From 81812480bdbcc202eccd28ad5f2e8c32c3a8ee47 Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Fri, 12 Apr 2019 14:33:03 -0700 Subject: [PATCH 02/10] update tests --- .../gluon/contrib/estimator/estimator.py | 61 ++++--- .../gluon/contrib/estimator/event_handler.py | 157 ++++++++++-------- tests/nightly/estimator/test_estimator_cnn.py | 2 +- tests/nightly/estimator/test_sentiment_rnn.py | 6 +- tests/python/unittest/test_gluon_estimator.py | 15 +- .../unittest/test_gluon_event_handler.py | 27 +-- 6 files changed, 152 insertions(+), 116 deletions(-) diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py index 2d66b4c645e0..361f3da40fe6 100644 --- a/python/mxnet/gluon/contrib/estimator/estimator.py +++ b/python/mxnet/gluon/contrib/estimator/estimator.py @@ -20,6 +20,7 @@ """Gluon Estimator""" import copy + from .event_handler import * from .... import gluon, autograd from ....context import Context, cpu, gpu, num_gpus @@ -70,7 +71,9 @@ def __init__(self, net, def _check_loss(self, loss): if isinstance(loss, gluon.loss.Loss): loss = [loss] - elif isinstance(loss, list) and not all([isinstance(l, gluon.loss.Loss) for l in loss]): + elif isinstance(loss, list) or all([isinstance(l, gluon.loss.Loss) for l in loss]): + loss = loss + else: raise ValueError("loss must be a Loss or a list of Loss, " "refer to gluon.loss.Loss:{}".format(loss)) return loss @@ -79,8 +82,8 @@ def _check_metrics(self, metrics): if isinstance(metrics, EvalMetric): metrics = [metrics] else: - self.train_metrics = metrics or [] - if not all([isinstance(metric, EvalMetric) for metric in self.train_metrics]): + metrics = metrics or [] + if not all([isinstance(metric, EvalMetric) for metric in metrics]): raise ValueError("metrics must be a Metric or a list of Metric, " "refer to mxnet.metric.EvalMetric:{}".format(metrics)) return metrics @@ -128,7 +131,7 @@ def _check_trainer(self, trainer): warnings.warn("No trainer specified, default SGD optimizer " "with learning rate 0.001 is used.") trainer = gluon.Trainer(self.net.collect_params(), - 'sgd', {'learning_rate': 0.001}) + 'sgd', {'learning_rate': 0.001}) elif not isinstance(trainer, gluon.Trainer): raise ValueError("Trainer must be a Gluon Trainer instance, refer to " "gluon.Trainer:{}".format(trainer)) @@ -151,18 +154,27 @@ def _get_data_and_label(self, batch, ctx): return data, label def prepare_loss_and_metrics(self): - train_loss = [] - val_loss = [] - val_metrics = [] - for loss in self.loss: - train_loss.append(Loss("Train " + loss.name)) - val_loss.append(Loss("Val " + loss.name)) - for metric in self.train_metrics: - metric.name = "Train " + metric.name - val_metric = copy.deepcopy(metric) - val_metric.name = "Val " + val_metric.name - val_metrics.append(val_metric) - return train_loss, self.train_metrics, val_loss, val_metrics + """ + Based on loss functions and training metrics in estimator + Create metric wrappers to record loss values, + Create copies of train loss/metric objects to record validation values + """ + if all(hasattr(self, attribute) for attribute in + ['train_loss', 'train_metrics', 'val_loss', 'val_metrics']): + return self.train_loss, self.train_metrics, self.val_loss, self.val_metrics + else: + self.train_loss = [] + self.val_loss = [] + self.val_metrics = [] + for loss in self.loss: + self.train_loss.append(Loss("Train " + ''.join([i for i in loss.name if not i.isdigit()]))) + self.val_loss.append(Loss("Validation " + ''.join([i for i in loss.name if not i.isdigit()]))) + for metric in self.train_metrics: + val_metric = copy.deepcopy(metric) + metric.name = "Train " + metric.name + val_metric.name = "Validation " + val_metric.name + self.val_metrics.append(val_metric) + return self.train_loss, self.train_metrics, self.val_loss, self.val_metrics def evaluate(self, val_data, @@ -236,19 +248,20 @@ def fit(self, train_data, "Please look at gluon.contrib.estimator.event_handler for more detail." % ", ".join([handler.__class__.__name__ for handler in event_handlers])) - event_handlers.sort(key= lambda handler : getattr(handler, 'rank', 0), reverse=True) + event_handlers.sort(key=lambda handler: getattr(handler, 'rank', 0), reverse=True) train_begin, epoch_begin, batch_begin, \ batch_end, epoch_end, train_end = self._categorize_handlers(event_handlers) # training begin for handler in train_begin: - handler.train_begin(trainer=self.trainer, epochs=epochs) + # we only have net, trainer, and epochs to train information + handler.train_begin(net=self.net, trainer=self.trainer, epochs=epochs) for epoch in range(epochs): # epoch begin for handler in epoch_begin: - handler.epoch_begin(trainer=self.trainer, epochs=epochs) + handler.epoch_begin(net=self.net, trainer=self.trainer, epochs=epochs) for i, batch in enumerate(train_data): if not isinstance(train_data, gluon.data.DataLoader): @@ -261,7 +274,7 @@ def fit(self, train_data, # batch begin for handler in batch_begin: - handler.batch_begin() + handler.batch_begin(net=self.net, trainer=self.trainer, epochs=epochs, batch=batch) with autograd.record(): pred = [self.net(x) for x in data] @@ -273,14 +286,12 @@ def fit(self, train_data, self.trainer.step(batch_size) # batch end for handler in batch_end: - if handler.batch_end(batch_size=batch_size, - pred=pred, - label=label, - loss=loss): break + if handler.batch_end(net=self.net, trainer=self.trainer, epochs=epochs, + batch_size=batch_size, pred=pred, label=label, loss=loss): break # epoch end for handler in epoch_end: - if handler.epoch_end(): break + if handler.epoch_end(net=self.net, trainer=self.trainer, epochs=epochs): break # train end for handler in train_end: diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py index 2d49ecc0d1ef..9653b593866f 100644 --- a/python/mxnet/gluon/contrib/estimator/event_handler.py +++ b/python/mxnet/gluon/contrib/estimator/event_handler.py @@ -24,29 +24,34 @@ import time import warnings -import numpy as np +from ....metric import * class TrainBegin(object): def train_begin(self, *args, **kwargs): pass + class TrainEnd(object): def train_end(self, *args, **kwargs): pass + class EpochBegin(object): def epoch_begin(self, *args, **kwargs): pass + class EpochEnd(object): def epoch_end(self, *args, **kwargs): return False + class BatchBegin(object): def batch_begin(self, *args, **kwargs): pass + class BatchEnd(object): def batch_end(self, *args, **kwargs): return False @@ -73,6 +78,7 @@ def batch_end(self, *args, **kwargs): for metric in self.train_loss: metric.update(0, loss) + class ValidationHandler(BatchEnd, EpochEnd): def __init__(self, val_data, @@ -96,14 +102,14 @@ def __init__(self, def batch_end(self, *args, **kwargs): if self.batch_period and self.num_batches % self.batch_period == 0: self.eval_fn(val_data=self.val_data, - val_loss= self.val_loss, + val_loss=self.val_loss, val_metrics=self.val_metrics) self.num_batches += 1 def epoch_end(self, *args, **kwargs): if self.num_epochs % self.epoch_period == 0: self.eval_fn(val_data=self.val_data, - val_loss= self.val_loss, + val_loss=self.val_loss, val_metrics=self.val_metrics) self.num_epochs += 1 @@ -154,13 +160,12 @@ def __init__(self, file_name=None, file_location = file_location or './' file_handler = logging.FileHandler(os.path.join(file_location, file_name)) self.logger.addHandler(file_handler) - self.train_metrics = train_metrics - self.val_metrics = val_metrics + self.train_metrics = train_metrics or [] + self.val_metrics = val_metrics or [] self.batch_index = 0 self.current_epoch = 0 self.processed_samples = 0 - def train_begin(self, *args, **kwargs): self.train_start = time.time() if 'trainer' in kwargs: @@ -212,9 +217,10 @@ def epoch_end(self, *args, **kwargs): msg += '%s : %.4f ' % (name, value) self.logger.info(msg) self.current_epoch += 1 + self.batch_index = 0 -class CheckpointHandler(object): +class CheckpointHandler(BatchEnd, EpochEnd): """Save the model after every epoch. :py:class:`CheckpointHandler` save the network parameters every epoch @@ -226,7 +232,7 @@ class CheckpointHandler(object): filepath : str file name to save the parameters, it can contain directories, for example: ./saved_model/resnet.params - monitor: str + monitor: EvalMetric the metrics to monitor verbose: int, default 0 verbosity mode @@ -241,18 +247,23 @@ class CheckpointHandler(object): def __init__(self, filepath, - monitor='val_accuracy', + monitor=None, verbose=0, save_best_only=False, mode='auto', - period=1): - super(CheckpointHandler, self).__init__() + epoch_period=1, + batch_period=None): self.monitor = monitor self.verbose = verbose self.filepath = filepath self.save_best_only = save_best_only - self.period = period - self.epochs_since_last_save = 0 + if self.save_best_only and not isinstance(self.monitor, EvalMetric): + raise ValueError("To save best model only, please provide one of the metric objects as monitor, " + "You can create these objects using estimator.prepare_loss_and_metric()") + self.epoch_period = epoch_period + self.batch_period = batch_period + self.num_batches = 0 + self.num_epochs = 0 self.logger = logging.getLogger(__name__) if mode not in ['auto', 'min', 'max']: @@ -262,55 +273,63 @@ def __init__(self, mode = 'auto' if mode == 'min': - self.monitor_op = np.less - self.best = np.Inf + self.monitor_op = numpy.less + self.best = numpy.Inf elif mode == 'max': - self.monitor_op = np.greater - self.best = -np.Inf + self.monitor_op = numpy.greater + self.best = -numpy.Inf else: # use greater for accuracy and less otherwise - if 'acc' in self.monitor: - self.monitor_op = np.greater - self.best = -np.Inf + if 'acc' in self.monitor.get()[0].lower(): + self.monitor_op = numpy.greater + self.best = -numpy.Inf else: - self.monitor_op = np.less - self.best = np.Inf + self.monitor_op = numpy.less + self.best = numpy.Inf + + def batch_end(self, *args, **kwargs): + net = kwargs['net'] + self._save_checkpoint(net, "Batch", self.num_batches) + self.num_batches += 1 + + def epoch_end(self, *args, **kwargs): + net = kwargs['net'] + self._save_checkpoint(net, "Epoch", self.num_epochs) + self.num_epochs += 1 - def epoch_end(self, ): - epoch = self.estimator.current_epoch + def _save_checkpoint(self, net, period_name, period_value): # add extension for weights if '.params' not in self.filepath: self.filepath += '.params' - self.epochs_since_last_save += 1 - if self.epochs_since_last_save >= self.period: - self.epochs_since_last_save = 0 + if self.num_epochs % self.epoch_period == 0: if self.save_best_only: + monitor_name, monitor_value = self.monitor.get() # check if monitor exists in train stats - if self.monitor not in self.estimator.train_stats: - warnings.warn(RuntimeWarning('Unable to find %s in training statistics, make sure the monitor value' - 'starts with `train_ `or `val_` and contains loss/metric name, ', - 'for example val_accuracy', self.monitor)) - self.estimator.net.save_parameters(self.filepath) + if numpy.isnan(monitor_value): + warnings.warn(RuntimeWarning('%s is not updated, make sure you pass one of the metric objects' + 'as monitor, you can use estimator.prepare_loss_and_metrics to' + 'create all metric objects', monitor_name)) + net.save_parameters(self.filepath) else: - current = self.estimator.train_stats[self.monitor] - if self.monitor_op(current, self.best): + if self.monitor_op(monitor_value, self.best): if self.verbose > 0: - self.logger.info('\n[Epoch %d] %s improved from %0.5f to %0.5f,' + self.logger.info('\n[%s %d] %s improved from %0.5f to %0.5f,' ' saving model to %s', - epoch, self.monitor, self.best, current, self.filepath) - self.best = current - self.estimator.net.save_parameters(self.filepath) + period_name, period_value, monitor_name, + self.best, monitor_value, self.filepath) + self.best = monitor_value + net.save_parameters(self.filepath) else: if self.verbose > 0: - self.logger.info('\n[Epoch %d] %s did not improve from %0.5f, skipping save model', - epoch, self.monitor, self.best) + self.logger.info('\n[%s %d] %s did not improve from %0.5f, skipping save model', + period_name, period_value, monitor_name, self.best) else: if self.verbose > 0: - logging.info('\nEpoch %d: saving model to %s', epoch, self.filepath) - self.estimator.net.save_parameters(self.filepath) + logging.info('\n%s %d: saving model to %s', period_name, period_value, self.filepath) + net.save_parameters(self.filepath) -class EarlyStoppingHandler(object): +class EarlyStoppingHandler(TrainBegin, EpochEnd, TrainEnd): """Early stop training if monitored value is not improving Parameters @@ -331,19 +350,24 @@ class EarlyStoppingHandler(object): """ def __init__(self, - monitor='val_accuracy', + monitor, min_delta=0, patience=0, mode='auto', baseline=None): super(EarlyStoppingHandler, self).__init__() + if not isinstance(monitor, EvalMetric): + raise ValueError("Please provide one of the metric objects as monitor, " + "You can create these objects using estimator.prepare_loss_and_metric()") self.monitor = monitor self.baseline = baseline self.patience = patience self.min_delta = min_delta self.wait = 0 self.stopped_epoch = 0 + self.num_epochs = 0 + self.stop_training = False self.logger = logging.getLogger(__name__) if mode not in ['auto', 'min', 'max']: @@ -352,45 +376,46 @@ def __init__(self, mode = 'auto' if mode == 'min': - self.monitor_op = np.less + self.monitor_op = numpy.less elif mode == 'max': - self.monitor_op = np.greater + self.monitor_op = numpy.greater else: - if 'acc' in self.monitor: - self.monitor_op = np.greater + if 'acc' in self.monitor.get()[0].lower(): + self.monitor_op = numpy.greater else: - self.monitor_op = np.less + self.monitor_op = numpy.less - if self.monitor_op == np.greater: + if self.monitor_op == numpy.greater: self.min_delta *= 1 else: self.min_delta *= -1 - def train_begin(self): + def train_begin(self, *args, **kwargs): self.wait = 0 self.stopped_epoch = 0 if self.baseline is not None: self.best = self.baseline else: - self.best = np.Inf if self.monitor_op == np.less else -np.Inf - - def epoch_end(self): - epoch = self.estimator.current_epoch - if self.monitor not in self.estimator.train_stats: - warnings.warn(RuntimeWarning('Unable to find %s in training statistics, make sure the monitor value' - 'starts with `train_ `or `val_` and contains loss/metric name, ', - 'for example val_accuracy', self.monitor)) + self.best = numpy.Inf if self.monitor_op == numpy.less else -numpy.Inf + + def epoch_end(self, *args, **kwargs): + monitor_name, monitor_value = self.monitor.get() + if numpy.isnan(monitor_value): + warnings.warn(RuntimeWarning('%s is not updated, make sure you pass one of the metric objects' + 'as monitor, you can use estimator.prepare_loss_and_metrics to' + 'create all metric objects', monitor_name)) else: - current = self.estimator.train_stats[self.monitor] - if self.monitor_op(current - self.min_delta, self.best): - self.best = current + if self.monitor_op(monitor_value - self.min_delta, self.best): + self.best = monitor_value self.wait = 0 else: self.wait += 1 if self.wait >= self.patience: - self.stopped_epoch = epoch - self.estimator.stop_training = True + self.stopped_epoch = self.num_epochs + self.stop_training = True + return self.stop_training - def train_end(self): + def train_end(self, *args, **kwargs): if self.stopped_epoch > 0: - self.logger.info('Epoch %d: early stopping due to %s not improving', self.stopped_epoch, self.monitor) + self.logger.info('Epoch %d: early stopping due to %s not improving', + self.stopped_epoch, self.monitor.get()[0]) diff --git a/tests/nightly/estimator/test_estimator_cnn.py b/tests/nightly/estimator/test_estimator_cnn.py index 7d0018b0eedd..c60dc544b347 100644 --- a/tests/nightly/estimator/test_estimator_cnn.py +++ b/tests/nightly/estimator/test_estimator_cnn.py @@ -137,7 +137,7 @@ def test_estimator_gpu(): val_data=test_data, epochs=num_epochs) - assert est.train_stats['train_'+acc.name] > 0.80 + assert acc.get()[1] > 0.80 if __name__ == '__main__': parser = argparse.ArgumentParser(description='test gluon estimator') diff --git a/tests/nightly/estimator/test_sentiment_rnn.py b/tests/nightly/estimator/test_sentiment_rnn.py index 5fd93c1286fa..404bf83fb86f 100644 --- a/tests/nightly/estimator/test_sentiment_rnn.py +++ b/tests/nightly/estimator/test_sentiment_rnn.py @@ -183,7 +183,7 @@ def run(net, train_dataloader, test_dataloader, **kwargs): # Begin training est.fit(train_data=train_dataloader, val_data=test_dataloader, epochs=num_epochs) - return est + return acc def test_estimator_cpu(**kwargs): @@ -250,9 +250,9 @@ def test_estimator_gpu(**kwargs): net.embedding.weight.set_data(glove_embedding.idx_to_vec) net.embedding.collect_params().setattr('grad_req', 'null') - est = run(net, train_dataloader, test_dataloader, **kwargs) + acc = run(net, train_dataloader, test_dataloader, **kwargs) - assert est.train_stats['train_accuracy'] > 0.70 + assert acc.get()[1] > 0.70 parser = argparse.ArgumentParser(description='test gluon estimator') diff --git a/tests/python/unittest/test_gluon_estimator.py b/tests/python/unittest/test_gluon_estimator.py index 13fcd960d439..6cc23e62ed11 100644 --- a/tests/python/unittest/test_gluon_estimator.py +++ b/tests/python/unittest/test_gluon_estimator.py @@ -19,12 +19,11 @@ import sys import unittest -import warnings import mxnet as mx from mxnet import gluon from mxnet.gluon import nn -from mxnet.gluon.contrib.estimator import Estimator, EventHandler +from mxnet.gluon.contrib.estimator import * from nose.tools import assert_raises @@ -267,16 +266,12 @@ def test_context(): def test_categorize_handlers(): - class CustomHandler1(EventHandler): - def __init__(self): - super(CustomHandler1, self).__init__() + class CustomHandler1(TrainBegin): def train_begin(self): print("custom train begin") - class CustomHandler2(EventHandler): - def __init__(self): - super(CustomHandler2, self).__init__() + class CustomHandler2(EpochBegin, BatchBegin, TrainEnd): def epoch_begin(self): print("custom epoch begin") @@ -287,9 +282,7 @@ def batch_begin(self): def train_end(self): print("custom train end") - class CustomHandler3(EventHandler): - def __init__(self): - super(CustomHandler3, self).__init__() + class CustomHandler3(EpochBegin, BatchBegin, BatchEnd, TrainEnd): def epoch_begin(self): print("custom epoch begin") diff --git a/tests/python/unittest/test_gluon_event_handler.py b/tests/python/unittest/test_gluon_event_handler.py index dd2e60d43f2b..f1ccccf84946 100644 --- a/tests/python/unittest/test_gluon_event_handler.py +++ b/tests/python/unittest/test_gluon_event_handler.py @@ -17,18 +17,21 @@ import os import tempfile + import mxnet as mx from mxnet import nd from mxnet.gluon import nn, loss from mxnet.gluon.contrib.estimator import estimator, event_handler + def _get_test_network(): net = nn.Sequential() net.add(nn.Dense(128, activation='relu', in_units=100, flatten=False), - nn.Dense(64, activation='relu', in_units=128), - nn.Dense(10, activation='relu', in_units=64)) + nn.Dense(64, activation='relu', in_units=128), + nn.Dense(10, activation='relu', in_units=64)) return net + def _get_test_data(): data = nd.ones((32, 100)) label = nd.random.randint(0, 10, (32, 1)) @@ -39,46 +42,48 @@ def _get_test_data(): def test_checkpoint_handler(): tmpdir = tempfile.mkdtemp() file_path = os.path.join(tmpdir, "model.params") - test_data = _get_test_data() + test_data = _get_test_data() save_best_only = False mode = 'auto' net = _get_test_network() ce_loss = loss.SoftmaxCrossEntropyLoss() + ce_loss_metric = mx.metric.Loss(ce_loss.name) acc = mx.metric.Accuracy() est = estimator.Estimator(net, loss=ce_loss, metrics=acc) checkpoint_handler = [event_handler.CheckpointHandler(file_path, + monitor=acc, save_best_only=save_best_only, mode=mode)] est.fit(test_data, event_handlers=checkpoint_handler, epochs=1) assert os.path.isfile(file_path) os.remove(file_path) + def test_early_stopping(): test_data = _get_test_data() mode = 'max' - monitor = 'train_accuracy' patience = 0 net = _get_test_network() ce_loss = loss.SoftmaxCrossEntropyLoss() acc = mx.metric.Accuracy() est = estimator.Estimator(net, loss=ce_loss, metrics=acc) - early_stopping = [event_handler.EarlyStoppingHandler(monitor, + early_stopping = [event_handler.EarlyStoppingHandler(monitor=acc, patience=patience, mode=mode)] est.fit(test_data, event_handlers=early_stopping, epochs=3) mode = 'auto' - monitor = 'train_accuracy' patience = 2 - early_stopping = [event_handler.EarlyStoppingHandler(monitor, + early_stopping = [event_handler.EarlyStoppingHandler(monitor=acc, patience=patience, - mode=mode)] + mode=mode)] est.fit(test_data, event_handlers=early_stopping, epochs=1) + def test_logging(): tmpdir = tempfile.mkdtemp() test_data = _get_test_data() @@ -87,9 +92,11 @@ def test_logging(): net = _get_test_network() ce_loss = loss.SoftmaxCrossEntropyLoss() + ce_loss_metric = mx.metric.Loss(ce_loss.name) acc = mx.metric.Accuracy() est = estimator.Estimator(net, loss=ce_loss, metrics=acc) - logging_handler = [event_handler.LoggingHandler(file_name=file_name, file_location=tmpdir)] + logging_handler = [event_handler.LoggingHandler(file_name=file_name, + file_location=tmpdir, train_metrics=[acc, ce_loss_metric])] est.fit(test_data, event_handlers=logging_handler, epochs=1) assert os.path.isfile(output_dir) - os.remove(output_dir) \ No newline at end of file + os.remove(output_dir) From 581a38fdd8e0c593476f4d40ab897431eef31200 Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Mon, 15 Apr 2019 16:33:22 -0700 Subject: [PATCH 03/10] passing weakref of estimator --- .../gluon/contrib/estimator/estimator.py | 56 ++++++----- .../gluon/contrib/estimator/event_handler.py | 92 +++++++++---------- 2 files changed, 68 insertions(+), 80 deletions(-) diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py index 361f3da40fe6..512acc11d864 100644 --- a/python/mxnet/gluon/contrib/estimator/estimator.py +++ b/python/mxnet/gluon/contrib/estimator/estimator.py @@ -20,6 +20,7 @@ """Gluon Estimator""" import copy +import weakref from .event_handler import * from .... import gluon, autograd @@ -160,25 +161,22 @@ def prepare_loss_and_metrics(self): Create copies of train loss/metric objects to record validation values """ if all(hasattr(self, attribute) for attribute in - ['train_loss', 'train_metrics', 'val_loss', 'val_metrics']): - return self.train_loss, self.train_metrics, self.val_loss, self.val_metrics + ['train_metrics', 'val_metrics']): + return self.train_metrics, self.val_metrics else: - self.train_loss = [] - self.val_loss = [] self.val_metrics = [] for loss in self.loss: - self.train_loss.append(Loss("Train " + ''.join([i for i in loss.name if not i.isdigit()]))) - self.val_loss.append(Loss("Validation " + ''.join([i for i in loss.name if not i.isdigit()]))) + self.train_metrics.append(Loss("Train " + ''.join([i for i in loss.name if not i.isdigit()]))) + self.val_metrics.append(Loss("Validation " + ''.join([i for i in loss.name if not i.isdigit()]))) for metric in self.train_metrics: val_metric = copy.deepcopy(metric) metric.name = "Train " + metric.name val_metric.name = "Validation " + val_metric.name self.val_metrics.append(val_metric) - return self.train_loss, self.train_metrics, self.val_loss, self.val_metrics + return self.train_metrics, self.val_metrics def evaluate(self, val_data, - val_loss, val_metrics): """Evaluate model on validation data @@ -186,12 +184,11 @@ def evaluate(self, ---------- val_data : DataLoader validation data with data and labels - batch_fn : function - custom batch function to extract data and label - from a data batch and load into contexts(devices) + val_metrics : EvalMetric or list of EvalMetrics + metrics to update validation result """ - for metric in val_loss + val_metrics: + for metric in val_metrics: metric.reset() for _, batch in enumerate(val_data): @@ -204,9 +201,10 @@ def evaluate(self, loss = [self.loss[0](y_hat, y) for y_hat, y in zip(pred, label)] # update metrics for metric in val_metrics: - metric.update(label, pred) - for metric in val_loss: - metric.update(0, loss) + if isinstance(metric, Loss): + metric.update(0, loss) + else: + metric.update(label, pred) def fit(self, train_data, val_data=None, @@ -233,17 +231,17 @@ def fit(self, train_data, custom batch function to extract data and label from a data batch and load into contexts(devices) """ - + self.max_epochs = epochs event_handlers = event_handlers or [] # provide default logging handler if not event_handlers: - train_loss, train_metrics, val_loss, val_metrics = self.prepare_loss_and_metrics() - event_handlers.append(MetricHandler(train_metrics=train_metrics, train_loss=train_loss)) + train_metrics, val_metrics = self.prepare_loss_and_metrics() + event_handlers.append(MetricHandler(train_metrics=train_metrics)) if val_data: event_handlers.append(ValidationHandler(val_data=val_data, eval_fn=self.evaluate, - val_loss=val_loss, val_metrics=val_metrics)) - event_handlers.append(LoggingHandler(train_metrics=train_metrics + train_loss, - val_metrics=val_metrics + val_loss)) + val_metrics=val_metrics)) + event_handlers.append(LoggingHandler(train_metrics=train_metrics, + val_metrics=val_metrics)) warnings.warn("No Event Handler specified, default %s are used. " "Please look at gluon.contrib.estimator.event_handler for more detail." % ", ".join([handler.__class__.__name__ for handler in event_handlers])) @@ -253,15 +251,16 @@ def fit(self, train_data, train_begin, epoch_begin, batch_begin, \ batch_end, epoch_end, train_end = self._categorize_handlers(event_handlers) + # only pass a weak reference to all event handlers + estimator_ref = weakref.proxy(self) # training begin for handler in train_begin: - # we only have net, trainer, and epochs to train information - handler.train_begin(net=self.net, trainer=self.trainer, epochs=epochs) + handler.train_begin(estimator_ref) for epoch in range(epochs): # epoch begin for handler in epoch_begin: - handler.epoch_begin(net=self.net, trainer=self.trainer, epochs=epochs) + handler.epoch_begin(estimator_ref) for i, batch in enumerate(train_data): if not isinstance(train_data, gluon.data.DataLoader): @@ -274,7 +273,7 @@ def fit(self, train_data, # batch begin for handler in batch_begin: - handler.batch_begin(net=self.net, trainer=self.trainer, epochs=epochs, batch=batch) + handler.batch_begin(estimator_ref, batch=batch) with autograd.record(): pred = [self.net(x) for x in data] @@ -286,16 +285,15 @@ def fit(self, train_data, self.trainer.step(batch_size) # batch end for handler in batch_end: - if handler.batch_end(net=self.net, trainer=self.trainer, epochs=epochs, - batch_size=batch_size, pred=pred, label=label, loss=loss): break + if handler.batch_end(estimator_ref, batch=batch, pred=pred, label=label, loss=loss): break # epoch end for handler in epoch_end: - if handler.epoch_end(net=self.net, trainer=self.trainer, epochs=epochs): break + if handler.epoch_end(estimator_ref): break # train end for handler in train_end: - handler.train_end() + handler.train_end(estimator_ref) def _categorize_handlers(self, event_handlers): """ diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py index 9653b593866f..8700204545c5 100644 --- a/python/mxnet/gluon/contrib/estimator/event_handler.py +++ b/python/mxnet/gluon/contrib/estimator/event_handler.py @@ -23,67 +23,67 @@ import os import time import warnings - +import weakref from ....metric import * class TrainBegin(object): - def train_begin(self, *args, **kwargs): + def train_begin(self, estimator, *args, **kwargs): pass class TrainEnd(object): - def train_end(self, *args, **kwargs): + def train_end(self, estimator, *args, **kwargs): pass class EpochBegin(object): - def epoch_begin(self, *args, **kwargs): + def epoch_begin(self, estimator, *args, **kwargs): pass class EpochEnd(object): - def epoch_end(self, *args, **kwargs): + def epoch_end(self, estimator, *args, **kwargs): return False class BatchBegin(object): - def batch_begin(self, *args, **kwargs): + def batch_begin(self, estimator, *args, **kwargs): pass class BatchEnd(object): - def batch_end(self, *args, **kwargs): + def batch_end(self, estimator, *args, **kwargs): return False class MetricHandler(EpochBegin, BatchEnd): - def __init__(self, train_loss, train_metrics): - self.train_loss = train_loss + def __init__(self, train_metrics): self.train_metrics = train_metrics # order to be called among all callbacks # metrics need to be calculated before other callbacks can access them self.rank = 1 - def epoch_begin(self, *args, **kwargs): - for metric in self.train_loss + self.train_metrics: + def epoch_begin(self, estimator, *args, **kwargs): + for metric in self.train_metrics: metric.reset() - def batch_end(self, *args, **kwargs): + def batch_end(self, estimator, *args, **kwargs): pred = kwargs['pred'] label = kwargs['label'] loss = kwargs['loss'] for metric in self.train_metrics: - metric.update(label, pred) - for metric in self.train_loss: - metric.update(0, loss) + if isinstance(metric, Loss): + # metric wrapper for loss values + metric.update(0, loss) + else: + metric.update(label, pred) class ValidationHandler(BatchEnd, EpochEnd): def __init__(self, val_data, eval_fn, - val_loss, val_metrics=None, epoch_period=1, batch_period=None): @@ -91,7 +91,6 @@ def __init__(self, self.eval_fn = eval_fn self.epoch_period = epoch_period self.batch_period = batch_period - self.val_loss = val_loss self.val_metrics = val_metrics self.num_batches = 0 self.num_epochs = 0 @@ -99,17 +98,15 @@ def __init__(self, # validation metrics need to be calculated before other callbacks can access them self.rank = 1 - def batch_end(self, *args, **kwargs): + def batch_end(self, estimator, *args, **kwargs): if self.batch_period and self.num_batches % self.batch_period == 0: self.eval_fn(val_data=self.val_data, - val_loss=self.val_loss, val_metrics=self.val_metrics) self.num_batches += 1 - def epoch_end(self, *args, **kwargs): + def epoch_end(self, estimator, *args, **kwargs): if self.num_epochs % self.epoch_period == 0: self.eval_fn(val_data=self.val_data, - val_loss=self.val_loss, val_metrics=self.val_metrics) self.num_epochs += 1 @@ -123,8 +120,6 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat Parameters ---------- - estimator : Estimator - The :py:class:`Estimator` to get training statistics file_name : str file name to save the logs file_location: str @@ -166,18 +161,17 @@ def __init__(self, file_name=None, self.current_epoch = 0 self.processed_samples = 0 - def train_begin(self, *args, **kwargs): + def train_begin(self, estimator, *args, **kwargs): self.train_start = time.time() - if 'trainer' in kwargs: - optimizer = kwargs['trainer'].optimizer.__class__.__name__ - lr = kwargs['trainer'].learning_rate - self.logger.info("Training begin: using optimizer %s " - "with current learning rate %.4f ", - optimizer, lr) - if 'epochs' in kwargs: - self.logger.info("Train for %d epochs.", kwargs['epochs']) - - def train_end(self, *args, **kwargs): + trainer = estimator.trainer + optimizer = trainer.optimizer.__class__.__name__ + lr = trainer.learning_rate + self.logger.info("Training begin: using optimizer %s " + "with current learning rate %.4f ", + optimizer, lr) + self.logger.info("Train for %d epochs.", estimator.max_epochs) + + def train_end(self, estimator, *args, **kwargs): train_time = time.time() - self.train_start msg = 'Train finished using total %ds with %d epochs.' % (train_time, self.current_epoch) # log every result in train stats including train/validation loss & metrics @@ -186,15 +180,15 @@ def train_end(self, *args, **kwargs): msg += '%s : %.4f ' % (name, value) self.logger.info(msg) - def batch_begin(self, *args, **kwargs): + def batch_begin(self, estimator, *args, **kwargs): if self.verbose == self.LOG_VERBOSITY_PER_BATCH: self.batch_start = time.time() - def batch_end(self, *args, **kwargs): + def batch_end(self, estimator, *args, **kwargs): if self.verbose == self.LOG_VERBOSITY_PER_BATCH: batch_time = time.time() - self.batch_start msg = '[Epoch %d] [Batch %d] ' % (self.current_epoch, self.batch_index) - self.processed_samples += kwargs['batch_size'] + self.processed_samples += kwargs['batch'][0].shape[0] msg += '[Samples %s] ' % (self.processed_samples) msg += 'time/batch: %.3fs ' % batch_time for metric in self.train_metrics: @@ -204,11 +198,11 @@ def batch_end(self, *args, **kwargs): self.logger.info(msg) self.batch_index += 1 - def epoch_begin(self, *args, **kwargs): + def epoch_begin(self, estimator, *args, **kwargs): if self.verbose >= self.LOG_VERBOSITY_PER_EPOCH: self.epoch_start = time.time() - def epoch_end(self, *args, **kwargs): + def epoch_end(self, estimator, *args, **kwargs): if self.verbose >= self.LOG_VERBOSITY_PER_EPOCH: epoch_time = time.time() - self.epoch_start msg = '\n[Epoch %d] finished in %.3fs: ' % (self.current_epoch, epoch_time) @@ -227,8 +221,6 @@ class CheckpointHandler(BatchEnd, EpochEnd): Parameters ---------- - estimator : Estimator - The :py:class:`Estimator` to get training statistics filepath : str file name to save the parameters, it can contain directories, for example: ./saved_model/resnet.params @@ -287,14 +279,12 @@ def __init__(self, self.monitor_op = numpy.less self.best = numpy.Inf - def batch_end(self, *args, **kwargs): - net = kwargs['net'] - self._save_checkpoint(net, "Batch", self.num_batches) + def batch_end(self, estimator, *args, **kwargs): + self._save_checkpoint(estimator.net, "Batch", self.num_batches) self.num_batches += 1 - def epoch_end(self, *args, **kwargs): - net = kwargs['net'] - self._save_checkpoint(net, "Epoch", self.num_epochs) + def epoch_end(self, estimator, *args, **kwargs): + self._save_checkpoint(estimator.net, "Epoch", self.num_epochs) self.num_epochs += 1 def _save_checkpoint(self, net, period_name, period_value): @@ -336,7 +326,7 @@ class EarlyStoppingHandler(TrainBegin, EpochEnd, TrainEnd): ---------- estimator : Estimator The :py:class:`Estimator` to get training statistics - monitor: str + monitor: EvalMetric the metrics to monitor min_delta: float, default 0 minimal change in monitored value to be considered as an improvement @@ -390,7 +380,7 @@ def __init__(self, else: self.min_delta *= -1 - def train_begin(self, *args, **kwargs): + def train_begin(self, estimator, *args, **kwargs): self.wait = 0 self.stopped_epoch = 0 if self.baseline is not None: @@ -398,7 +388,7 @@ def train_begin(self, *args, **kwargs): else: self.best = numpy.Inf if self.monitor_op == numpy.less else -numpy.Inf - def epoch_end(self, *args, **kwargs): + def epoch_end(self, estimator, *args, **kwargs): monitor_name, monitor_value = self.monitor.get() if numpy.isnan(monitor_value): warnings.warn(RuntimeWarning('%s is not updated, make sure you pass one of the metric objects' @@ -415,7 +405,7 @@ def epoch_end(self, *args, **kwargs): self.stop_training = True return self.stop_training - def train_end(self, *args, **kwargs): + def train_end(self, estimator, *args, **kwargs): if self.stopped_epoch > 0: self.logger.info('Epoch %d: early stopping due to %s not improving', self.stopped_epoch, self.monitor.get()[0]) From f805fbd0c5769de5bd6a6adf30fe38a27f779faa Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Wed, 17 Apr 2019 12:18:05 -0700 Subject: [PATCH 04/10] fix unit test --- .../gluon/contrib/estimator/event_handler.py | 31 +++++++++---------- tests/python/unittest/test_gluon_estimator.py | 2 +- .../unittest/test_gluon_event_handler.py | 31 +++++++++---------- 3 files changed, 31 insertions(+), 33 deletions(-) diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py index 8700204545c5..aef38690be96 100644 --- a/python/mxnet/gluon/contrib/estimator/event_handler.py +++ b/python/mxnet/gluon/contrib/estimator/event_handler.py @@ -16,14 +16,13 @@ # under the License. # coding: utf-8 -# pylint: disable=wildcard-import """Gluon EventHandlers for Estimators""" import logging import os import time import warnings -import weakref +import numpy as np from ....metric import * @@ -265,19 +264,19 @@ def __init__(self, mode = 'auto' if mode == 'min': - self.monitor_op = numpy.less - self.best = numpy.Inf + self.monitor_op = np.less + self.best = np.Inf elif mode == 'max': - self.monitor_op = numpy.greater - self.best = -numpy.Inf + self.monitor_op = np.greater + self.best = -np.Inf else: # use greater for accuracy and less otherwise if 'acc' in self.monitor.get()[0].lower(): - self.monitor_op = numpy.greater - self.best = -numpy.Inf + self.monitor_op = np.greater + self.best = -np.Inf else: - self.monitor_op = numpy.less - self.best = numpy.Inf + self.monitor_op = np.less + self.best = np.Inf def batch_end(self, estimator, *args, **kwargs): self._save_checkpoint(estimator.net, "Batch", self.num_batches) @@ -366,16 +365,16 @@ def __init__(self, mode = 'auto' if mode == 'min': - self.monitor_op = numpy.less + self.monitor_op = np.less elif mode == 'max': - self.monitor_op = numpy.greater + self.monitor_op = np.greater else: if 'acc' in self.monitor.get()[0].lower(): - self.monitor_op = numpy.greater + self.monitor_op = np.greater else: - self.monitor_op = numpy.less + self.monitor_op = np.less - if self.monitor_op == numpy.greater: + if self.monitor_op == np.greater: self.min_delta *= 1 else: self.min_delta *= -1 @@ -386,7 +385,7 @@ def train_begin(self, estimator, *args, **kwargs): if self.baseline is not None: self.best = self.baseline else: - self.best = numpy.Inf if self.monitor_op == numpy.less else -numpy.Inf + self.best = np.Inf if self.monitor_op == np.less else -np.Inf def epoch_end(self, estimator, *args, **kwargs): monitor_name, monitor_value = self.monitor.get() diff --git a/tests/python/unittest/test_gluon_estimator.py b/tests/python/unittest/test_gluon_estimator.py index 6cc23e62ed11..379f029c2c72 100644 --- a/tests/python/unittest/test_gluon_estimator.py +++ b/tests/python/unittest/test_gluon_estimator.py @@ -251,7 +251,7 @@ def test_context(): metrics=metrics) # input list of context gpus = mx.context.num_gpus() - ctx = [mx.gpu(i) for i in gpus] if gpus > 0 else [mx.cpu()] + ctx = [mx.gpu(i) for i in range(gpus)] if gpus > 0 else [mx.cpu()] net = get_model() est = Estimator(net=net, loss=loss, diff --git a/tests/python/unittest/test_gluon_event_handler.py b/tests/python/unittest/test_gluon_event_handler.py index f1ccccf84946..e151281ea9bd 100644 --- a/tests/python/unittest/test_gluon_event_handler.py +++ b/tests/python/unittest/test_gluon_event_handler.py @@ -22,7 +22,7 @@ from mxnet import nd from mxnet.gluon import nn, loss from mxnet.gluon.contrib.estimator import estimator, event_handler - +from common import TemporaryDirectory def _get_test_network(): net = nn.Sequential() @@ -85,18 +85,17 @@ def test_early_stopping(): def test_logging(): - tmpdir = tempfile.mkdtemp() - test_data = _get_test_data() - file_name = 'test_log' - output_dir = os.path.join(tmpdir, file_name) - - net = _get_test_network() - ce_loss = loss.SoftmaxCrossEntropyLoss() - ce_loss_metric = mx.metric.Loss(ce_loss.name) - acc = mx.metric.Accuracy() - est = estimator.Estimator(net, loss=ce_loss, metrics=acc) - logging_handler = [event_handler.LoggingHandler(file_name=file_name, - file_location=tmpdir, train_metrics=[acc, ce_loss_metric])] - est.fit(test_data, event_handlers=logging_handler, epochs=1) - assert os.path.isfile(output_dir) - os.remove(output_dir) + with TemporaryDirectory() as tmpdir: + test_data = _get_test_data() + file_name = 'test_log' + output_dir = os.path.join(tmpdir, file_name) + + net = _get_test_network() + ce_loss = loss.SoftmaxCrossEntropyLoss() + ce_loss_metric = mx.metric.Loss(ce_loss.name) + acc = mx.metric.Accuracy() + est = estimator.Estimator(net, loss=ce_loss, metrics=acc) + logging_handler = [event_handler.LoggingHandler(file_name=file_name, + file_location=tmpdir, train_metrics=[acc, ce_loss_metric])] + est.fit(test_data, event_handlers=logging_handler, epochs=1) + assert os.path.isfile(output_dir) \ No newline at end of file From 962436aa4203a93b78ff3bf8be6b7a7b608629bb Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Wed, 17 Apr 2019 12:44:59 -0700 Subject: [PATCH 05/10] fix test --- python/mxnet/gluon/contrib/estimator/event_handler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py index aef38690be96..604d9e1386b6 100644 --- a/python/mxnet/gluon/contrib/estimator/event_handler.py +++ b/python/mxnet/gluon/contrib/estimator/event_handler.py @@ -23,7 +23,7 @@ import time import warnings import numpy as np -from ....metric import * +from ....metric import EvalMetric, Loss class TrainBegin(object): @@ -294,7 +294,7 @@ def _save_checkpoint(self, net, period_name, period_value): if self.save_best_only: monitor_name, monitor_value = self.monitor.get() # check if monitor exists in train stats - if numpy.isnan(monitor_value): + if np.isnan(monitor_value): warnings.warn(RuntimeWarning('%s is not updated, make sure you pass one of the metric objects' 'as monitor, you can use estimator.prepare_loss_and_metrics to' 'create all metric objects', monitor_name)) @@ -389,7 +389,7 @@ def train_begin(self, estimator, *args, **kwargs): def epoch_end(self, estimator, *args, **kwargs): monitor_name, monitor_value = self.monitor.get() - if numpy.isnan(monitor_value): + if np.isnan(monitor_value): warnings.warn(RuntimeWarning('%s is not updated, make sure you pass one of the metric objects' 'as monitor, you can use estimator.prepare_loss_and_metrics to' 'create all metric objects', monitor_name)) From c67dd541a246d3becb0913a05ef04b222050a1e6 Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Wed, 17 Apr 2019 16:28:10 -0700 Subject: [PATCH 06/10] fix pylint --- python/mxnet/gluon/contrib/estimator/estimator.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py index 512acc11d864..e6692c05ae17 100644 --- a/python/mxnet/gluon/contrib/estimator/estimator.py +++ b/python/mxnet/gluon/contrib/estimator/estimator.py @@ -160,10 +160,8 @@ def prepare_loss_and_metrics(self): Create metric wrappers to record loss values, Create copies of train loss/metric objects to record validation values """ - if all(hasattr(self, attribute) for attribute in + if any(not hasattr(self, attribute) for attribute in ['train_metrics', 'val_metrics']): - return self.train_metrics, self.val_metrics - else: self.val_metrics = [] for loss in self.loss: self.train_metrics.append(Loss("Train " + ''.join([i for i in loss.name if not i.isdigit()]))) @@ -173,7 +171,7 @@ def prepare_loss_and_metrics(self): metric.name = "Train " + metric.name val_metric.name = "Validation " + val_metric.name self.val_metrics.append(val_metric) - return self.train_metrics, self.val_metrics + return self.train_metrics, self.val_metrics def evaluate(self, val_data, From 29066d10da0dbf6d8ded60a27989f3e0061a725b Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Thu, 18 Apr 2019 09:27:31 -0700 Subject: [PATCH 07/10] fix test --- .../gluon/contrib/estimator/estimator.py | 13 +++-- .../gluon/contrib/estimator/event_handler.py | 58 +++++++++++++++++-- 2 files changed, 62 insertions(+), 9 deletions(-) diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py index e6692c05ae17..c8b4072528e5 100644 --- a/python/mxnet/gluon/contrib/estimator/estimator.py +++ b/python/mxnet/gluon/contrib/estimator/estimator.py @@ -20,9 +20,11 @@ """Gluon Estimator""" import copy +import warnings import weakref -from .event_handler import * +from .event_handler import MetricHandler, ValidationHandler, LoggingHandler +from .event_handler import TrainBegin, EpochBegin, BatchBegin, BatchEnd, EpochEnd, TrainEnd from .... import gluon, autograd from ....context import Context, cpu, gpu, num_gpus from ....metric import EvalMetric, Loss, Accuracy @@ -47,7 +49,7 @@ class Estimator(object): trainer : Trainer Trainer to apply optimizer on network parameters context : Context or list of Context - devices to run the training on + device(s) to run the training on """ def __init__(self, net, @@ -283,11 +285,14 @@ def fit(self, train_data, self.trainer.step(batch_size) # batch end for handler in batch_end: - if handler.batch_end(estimator_ref, batch=batch, pred=pred, label=label, loss=loss): break + if handler.batch_end(estimator_ref, batch=batch, + pred=pred, label=label, loss=loss): + break # epoch end for handler in epoch_end: - if handler.epoch_end(estimator_ref): break + if handler.epoch_end(estimator_ref): + break # train end for handler in train_end: diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py index 604d9e1386b6..8982ab2cdb22 100644 --- a/python/mxnet/gluon/contrib/estimator/event_handler.py +++ b/python/mxnet/gluon/contrib/estimator/event_handler.py @@ -22,7 +22,9 @@ import os import time import warnings + import numpy as np + from ....metric import EvalMetric, Loss @@ -57,11 +59,23 @@ def batch_end(self, estimator, *args, **kwargs): class MetricHandler(EpochBegin, BatchEnd): + """Metric Handler that update metric values at batch end + + :py:class:`MetricHandler` takes model predictions and true labels + and update the metrics, it also update metric wrapper for loss with loss values + Validation loss and metrics will be handled by :py:class:`ValidationHandler` + + Parameters + ---------- + train_metrics : List of EvalMetrics + training metrics to be updated at batch end + """ + def __init__(self, train_metrics): - self.train_metrics = train_metrics + self.train_metrics = train_metrics or [] # order to be called among all callbacks # metrics need to be calculated before other callbacks can access them - self.rank = 1 + self.priority = -np.Inf def epoch_begin(self, estimator, *args, **kwargs): for metric in self.train_metrics: @@ -80,6 +94,29 @@ def batch_end(self, estimator, *args, **kwargs): class ValidationHandler(BatchEnd, EpochEnd): + """"Validation Handler that evaluate model on validation dataset + + :py:class:`ValidationHandler` takes validation dataset, an evaluation function, + metrics to be evaluated, and how often to run the validation. You can provide custom + evaluation function or use the one provided my :py:class:`Estimator` + + Parameters + ---------- + val_data : DataLoader + validation data set to run evaluation + eval_fn : function + a function defines how to run evaluation and + calculate loss and metrics + val_metrics : List of EvalMetrics + validation metrics to be updated + epoch_period : int, default 1 + how often to run validation at epoch end, by default + validate every epoch + batch_period : int, default None + how often to run validation at batch end, by default + does not validate at batch end + """ + def __init__(self, val_data, eval_fn, @@ -95,7 +132,7 @@ def __init__(self, self.num_epochs = 0 # order to be called among all callbacks # validation metrics need to be calculated before other callbacks can access them - self.rank = 1 + self.priority = -np.Inf def batch_end(self, estimator, *args, **kwargs): if self.batch_period and self.num_batches % self.batch_period == 0: @@ -121,12 +158,16 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat ---------- file_name : str file name to save the logs - file_location: str + file_location : str file location to save the logs - verbose: int, default LOG_VERBOSITY_PER_EPOCH + verbose : int, default LOG_VERBOSITY_PER_EPOCH Limit the granularity of metrics displayed during training process verbose=LOG_VERBOSITY_PER_EPOCH: display metrics every epoch verbose=LOG_VERBOSITY_PER_BATCH: display metrics every batch + train_metrics : list of EvalMetrics + training metrics to be logged, logged at batch end, epoch end, train end + val_metrics : list of EvalMetrics + validation metrics to be logged, logged at epoch end, train end """ LOG_VERBOSITY_PER_EPOCH = 1 @@ -159,6 +200,9 @@ def __init__(self, file_name=None, self.batch_index = 0 self.current_epoch = 0 self.processed_samples = 0 + # logging handler need to be called at last to make sure all states are updated + # it will also shut down logging at train end + self.priority = np.Inf def train_begin(self, estimator, *args, **kwargs): self.train_start = time.time() @@ -178,6 +222,10 @@ def train_end(self, estimator, *args, **kwargs): name, value = metric.get() msg += '%s : %.4f ' % (name, value) self.logger.info(msg) + for handler in self.logger.handlers: + handler.close() + self.logger.removeHandler(handler) + logging.shutdown() def batch_begin(self, estimator, *args, **kwargs): if self.verbose == self.LOG_VERBOSITY_PER_BATCH: From 57c9b0af23cc186524468c429ad803c40157a98c Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Thu, 18 Apr 2019 10:05:58 -0700 Subject: [PATCH 08/10] fix pylint --- python/mxnet/gluon/contrib/estimator/estimator.py | 2 +- python/mxnet/gluon/contrib/estimator/event_handler.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py index c8b4072528e5..2e2cdac3b49a 100644 --- a/python/mxnet/gluon/contrib/estimator/estimator.py +++ b/python/mxnet/gluon/contrib/estimator/estimator.py @@ -16,7 +16,7 @@ # under the License. # coding: utf-8 -# pylint: disable=wildcard-import +# pylint: disable=wildcard-import, unused-variable """Gluon Estimator""" import copy diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py index 8982ab2cdb22..220aa31ab5dd 100644 --- a/python/mxnet/gluon/contrib/estimator/event_handler.py +++ b/python/mxnet/gluon/contrib/estimator/event_handler.py @@ -16,6 +16,7 @@ # under the License. # coding: utf-8 +# pylint: disable=wildcard-import, unused-argument """Gluon EventHandlers for Estimators""" import logging From 82c094deced2d4dc95b40e1cf2e42dc0d6eb30bd Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Thu, 18 Apr 2019 13:48:26 -0700 Subject: [PATCH 09/10] move default metric logic --- python/mxnet/gluon/contrib/estimator/estimator.py | 7 +++---- tests/python/unittest/test_gluon_estimator.py | 1 + 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py index 2e2cdac3b49a..78672d2f381a 100644 --- a/python/mxnet/gluon/contrib/estimator/estimator.py +++ b/python/mxnet/gluon/contrib/estimator/estimator.py @@ -63,10 +63,6 @@ def __init__(self, net, self.loss = self._check_loss(loss) self.train_metrics = self._check_metrics(metrics) - # Use default mx.metric.Accuracy() for gluon.loss.SoftmaxCrossEntropyLoss() - if not self.train_metrics and any([isinstance(l, gluon.loss.SoftmaxCrossEntropyLoss) for l in self.loss]): - self.train_metrics = [Accuracy()] - self.context = self._check_context(context) self._initialize(initializer) self.trainer = self._check_trainer(trainer) @@ -164,6 +160,9 @@ def prepare_loss_and_metrics(self): """ if any(not hasattr(self, attribute) for attribute in ['train_metrics', 'val_metrics']): + # Use default mx.metric.Accuracy() for gluon.loss.SoftmaxCrossEntropyLoss() + if not self.train_metrics and any([isinstance(l, gluon.loss.SoftmaxCrossEntropyLoss) for l in self.loss]): + self.train_metrics = [Accuracy()] self.val_metrics = [] for loss in self.loss: self.train_metrics.append(Loss("Train " + ''.join([i for i in loss.name if not i.isdigit()]))) diff --git a/tests/python/unittest/test_gluon_estimator.py b/tests/python/unittest/test_gluon_estimator.py index 379f029c2c72..6f19f435531b 100644 --- a/tests/python/unittest/test_gluon_estimator.py +++ b/tests/python/unittest/test_gluon_estimator.py @@ -221,6 +221,7 @@ def test_metric(): loss=loss, trainer=trainer, context=ctx) + est.prepare_loss_and_metrics() assert isinstance(est.train_metrics[0], mx.metric.Accuracy) From 119c60aa20af164cb339413ad5addfee974c23f8 Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Thu, 18 Apr 2019 16:40:06 -0700 Subject: [PATCH 10/10] combine nightly tests --- ci/docker/runtime_functions.sh | 18 +++--------------- tests/nightly/Jenkinsfile | 12 ++++++------ tests/nightly/JenkinsfileForBinaries | 16 ---------------- 3 files changed, 9 insertions(+), 37 deletions(-) diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh index 59ff22155ac8..b194ebb15b50 100755 --- a/ci/docker/runtime_functions.sh +++ b/ci/docker/runtime_functions.sh @@ -1296,31 +1296,19 @@ nightly_scala_demo_test_cpu() { bash bin/run_im.sh } -nightly_estimator_cnn_gpu() { +nightly_estimator_gpu() { set -ex cd /work/mxnet/tests/nightly/estimator export PYTHONPATH=/work/mxnet/python/ python test_estimator_cnn.py --type gpu -} - -nightly_estimator_cnn_cpu() { - set -ex - cd /work/mxnet/tests/nightly/estimator - export PYTHONPATH=/work/mxnet/python/ - python test_estimator_cnn.py --type cpu -} - -nightly_estimator_rnn_gpu() { - set -ex - cd /work/mxnet/tests/nightly/estimator - export PYTHONPATH=/work/mxnet/python/ python test_sentiment_rnn.py --type gpu } -nightly_estimator_rnn_cpu() { +nightly_estimator_cpu() { set -ex cd /work/mxnet/tests/nightly/estimator export PYTHONPATH=/work/mxnet/python/ + python test_estimator_cnn.py --type cpu python test_sentiment_rnn.py --type cpu } diff --git a/tests/nightly/Jenkinsfile b/tests/nightly/Jenkinsfile index a65da2d0b87e..1be084c9d3f5 100755 --- a/tests/nightly/Jenkinsfile +++ b/tests/nightly/Jenkinsfile @@ -137,19 +137,19 @@ core_logic: { } } }, - 'estimator: RNN GPU': { + 'Gluon estimator: GPU': { node(NODE_LINUX_GPU) { - ws('workspace/estimator-test-rnn-gpu') { + ws('workspace/estimator-test-gpu') { utils.unpack_and_init('gpu', mx_lib) - utils.docker_run('ubuntu_nightly_gpu', 'nightly_estimator_test_rnn_gpu', true) + utils.docker_run('ubuntu_nightly_gpu', 'nightly_estimator_gpu', true) } } }, - 'estimator: RNN CPU': { + 'Gluon estimator: CPU': { node(NODE_LINUX_CPU) { - ws('workspace/estimator-test-rnn-cpu') { + ws('workspace/estimator-test-cpu') { utils.unpack_and_init('cpu', mx_lib) - utils.docker_run('ubuntu_nightly_cpu', 'nightly_estimator_test_rnn_cpu', false) + utils.docker_run('ubuntu_nightly_cpu', 'nightly_estimator_cpu', false) } } } diff --git a/tests/nightly/JenkinsfileForBinaries b/tests/nightly/JenkinsfileForBinaries index 53572c85f513..53e1c30e188f 100755 --- a/tests/nightly/JenkinsfileForBinaries +++ b/tests/nightly/JenkinsfileForBinaries @@ -106,22 +106,6 @@ core_logic: { utils.docker_run('ubuntu_nightly_gpu', 'nightly_tutorial_test_ubuntu_python3_gpu', true, '1500m') } } - }, - 'estimator: CNN GPU': { - node(NODE_LINUX_GPU) { - ws('workspace/estimator-test-cnn-gpu') { - utils.unpack_and_init('gpu', mx_lib) - utils.docker_run('ubuntu_nightly_gpu', 'nightly_estimator_test_cnn_gpu', true) - } - } - }, - 'estimator: CNN CPU': { - node(NODE_LINUX_CPU) { - ws('workspace/estimator-test-cnn-cpu') { - utils.unpack_and_init('cpu', mx_lib) - utils.docker_run('ubuntu_nightly_cpu', 'nightly_estimator_test_cnn_cpu', true) - } - } } } }