diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh index e1da222ca298..58e39efc2873 100755 --- a/ci/docker/runtime_functions.sh +++ b/ci/docker/runtime_functions.sh @@ -1350,6 +1350,16 @@ nightly_scala_demo_test_cpu() { bash bin/run_im.sh } +nightly_estimator() { + set -ex + cd /work/mxnet/tests/nightly/estimator + export PYTHONPATH=/work/mxnet/python/ + python test_estimator_cnn.py --type gpu + python test_sentiment_rnn.py --type gpu + python test_estimator_cnn.py --type cpu + python test_sentiment_rnn.py --type cpu +} + # Deploy deploy_docs() { diff --git a/python/mxnet/gluon/contrib/estimator/__init__.py b/python/mxnet/gluon/contrib/estimator/__init__.py new file mode 100644 index 000000000000..58600dadffb4 --- /dev/null +++ b/python/mxnet/gluon/contrib/estimator/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=wildcard-import +"""Gluon Estimator Module""" +from .estimator import * +from .event_handler import * diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py new file mode 100644 index 000000000000..da1a3915caec --- /dev/null +++ b/python/mxnet/gluon/contrib/estimator/estimator.py @@ -0,0 +1,408 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import, unused-variable +"""Gluon Estimator""" + +import copy +import warnings + +from .event_handler import MetricHandler, ValidationHandler, LoggingHandler, StoppingHandler +from .event_handler import TrainBegin, EpochBegin, BatchBegin, BatchEnd, EpochEnd, TrainEnd +from .... import gluon, autograd +from ....context import Context, cpu, gpu, num_gpus +from ....metric import EvalMetric, Loss, Accuracy + +__all__ = ['Estimator'] + + +class Estimator(object): + """Estimator Class for easy model training + + :py:class:`Estimator` can be used to facilitate the training & validation process + + + Parameters + ---------- + net : Block + The model used for training. + loss : gluon.loss.Loss or list of gluon.loss.Loss + Loss(objective functions) to calculate during training. + metrics : EvalMetric or list of EvalMetric + Metrics for evaluating models. + initializer : Initializer + Initializer to initialize the network. + trainer : Trainer + Trainer to apply optimizer on network parameters. + context : Context or list of Context + Device(s) to run the training on. + """ + + def __init__(self, net, + loss, + metrics=None, + initializer=None, + trainer=None, + context=None): + + self.net = net + self.loss = self._check_loss(loss) + self.train_metrics = self._check_metrics(metrics) + + self.context = self._check_context(context) + self._initialize(initializer) + self.trainer = self._check_trainer(trainer) + + def _check_loss(self, loss): + if isinstance(loss, gluon.loss.Loss): + loss = [loss] + elif isinstance(loss, list) and all([isinstance(l, gluon.loss.Loss) for l in loss]): + loss = loss + else: + raise ValueError("loss must be a Loss or a list of Loss, " + "refer to gluon.loss.Loss:{}".format(loss)) + return loss + + def _check_metrics(self, metrics): + if isinstance(metrics, EvalMetric): + metrics = [metrics] + else: + metrics = metrics or [] + if not all([isinstance(metric, EvalMetric) for metric in metrics]): + raise ValueError("metrics must be a Metric or a list of Metric, " + "refer to mxnet.metric.EvalMetric:{}".format(metrics)) + return metrics + + def _check_context(self, context): + # infer available context + gpus = num_gpus() + available_gpus = [gpu(i) for i in range(gpus)] + + if context: + # check context values, only accept Context or a list of Context + if isinstance(context, Context): + context = [context] + elif isinstance(context, list) and all([isinstance(c, Context) for c in context]): + context = context + else: + raise ValueError("context must be a Context or a list of Context, " + "for example mx.cpu() or [mx.gpu(0), mx.gpu(1)], " + "refer to mxnet.Context:{}".format(context)) + for ctx in context: + assert ctx in available_gpus or str(ctx).startswith('cpu'), \ + "%s is not available, please make sure " \ + "your context is in one of: mx.cpu(), %s" % \ + (ctx, ", ".join([str(ctx) for ctx in available_gpus])) + else: + # provide default context + if gpus > 0: + # only use 1 GPU by default + if gpus > 1: + warnings.warn("You have multiple GPUs, gpu(0) will be used by default." + "To utilize all your GPUs, specify context as a list of gpus, " + "e.g. context=[mx.gpu(0), mx.gpu(1)] ") + context = [gpu(0)] + else: + context = [cpu()] + return context + + def _initialize(self, initializer): + # initialize the network + if not self._is_initialized(): + # net is partially or not initialized, + # initialize with user specified initializer + # if initializer is None, default initializer will be used + # do not re-init layers already initialized + if initializer: + self.net.initialize(init=initializer, ctx=self.context) + else: + self.net.initialize(ctx=self.context) + elif initializer: + # net is fully initialized, and user passed not None initializer + # do not force reinitialize, give warning + warnings.warn("Network already fully initialized, skipping initialization. " + "You don't need to pass initializer if you already " + "initialized your net. " + "You can use net.initialize(init=your_initializer, force_reinit=True)" + "to force re-initialize.") + + def _check_trainer(self, trainer): + # handle trainer + if not trainer: + warnings.warn("No trainer specified, default SGD optimizer " + "with learning rate 0.001 is used.") + trainer = gluon.Trainer(self.net.collect_params(), + 'sgd', {'learning_rate': 0.001}) + elif not isinstance(trainer, gluon.Trainer): + raise ValueError("Trainer must be a Gluon Trainer instance, refer to " + "gluon.Trainer:{}".format(trainer)) + return trainer + + def _is_initialized(self): + param_dict = self.net.collect_params() + for param in param_dict: + try: + param_dict[param].list_ctx() + except RuntimeError: + return False + return True + + def _get_data_and_label(self, batch, ctx, batch_axis=0): + data = batch[0] + label = batch[1] + data = gluon.utils.split_and_load(data, ctx_list=ctx, batch_axis=batch_axis) + label = gluon.utils.split_and_load(label, ctx_list=ctx, batch_axis=batch_axis) + return data, label + + def prepare_loss_and_metrics(self): + """ + Based on loss functions and training metrics in estimator + Create metric wrappers to record loss values, + Create copies of train loss/metric objects to record validation values + Returns train_metrics and val_metrics + + """ + if any(not hasattr(self, attribute) for attribute in + ['train_metrics', 'val_metrics']): + # Use default mx.metric.Accuracy() for gluon.loss.SoftmaxCrossEntropyLoss() + if not self.train_metrics and any([isinstance(l, gluon.loss.SoftmaxCrossEntropyLoss) for l in self.loss]): + self.train_metrics = [Accuracy()] + self.val_metrics = [] + for loss in self.loss: + # remove trailing numbers from loss name to avoid confusion + self.train_metrics.append(Loss(loss.name.rstrip('1234567890'))) + for metric in self.train_metrics: + val_metric = copy.deepcopy(metric) + metric.name = "train " + metric.name + val_metric.name = "validation " + val_metric.name + self.val_metrics.append(val_metric) + return self.train_metrics, self.val_metrics + + def evaluate(self, + val_data, + val_metrics, + batch_axis=0): + """Evaluate model on validation data + + Parameters + ---------- + val_data : DataLoader + Validation data loader with data and labels. + val_metrics : EvalMetric or list of EvalMetrics + Metrics to update validation result. + batch_axis : int, default 0 + Batch axis to split the validation data into devices. + """ + if not isinstance(val_data, gluon.data.DataLoader): + raise ValueError("Estimator only support input as Gluon DataLoader. Alternatively, you " + "can transform your DataIter or any NDArray into Gluon DataLoader. " + "Refer to gluon.data.dataloader") + + for metric in val_metrics: + metric.reset() + + for _, batch in enumerate(val_data): + data, label = self._get_data_and_label(batch, self.context, batch_axis) + pred = [self.net(x) for x in data] + loss = [self.loss[0](y_hat, y) for y_hat, y in zip(pred, label)] + # update metrics + for metric in val_metrics: + if isinstance(metric, Loss): + metric.update(0, loss) + else: + metric.update(label, pred) + + def fit(self, train_data, + val_data=None, + epochs=None, + event_handlers=None, + batches=None, + batch_axis=0): + """Trains the model with a given :py:class:`DataLoader` for a specified + number of epochs or batches. The batch size is inferred from the + data loader's batch_size. + + Parameters + ---------- + train_data : DataLoader + Training data loader with data and labels. + val_data : DataLoader, default None + Validation data loader with data and labels. + epochs : int, default None + Number of epochs to iterate on the training data. + You can only specify one and only one type of iteration(epochs or batches). + event_handlers : EventHandler or list of EventHandler + List of :py:class:`EventHandlers` to apply during training. + batches : int, default None + Number of batches to iterate on the training data. + You can only specify one and only one type of iteration(epochs or batches). + batch_axis : int, default 0 + Batch axis to split the training data into devices. + """ + if not isinstance(train_data, gluon.data.DataLoader): + raise ValueError("Estimator only support input as Gluon DataLoader. Alternatively, you " + "can transform your DataIter or any NDArray into Gluon DataLoader. " + "Refer to gluon.data.dataloader") + + # must specify one and only one of epochs or batches + if (not epochs) == (not batches): + raise ValueError( + "Fit only support exactly one type of iteration, " + "train by number of epochs or number of batches." + "Please specify one and only one of: epochs or batches.") + + self.max_epoch = epochs + self.max_batch = batches + + # provide default handlers + event_handlers = self._prepare_default_handlers(val_data, event_handlers) + + train_begin, epoch_begin, batch_begin, \ + batch_end, epoch_end, train_end = self._categorize_handlers(event_handlers) + + # pass a reference to all event handlers + estimator_ref = self + # training begin + for handler in train_begin: + handler.train_begin(estimator_ref) + + while True: + # epoch begin + for handler in epoch_begin: + handler.epoch_begin(estimator_ref) + + for i, batch in enumerate(train_data): + data, label = self._get_data_and_label(batch, self.context, batch_axis) + + batch_size = batch[0].shape[0] + + # batch begin + for handler in batch_begin: + handler.batch_begin(estimator_ref, batch=batch) + + with autograd.record(): + pred = [self.net(x) for x in data] + loss = [self.loss[0](y_hat, y) for y_hat, y in zip(pred, label)] + + for l in loss: + l.backward() + + self.trainer.step(batch_size) + # batch end + + batch_end_result = [] + for handler in batch_end: + batch_end_result.append(handler.batch_end(estimator_ref, batch=batch, + pred=pred, label=label, loss=loss)) + # if any handler signaled to stop + if any(batch_end_result): + break + + # epoch end + epoch_end_result = [] + for handler in epoch_end: + epoch_end_result.append(handler.epoch_end(estimator_ref)) + # if any handler signaled to stop + if any(epoch_end_result): + break + + # train end + for handler in train_end: + handler.train_end(estimator_ref) + + def _prepare_default_handlers(self, val_data, event_handlers): + event_handlers = event_handlers or [] + default_handlers = [] + train_metrics, val_metrics = self.prepare_loss_and_metrics() + + # no need to add to default handler check as StoppingHandler does not use metrics + event_handlers.append(StoppingHandler(self.max_epoch, self.max_batch)) + + if not any(isinstance(handler, MetricHandler) for handler in event_handlers): + event_handlers.append(MetricHandler(train_metrics=train_metrics)) + default_handlers.append("MetricHandler") + + if val_data and not any(isinstance(handler, ValidationHandler) for handler in event_handlers): + event_handlers.append(ValidationHandler(val_data=val_data, eval_fn=self.evaluate, + val_metrics=val_metrics)) + default_handlers.append("ValidationHandler") + + if not any(isinstance(handler, LoggingHandler) for handler in event_handlers): + event_handlers.append(LoggingHandler(train_metrics=train_metrics, + val_metrics=val_metrics)) + default_handlers.append("LoggingHandler") + + # if there is a mix of user defined event handlers and default event handlers + # they should have the same set of loss and metrics + if default_handlers: + msg = "You are training with the following default event handlers: %s. " \ + "They use loss and metrics from estimator.prepare_loss_and_metrics(). " \ + "Please use the same set of metrics for all your other handlers." % \ + ", ".join(default_handlers) + warnings.warn(msg) + # check if all handlers has the same set of references to loss and metrics + references = [] + for handler in event_handlers: + for attribute in dir(handler): + if any(keyword in attribute for keyword in ['metric' or 'monitor']): + reference = getattr(handler, attribute) + if isinstance(reference, list): + references += reference + else: + references.append(reference) + # remove None metric references + references = set([ref for ref in references if ref]) + for metric in references: + if metric not in train_metrics + val_metrics: + msg = "We have added following default handlers for you: %s and used " \ + "estimator.prepare_loss_and_metrics() to pass metrics to " \ + "those handlers. Please use the same set of metrics " \ + "for all your handlers." % \ + ", ".join(default_handlers) + raise ValueError(msg) + + event_handlers.sort(key=lambda handler: getattr(handler, 'priority', 0)) + return event_handlers + + def _categorize_handlers(self, event_handlers): + """ + categorize handlers into 6 event lists to avoid calling empty methods + for example, only event handlers with train_begin method + implemented will be called at train begin + """ + + train_begin = [] + epoch_begin = [] + batch_begin = [] + batch_end = [] + epoch_end = [] + train_end = [] + for handler in event_handlers: + if isinstance(handler, TrainBegin): + train_begin.append(handler) + if isinstance(handler, EpochBegin): + epoch_begin.append(handler) + if isinstance(handler, BatchBegin): + batch_begin.append(handler) + if isinstance(handler, BatchEnd): + batch_end.append(handler) + if isinstance(handler, EpochEnd): + epoch_end.append(handler) + if isinstance(handler, TrainEnd): + train_end.append(handler) + return train_begin, epoch_begin, batch_begin, batch_end, epoch_end, train_end diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py new file mode 100644 index 000000000000..ce5890e0bcae --- /dev/null +++ b/python/mxnet/gluon/contrib/estimator/event_handler.py @@ -0,0 +1,705 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import, unused-argument +"""Gluon EventHandlers for Estimators""" + +import logging +import os +import time +import warnings + +import numpy as np + +from ....metric import EvalMetric, Loss + + +class TrainBegin(object): + def train_begin(self, estimator, *args, **kwargs): + pass + + +class TrainEnd(object): + def train_end(self, estimator, *args, **kwargs): + pass + + +class EpochBegin(object): + def epoch_begin(self, estimator, *args, **kwargs): + pass + + +class EpochEnd(object): + def epoch_end(self, estimator, *args, **kwargs): + return False + + +class BatchBegin(object): + def batch_begin(self, estimator, *args, **kwargs): + pass + + +class BatchEnd(object): + def batch_end(self, estimator, *args, **kwargs): + return False + + +class StoppingHandler(TrainBegin, BatchEnd, EpochEnd): + """Stop conditions to stop training + Stop training if maximum number of batches or epochs + reached. + + Parameters + ---------- + max_epoch : int, default None + Number of maximum epochs to train. + max_batch : int, default None + Number of maximum batches to train. + + """ + + def __init__(self, max_epoch=None, max_batch=None): + self.max_epoch = max_epoch + self.max_batch = max_batch + self.current_batch = 0 + self.current_epoch = 0 + self.stop_training = False + + def train_begin(self, estimator, *args, **kwargs): + self.max_epoch = estimator.max_epoch + self.max_batch = estimator.max_batch + self.current_batch = 0 + self.current_epoch = 0 + + def batch_end(self, estimator, *args, **kwargs): + self.current_batch += 1 + if self.current_batch == self.max_batch: + self.stop_training = True + return self.stop_training + + def epoch_end(self, estimator, *args, **kwargs): + self.current_epoch += 1 + if self.current_epoch == self.max_epoch: + self.stop_training = True + return self.stop_training + + +class MetricHandler(EpochBegin, BatchEnd): + """Metric Handler that update metric values at batch end + + :py:class:`MetricHandler` takes model predictions and true labels + and update the metrics, it also update metric wrapper for loss with loss values. + Validation loss and metrics will be handled by :py:class:`ValidationHandler` + + Parameters + ---------- + train_metrics : List of EvalMetrics + Training metrics to be updated at batch end. + """ + + def __init__(self, train_metrics): + self.train_metrics = train_metrics or [] + # order to be called among all callbacks + # metrics need to be calculated before other callbacks can access them + self.priority = -np.Inf + + def epoch_begin(self, estimator, *args, **kwargs): + for metric in self.train_metrics: + metric.reset() + + def batch_end(self, estimator, *args, **kwargs): + pred = kwargs['pred'] + label = kwargs['label'] + loss = kwargs['loss'] + for metric in self.train_metrics: + if isinstance(metric, Loss): + # metric wrapper for loss values + metric.update(0, loss) + else: + metric.update(label, pred) + + +class ValidationHandler(TrainBegin, BatchEnd, EpochEnd): + """"Validation Handler that evaluate model on validation dataset + + :py:class:`ValidationHandler` takes validation dataset, an evaluation function, + metrics to be evaluated, and how often to run the validation. You can provide custom + evaluation function or use the one provided my :py:class:`Estimator` + + Parameters + ---------- + val_data : DataLoader + Validation data set to run evaluation. + eval_fn : function + A function defines how to run evaluation and + calculate loss and metrics. + val_metrics : List of EvalMetrics + Validation metrics to be updated. + epoch_period : int, default 1 + How often to run validation at epoch end, by default + :py:class:`ValidationHandler` validate every epoch. + batch_period : int, default None + How often to run validation at batch end, by default + :py:class:`ValidationHandler` does not validate at batch end. + """ + + def __init__(self, + val_data, + eval_fn, + val_metrics=None, + epoch_period=1, + batch_period=None): + self.val_data = val_data + self.eval_fn = eval_fn + self.epoch_period = epoch_period + self.batch_period = batch_period + self.val_metrics = val_metrics + self.current_batch = 0 + self.current_epoch = 0 + # order to be called among all callbacks + # validation metrics need to be calculated before other callbacks can access them + self.priority = -np.Inf + self.logger = logging.getLogger(__name__) + + def train_begin(self, estimator, *args, **kwargs): + # reset epoch and batch counter + self.current_batch = 0 + self.current_epoch = 0 + + def batch_end(self, estimator, *args, **kwargs): + self.current_batch += 1 + if self.batch_period and self.current_batch % self.batch_period == 0: + self.eval_fn(val_data=self.val_data, + val_metrics=self.val_metrics) + msg = '[Epoch %d] ValidationHandler: %d batches reached, ' \ + % (self.current_epoch, self.current_batch) + for monitor in self.val_metrics: + name, value = monitor.get() + msg += '%s: %.4f, ' % (name, value) + self.logger.info(msg.rstrip(',')) + + def epoch_end(self, estimator, *args, **kwargs): + self.current_epoch += 1 + if self.epoch_period and self.current_epoch % self.epoch_period == 0: + self.eval_fn(val_data=self.val_data, + val_metrics=self.val_metrics) + + +class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, BatchEnd): + """Basic Logging Handler that applies to every Gluon estimator by default. + + :py:class:`LoggingHandler` logs hyper-parameters, training statistics, + and other useful information during training + + Parameters + ---------- + file_name : str + File name to save the logs. + file_location : str + File location to save the logs. + filemode : str, default 'a' + Logging file mode, default using append mode. + verbose : int, default LOG_PER_EPOCH + Limit the granularity of metrics displayed during training process. + verbose=LOG_PER_EPOCH: display metrics every epoch + verbose=LOG_PER_BATCH: display metrics every batch + train_metrics : list of EvalMetrics + Training metrics to be logged, logged at batch end, epoch end, train end. + val_metrics : list of EvalMetrics + Validation metrics to be logged, logged at epoch end, train end. + """ + + LOG_PER_EPOCH = 1 + LOG_PER_BATCH = 2 + + def __init__(self, file_name=None, + file_location=None, + filemode='a', + verbose=LOG_PER_EPOCH, + train_metrics=None, + val_metrics=None): + super(LoggingHandler, self).__init__() + self.logger = logging.getLogger(__name__) + self.logger.setLevel(logging.INFO) + stream_handler = logging.StreamHandler() + self.logger.addHandler(stream_handler) + # save logger to file only if file name or location is specified + if file_name or file_location: + file_name = file_name or 'estimator_log' + file_location = file_location or './' + file_handler = logging.FileHandler(os.path.join(file_location, file_name), mode=filemode) + self.logger.addHandler(file_handler) + if verbose not in [self.LOG_PER_EPOCH, self.LOG_PER_BATCH]: + raise ValueError("verbose level must be either LOG_PER_EPOCH or " + "LOG_PER_BATCH, received %s. " + "E.g: LoggingHandler(verbose=LoggingHandler.LOG_PER_EPOCH)" + % verbose) + self.verbose = verbose + self.train_metrics = train_metrics or [] + self.val_metrics = val_metrics or [] + self.batch_index = 0 + self.current_epoch = 0 + self.processed_samples = 0 + # logging handler need to be called at last to make sure all states are updated + # it will also shut down logging at train end + self.priority = np.Inf + + def train_begin(self, estimator, *args, **kwargs): + self.train_start = time.time() + trainer = estimator.trainer + optimizer = trainer.optimizer.__class__.__name__ + lr = trainer.learning_rate + self.logger.info("Training begin: using optimizer %s " + "with current learning rate %.4f ", + optimizer, lr) + if estimator.max_epoch: + self.logger.info("Train for %d epochs.", estimator.max_epoch) + else: + self.logger.info("Train for %d batches.", estimator.max_batch) + # reset all counters + self.current_epoch = 0 + self.batch_index = 0 + self.processed_samples = 0 + + def train_end(self, estimator, *args, **kwargs): + train_time = time.time() - self.train_start + msg = 'Train finished using total %ds with %d epochs. ' % (train_time, self.current_epoch) + # log every result in train stats including train/validation loss & metrics + for metric in self.train_metrics + self.val_metrics: + name, value = metric.get() + msg += '%s: %.4f, ' % (name, value) + self.logger.info(msg.rstrip(', ')) + # make a copy of handler list and remove one by one + # as removing handler will edit the handler list + for handler in self.logger.handlers[:]: + handler.close() + self.logger.removeHandler(handler) + logging.shutdown() + + def batch_begin(self, estimator, *args, **kwargs): + if self.verbose == self.LOG_PER_BATCH: + self.batch_start = time.time() + + def batch_end(self, estimator, *args, **kwargs): + if self.verbose == self.LOG_PER_BATCH: + batch_time = time.time() - self.batch_start + msg = '[Epoch %d][Batch %d]' % (self.current_epoch, self.batch_index) + self.processed_samples += kwargs['batch'][0].shape[0] + msg += '[Samples %s] ' % (self.processed_samples) + msg += 'time/batch: %.3fs ' % batch_time + for metric in self.train_metrics: + # only log current training loss & metric after each batch + name, value = metric.get() + msg += '%s: %.4f, ' % (name, value) + self.logger.info(msg.rstrip(', ')) + self.batch_index += 1 + + def epoch_begin(self, estimator, *args, **kwargs): + if self.verbose >= self.LOG_PER_EPOCH: + self.epoch_start = time.time() + self.logger.info("[Epoch %d] Begin, current learning rate: %.4f", + self.current_epoch, estimator.trainer.learning_rate) + + def epoch_end(self, estimator, *args, **kwargs): + if self.verbose >= self.LOG_PER_EPOCH: + epoch_time = time.time() - self.epoch_start + msg = '[Epoch %d] Finished in %.3fs, ' % (self.current_epoch, epoch_time) + for monitor in self.train_metrics + self.val_metrics: + name, value = monitor.get() + msg += '%s: %.4f, ' % (name, value) + self.logger.info(msg.rstrip(', ')) + self.current_epoch += 1 + self.batch_index = 0 + + +class CheckpointHandler(TrainBegin, BatchEnd, EpochEnd): + """Save the model after user define period + + :py:class:`CheckpointHandler` saves the network architecture after first batch if the model + can be fully hybridized, saves model parameters and trainer states after user defined period, + default saves every epoch. + + Parameters + ---------- + model_dir : str + File directory to save all the model related files including model architecture, + model parameters, and trainer states. + model_prefix : str default 'model' + Prefix to add for all checkpoint file names. + monitor: EvalMetric, default None + The metrics to monitor and determine if model has improved + verbose: int, default 0 + Verbosity mode, 1 means inform user every time a checkpoint is saved + save_best: bool, default False + If True, monitor must not be None, :py:class:`CheckpointHandler` will save the + model parameters and trainer states with the best monitored value. + mode: str, default 'auto' + One of {auto, min, max}, if `save_best=True`, the comparison to make + and determine if the monitored value has improved. if 'auto' mode, + :py:class:`CheckpointHandler` will try to use min or max based on + the monitored metric name. + epoch_period: int, default 1 + Epoch intervals between saving the network. By default, checkpoints are + saved every epoch. + batch_period: int, default None + Batch intervals between saving the network. + By default, checkpoints are not saved based on the number of batches. + max_checkpoints : int, default 5 + Maximum number of checkpoint files to keep in the model_dir, older checkpoints + will be removed. Best checkpoint file is not counted. + resume_from_checkpoint : bool, default False + Whether to resume training from checkpoint in model_dir. If True and checkpoints + found, :py:class:`CheckpointHandler` will load net parameters and trainer states, + and train the remaining of epochs and batches. + """ + + def __init__(self, + model_dir, + model_prefix='model', + monitor=None, + verbose=0, + save_best=False, + mode='auto', + epoch_period=1, + batch_period=None, + max_checkpoints=5, + resume_from_checkpoint=False): + self.monitor = monitor + self.verbose = verbose + if not os.path.exists(model_dir): + os.makedirs(model_dir) + self.model_dir = model_dir + self.model_prefix = model_prefix + self.save_best = save_best + if self.save_best and not isinstance(self.monitor, EvalMetric): + raise ValueError("To save best model only, please provide one of the metric objects as monitor, " + "You can get these objects using estimator.prepare_loss_and_metric()") + self.epoch_period = epoch_period + self.batch_period = batch_period + self.current_batch = 0 + self.current_epoch = 0 + self.max_checkpoints = max_checkpoints + self.resume_from_checkpoint = resume_from_checkpoint + self.saved_checkpoints = [] + self.logger = logging.getLogger(__name__) + if self.save_best: + if mode not in ['auto', 'min', 'max']: + warnings.warn('ModelCheckpoint mode %s is unknown, ' + 'fallback to auto mode. CheckpointHandler will use' + 'max mode for f1 and accuracy metric comparison and ' + 'use min mode other wise' % (mode), + RuntimeWarning) + mode = 'auto' + + if mode == 'min': + self.monitor_op = np.less + self.best = np.Inf + elif mode == 'max': + self.monitor_op = np.greater + self.best = -np.Inf + else: + # use greater for accuracy and f1 and less otherwise + if 'acc' or 'f1' in self.monitor.get()[0].lower(): + self.logger.info("`greater` operator will be used to determine " + "if %s has improved, please use `min` for mode " + "if you want otherwise", self.monitor.get()[0]) + self.monitor_op = np.greater + else: + self.logger.info("`less` operator will be used to determine " + "if %s has improved, please use `max` for mode " + "if you want otherwise", self.monitor.get()[0]) + self.monitor_op = np.less + + def train_begin(self, estimator, *args, **kwargs): + # reset all counters + self.current_epoch = 0 + self.current_batch = 0 + if self.save_best: + self.best = np.Inf if self.monitor_op == np.less else -np.Inf + if self.resume_from_checkpoint: + error_msg = "To use resume from checkpoint, you must only specify " \ + "the same type of period you used for training." \ + "For example, if you are training based on number of epochs," \ + "you must save only based on epochs, and set batch_period to None." + if estimator.max_batch: + assert self.batch_period, error_msg + assert not self.epoch_period, error_msg + if estimator.max_epoch: + assert self.epoch_period, error_msg + assert not self.batch_period, error_msg + + self._resume_from_checkpoint(estimator) + + def batch_end(self, estimator, *args, **kwargs): + # only save symbol once after first batch + if self.current_batch == 0: + self._save_symbol(estimator) + if self.batch_period and (self.current_batch + 1) % self.batch_period == 0: + self._save_checkpoint(estimator) + self.current_batch += 1 + + def epoch_end(self, estimator, *args, **kwargs): + if self.epoch_period and (self.current_epoch + 1) % self.epoch_period == 0: + self._save_checkpoint(estimator) + self.current_epoch += 1 + + def _save_checkpoint(self, estimator): + # if resumed from checkpoint, increment checkpoint number + if self.resume_from_checkpoint: + save_epoch_number = self.current_epoch + self.trained_epoch + 1 + if estimator.max_epoch: + # checkpoint saved at epoch end, batch number already incremented + save_batch_number = self.current_batch + self.trained_batch + else: + save_batch_number = self.current_batch + self.trained_batch + 1 + else: + save_epoch_number = self.current_epoch + save_batch_number = self.current_batch + prefix = "%s-epoch%dbatch%d" % (self.model_prefix, save_epoch_number, save_batch_number) + self._save_params_and_trainer(estimator, prefix) + if self.verbose > 0: + self.logger.info('[Epoch %d] CheckpointHandler: trained total %d batches, ' + 'saving model at %s with prefix: %s', + self.current_epoch, self.current_batch + 1, self.model_dir, prefix) + + if self.save_best: + monitor_name, monitor_value = self.monitor.get() + # check if monitor exists in train stats + if np.isnan(monitor_value): + warnings.warn(RuntimeWarning('Skipping save best because %s is not updated, make sure you ' + 'pass one of the metric objects as monitor, ' + 'you can use estimator.prepare_loss_and_metrics to' + 'create all metric objects', monitor_name)) + else: + if self.monitor_op(monitor_value, self.best): + prefix = self.model_prefix + '-best' + self._save_params_and_trainer(estimator, prefix) + self.best = monitor_value + if self.verbose > 0: + self.logger.info('[Epoch %d] CheckpointHandler: ' + '%s improved from %0.5f to %0.5f, ' + 'updating best model at %s with prefix: %s', + self.current_epoch, monitor_name, + self.best, monitor_value, self.model_dir, prefix) + else: + if self.verbose > 0: + self.logger.info('[Epoch %d] CheckpointHandler: ' + '%s did not improve from %0.5f, ' + 'skipping updating best model', + self.current_batch, monitor_name, + self.best) + + def _save_symbol(self, estimator): + symbol_file = os.path.join(self.model_dir, self.model_prefix + '-symbol.json') + if hasattr(estimator.net, '_cached_graph'): + sym = estimator.net._cached_graph[1] + sym.save(symbol_file) + else: + self.logger.info("Model architecture(symbol file) is not saved, please use HybridBlock" + "to construct your model, can call net.hybridize() before passing to" + "Estimator in order to save model architecture as %s.", symbol_file) + + def _save_params_and_trainer(self, estimator, file_prefix): + param_file = os.path.join(self.model_dir, file_prefix + '.params') + trainer_file = os.path.join(self.model_dir, file_prefix + '.states') + estimator.net.save_parameters(param_file) + estimator.trainer.save_states(trainer_file) + + # only count checkpoints with epoch or batch number in file name + if 'best' not in file_prefix: + self.saved_checkpoints.append(file_prefix) + # remove old checkpoint when max number of checkpoints reached + if len(self.saved_checkpoints) > self.max_checkpoints: + prefix = self.saved_checkpoints.pop(0) + for fname in os.listdir(self.model_dir): + if fname.startswith(prefix): + os.remove(os.path.join(self.model_dir, fname)) + + def _resume_from_checkpoint(self, estimator): + prefix = self.model_prefix + '-epoch' + self.trained_epoch = self._find_max_iteration( + dir=self.model_dir, + prefix=prefix, + start='epoch', + end='batch', + saved_checkpoints=self.saved_checkpoints) + prefix += str(self.trained_epoch) + self.trained_batch = self._find_max_iteration( + dir=self.model_dir, + prefix=prefix, + start='batch', + end='.params') + + if self.trained_epoch == -1: + msg = "CheckpointHandler: No checkpoint found, training from scratch for " + if estimator.max_batch: + msg += "%d batches" % estimator.max_batch + else: + msg += "%d epochs" % estimator.max_epoch + self.logger.info(msg) + else: + msg = "CheckpointHandler: Checkpoint resumed from epoch %d batch %d, " \ + "continue to train for " % (self.trained_epoch, self.trained_batch) + # change maximum number of epoch or batch to train if resumed from epoch checkpoint + if estimator.max_epoch: + if self.trained_epoch >= estimator.max_epoch - 1: + raise ValueError("Found checkpoint with maximum number of epoch %d reached, please specify " + "resume_from_checkpoint=False (default value) if you wan to train from scratch." + % estimator.max_epoch) + estimator.max_epoch = estimator.max_epoch - self.trained_epoch - 1 + msg += "%d epochs " % estimator.max_epoch + if estimator.max_batch: + if self.trained_batch >= estimator.max_batch - 1: + raise ValueError("Found checkpoint with maximum number of batch %d reached, please specify" + "resume_from_checkpoint=False (default value) if you wan to train from scratch." + % self.trained_batch) + estimator.max_batch = estimator.max_batch - self.trained_batch - 1 + msg += "%d batches " % estimator.max_batch + # load checkpoint + param_file = "%s-epoch%dbatch%d.params" % (self.model_prefix, self.trained_epoch, self.trained_batch) + param_file = os.path.join(self.model_dir, param_file) + trainer_file = "%s-epoch%dbatch%d.states" % (self.model_prefix, self.trained_epoch, self.trained_batch) + trainer_file = os.path.join(self.model_dir, trainer_file) + assert os.path.exists(param_file), "Failed to load checkpoint, %s does not exist" % param_file + assert os.path.exists(trainer_file), "Failed to load checkpoint, %s does not exist" % trainer_file + estimator.net.load_parameters(param_file, ctx=estimator.context) + estimator.trainer.load_states(trainer_file) + self.logger.warning(msg) + + def _find_max_iteration(self, dir, prefix, start, end, saved_checkpoints=None): + error_msg = "Error parsing checkpoint file, please check your " \ + "checkpoints have the format: " \ + "{model_name}-epoch{epoch_number}batch{batch_number}.params, " \ + "there should also be a .states file for each .params file " + max_iter = -1 + for fname in os.listdir(dir): + if fname.startswith(prefix) and '.params' in fname: + if saved_checkpoints: + # save prefix of existing checkpoints + saved_checkpoints.append(fname[:fname.find('.params')]) + try: + # find trained number of epoch + iter = int(fname[fname.find(start) + len(start): fname.find(end)]) + if iter > max_iter: + max_iter = iter + except ValueError: + raise ValueError(error_msg) + return max_iter + + +class EarlyStoppingHandler(TrainBegin, EpochEnd, TrainEnd): + """Early stop training if monitored value is not improving + + Parameters + ---------- + monitor: EvalMetric + The metric to monitor, and stop training if this metric does not improve. + min_delta: float, default 0 + Minimal change in monitored value to be considered as an improvement. + patience: int, default 0 + Number of epochs to wait for improvement before terminate training. + mode: str, default 'auto' + One of {auto, min, max}, if `save_best_only=True`, the comparison to make + and determine if the monitored value has improved. if 'auto' mode, checkpoint + handler will try to use min or max based on the monitored metric name. + baseline: float + Baseline value to compare the monitored value with. + """ + + def __init__(self, + monitor, + min_delta=0, + patience=0, + mode='auto', + baseline=None): + super(EarlyStoppingHandler, self).__init__() + + if not isinstance(monitor, EvalMetric): + raise ValueError("Please provide one of the metric objects as monitor, " + "You can create these objects using estimator.prepare_loss_and_metric()") + self.monitor = monitor + self.baseline = baseline + self.patience = patience + self.min_delta = min_delta + self.wait = 0 + self.stopped_epoch = 0 + self.current_epoch = 0 + self.stop_training = False + self.logger = logging.getLogger(__name__) + + if mode not in ['auto', 'min', 'max']: + warnings.warn('EarlyStopping mode %s is unknown, ' + 'fallback to auto mode. CheckpointHandler will use' + 'max mode for f1 and accuracy metric comparison and ' + 'use min mode other wise' % (mode), + RuntimeWarning) + mode = 'auto' + + if mode == 'min': + self.monitor_op = np.less + elif mode == 'max': + self.monitor_op = np.greater + else: + if 'acc' or 'f1' in self.monitor.get()[0].lower(): + self.logger.info("`greater` operator is used to determine " + "if %s has improved, please use `min` for mode " + "if you want otherwise", self.monitor.get()[0]) + self.monitor_op = np.greater + else: + self.logger.info("`less` operator is used to determine " + "if %s has improved, please use `max` for mode " + "if you want otherwise", self.monitor.get()[0]) + self.monitor_op = np.less + + if self.monitor_op == np.greater: + self.min_delta *= 1 + else: + self.min_delta *= -1 + + def train_begin(self, estimator, *args, **kwargs): + self.wait = 0 + self.stopped_epoch = 0 + self.current_epoch = 0 + self.stop_training = False + if self.baseline is not None: + self.best = self.baseline + else: + self.best = np.Inf if self.monitor_op == np.less else -np.Inf + + def epoch_end(self, estimator, *args, **kwargs): + monitor_name, monitor_value = self.monitor.get() + if np.isnan(monitor_value): + warnings.warn(RuntimeWarning('%s is not updated, make sure you pass one of the metric objects' + 'as monitor, you can use estimator.prepare_loss_and_metrics to' + 'create all metric objects', monitor_name)) + else: + if self.monitor_op(monitor_value - self.min_delta, self.best): + self.best = monitor_value + self.wait = 0 + else: + self.wait += 1 + if self.wait >= self.patience: + self.stopped_epoch = self.current_epoch + self.stop_training = True + self.current_epoch += 1 + return self.stop_training + + def train_end(self, estimator, *args, **kwargs): + if self.stopped_epoch > 0: + self.logger.info('[Epoch %d] EarlyStoppingHanlder: early stopping due to %s not improving', + self.stopped_epoch, self.monitor.get()[0]) diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index 6935c2752e1a..0939490a8307 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -255,6 +255,13 @@ def learning_rate(self): else: return self._optimizer.learning_rate + @property + def optimizer(self): + if isinstance(self._optimizer, opt.Optimizer): + return self._optimizer + else: + raise UserWarning("Optimizer has not been initialized yet") + def set_learning_rate(self, lr): """Sets a new learning rate of the optimizer. diff --git a/tests/nightly/JenkinsfileForBinaries b/tests/nightly/JenkinsfileForBinaries index ea6db1a20cbf..e4b9ff1acbb1 100755 --- a/tests/nightly/JenkinsfileForBinaries +++ b/tests/nightly/JenkinsfileForBinaries @@ -141,6 +141,14 @@ core_logic: { utils.docker_run('ubuntu_nightly_gpu', 'nightly_tutorial_test_ubuntu_python3_gpu', true, '1500m') } } + }, + 'Gluon estimator: GPU': { + node(NODE_LINUX_GPU) { + ws('workspace/estimator-test-gpu') { + utils.unpack_and_init('gpu', mx_lib) + utils.docker_run('ubuntu_nightly_gpu', 'nightly_estimator', true) + } + } } } } diff --git a/tests/nightly/estimator/test_estimator_cnn.py b/tests/nightly/estimator/test_estimator_cnn.py new file mode 100644 index 000000000000..c60dc544b347 --- /dev/null +++ b/tests/nightly/estimator/test_estimator_cnn.py @@ -0,0 +1,151 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Test gluon estimator on CNN models + +import argparse +import numpy as np +import mxnet as mx +from mxnet import gluon, init, nd +from mxnet.gluon import data +from mxnet.gluon.contrib.estimator import estimator +from mxnet.gluon.model_zoo import vision + +def load_data_mnist(batch_size, resize=None, num_workers=4): + ''' + Load MNIST dataset + ''' + transformer = [] + if resize: + transformer += [data.vision.transforms.Resize(resize)] + transformer += [data.vision.transforms.ToTensor()] + transformer = data.vision.transforms.Compose(transformer) + mnist_train = data.vision.MNIST(train=True) + mnist_test = data.vision.MNIST(train=False) + train_iter = data.DataLoader( + mnist_train.transform_first(transformer), batch_size, shuffle=True, + num_workers=num_workers) + test_iter = data.DataLoader( + mnist_test.transform_first(transformer), batch_size, shuffle=False, + num_workers=num_workers) + return train_iter, test_iter + +def bilinear_kernel(in_channels, out_channels, kernel_size): + ''' + Bilinear interpolation using transposed convolution + https://github.com/d2l-ai/d2l-en/blob/master/chapter_computer-vision/fcn.md + ''' + factor = (kernel_size + 1) // 2 + if kernel_size % 2 == 1: + center = factor - 1 + else: + center = factor - 0.5 + og = np.ogrid[:kernel_size, :kernel_size] + filt = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor) + weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), dtype='float32') + weight[range(in_channels), range(out_channels), :, :] = filt + return nd.array(weight) + +def get_net(model_name, context): + if model_name == 'FCN': + num_classes = 21 + pretrained_net = vision.resnet18_v2(pretrained=True, ctx=context) + net = gluon.nn.HybridSequential() + for layer in pretrained_net.features[:-2]: + net.add(layer) + net.add(gluon.nn.Conv2D(num_classes, kernel_size=1), + gluon.nn.Conv2DTranspose(num_classes, kernel_size=64, padding=16, strides=32)) + net[-1].initialize(init.Constant(bilinear_kernel(num_classes, num_classes, 64)), ctx=context) + net[-2].initialize(init=init.Xavier(), ctx=context) + input_shape = (1, 3, 320, 480) + label_shape = (1, 320, 480) + loss_axis = 1 + else: + net = vision.get_model(model_name, classes=10) + net.initialize(mx.init.Xavier(), ctx=context) + input_shape = (1, 1, 224, 224) + label_shape = 1 + loss_axis = -1 + return net, input_shape, label_shape, loss_axis + +def test_estimator_cpu(): + ''' + Test estimator by doing one pass over each model with synthetic data + ''' + models = ['resnet18_v1', + 'FCN' + ] + context = mx.cpu() + for model_name in models: + net, input_shape, label_shape, loss_axis = get_net(model_name, context) + train_dataset = gluon.data.dataset.ArrayDataset(mx.nd.random.uniform(shape=input_shape), + mx.nd.zeros(shape=label_shape)) + val_dataset = gluon.data.dataset.ArrayDataset(mx.nd.random.uniform(shape=input_shape), + mx.nd.zeros(shape=label_shape)) + loss = gluon.loss.SoftmaxCrossEntropyLoss(axis=loss_axis) + train_data = gluon.data.DataLoader(train_dataset, batch_size=1) + val_data = gluon.data.DataLoader(val_dataset, batch_size=1) + net.hybridize() + trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001}) + # Define estimator + est = estimator.Estimator(net=net, + loss=loss, + metrics=mx.metric.Accuracy(), + trainer=trainer, + context=context) + # Call fit() + est.fit(train_data=train_data, + val_data=val_data, + epochs=1) + +def test_estimator_gpu(): + ''' + Test estimator by training resnet18_v1 for 5 epochs on MNIST and verify accuracy + ''' + model_name = 'resnet18_v1' + batch_size = 128 + num_epochs = 5 + context = mx.gpu(0) + net, _, _, _ = get_net(model_name, context) + train_data, test_data = load_data_mnist(batch_size, resize=224) + loss = gluon.loss.SoftmaxCrossEntropyLoss() + net.hybridize() + acc = mx.metric.Accuracy() + trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001}) + # Define estimator + est = estimator.Estimator(net=net, + loss=loss, + metrics=acc, + trainer=trainer, + context=context) + # Call fit() + est.fit(train_data=train_data, + val_data=test_data, + epochs=num_epochs) + + assert acc.get()[1] > 0.80 + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='test gluon estimator') + parser.add_argument('--type', type=str, default='cpu') + opt = parser.parse_args() + if opt.type == 'cpu': + test_estimator_cpu() + elif opt.type == 'gpu': + test_estimator_gpu() + else: + raise RuntimeError("Unknown test type") diff --git a/tests/nightly/estimator/test_sentiment_rnn.py b/tests/nightly/estimator/test_sentiment_rnn.py new file mode 100644 index 000000000000..404bf83fb86f --- /dev/null +++ b/tests/nightly/estimator/test_sentiment_rnn.py @@ -0,0 +1,276 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Gluon Text Sentiment Classification Example using RNN/CNN +Example modified from below link: +https://github.com/d2l-ai/d2l-en/blob/master/chapter_natural-language-processing/sentiment-analysis-rnn.md +https://github.com/d2l-ai/d2l-en/blob/master/chapter_natural-language-processing/sentiment-analysis-cnn.md""" + +import argparse +import os +import tarfile +import random +import collections +import mxnet as mx +from mxnet import nd, gluon +from mxnet.contrib import text +from mxnet.gluon import nn, rnn +from mxnet.gluon.contrib.estimator import estimator + + +class TextCNN(nn.Block): + def __init__(self, vocab, embed_size, kernel_sizes, num_channels, + **kwargs): + super(TextCNN, self).__init__(**kwargs) + self.embedding = nn.Embedding(len(vocab), embed_size) + # The embedding layer does not participate in training + self.constant_embedding = nn.Embedding(len(vocab), embed_size) + self.dropout = nn.Dropout(0.5) + self.decoder = nn.Dense(2) + # The max-over-time pooling layer has no weight, so it can share an + # instance + self.pool = nn.GlobalMaxPool1D() + # Create multiple one-dimensional convolutional layers + self.convs = nn.Sequential() + for c, k in zip(num_channels, kernel_sizes): + self.convs.add(nn.Conv1D(c, k, activation='relu')) + + def forward(self, inputs): + # Concatenate the output of two embedding layers with shape of + # (batch size, number of words, word vector dimension) by word vector + embeddings = nd.concat( + self.embedding(inputs), self.constant_embedding(inputs), dim=2) + # According to the input format required by Conv1D, the word vector + # dimension, that is, the channel dimension of the one-dimensional + # convolutional layer, is transformed into the previous dimension + embeddings = embeddings.transpose((0, 2, 1)) + # For each one-dimensional convolutional layer, after max-over-time + # pooling, an NDArray with the shape of (batch size, channel size, 1) + # can be obtained. Use the flatten function to remove the last + # dimension and then concatenate on the channel dimension + encoding = nd.concat(*[nd.flatten( + self.pool(conv(embeddings))) for conv in self.convs], dim=1) + # After applying the dropout method, use a fully connected layer to + # obtain the output + outputs = self.decoder(self.dropout(encoding)) + return outputs + + +class BiRNN(nn.Block): + def __init__(self, vocab, embed_size, num_hiddens, num_layers, **kwargs): + super(BiRNN, self).__init__(**kwargs) + self.embedding = nn.Embedding(len(vocab), embed_size) + # Set Bidirectional to True to get a bidirectional recurrent neural + # network + self.encoder = rnn.LSTM(num_hiddens, num_layers=num_layers, + bidirectional=True, input_size=embed_size) + self.decoder = nn.Dense(2) + + def forward(self, inputs): + # The shape of inputs is (batch size, number of words). Because LSTM + # needs to use sequence as the first dimension, the input is + # transformed and the word feature is then extracted. The output shape + # is (number of words, batch size, word vector dimension). + embeddings = self.embedding(inputs.T) + # The shape of states is (number of words, batch size, 2 * number of + # hidden units). + states = self.encoder(embeddings) + # Concatenate the hidden states of the initial time step and final + # time step to use as the input of the fully connected layer. Its + # shape is (batch size, 4 * number of hidden units) + encoding = nd.concat(states[0], states[-1]) + outputs = self.decoder(encoding) + return outputs + + +def download_imdb(data_dir='/tmp/data'): + ''' + Download and extract the IMDB dataset + ''' + url = ('http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz') + sha1 = '01ada507287d82875905620988597833ad4e0903' + if not os.path.exists(data_dir): + os.makedirs(data_dir) + file_path = os.path.join(data_dir, 'aclImdb_v1.tar.gz') + if not os.path.isfile(file_path): + file_path = gluon.utils.download(url, data_dir, sha1_hash=sha1) + with tarfile.open(file_path, 'r') as f: + f.extractall(data_dir) + + +def read_imdb(folder='train'): + ''' + Read the IMDB dataset + ''' + data = [] + for label in ['pos', 'neg']: + folder_name = os.path.join('/tmp/data/aclImdb/', folder, label) + for file in os.listdir(folder_name): + with open(os.path.join(folder_name, file), 'rb') as f: + review = f.read().decode('utf-8').replace('\n', '').lower() + data.append([review, 1 if label == 'pos' else 0]) + random.shuffle(data) + return data + + +def get_tokenized_imdb(data): + ''' + Tokenized the words + ''' + + def tokenizer(text): + return [tok.lower() for tok in text.split(' ')] + + return [tokenizer(review) for review, _ in data] + + +def get_vocab_imdb(data): + ''' + Get the indexed tokens + ''' + tokenized_data = get_tokenized_imdb(data) + counter = collections.Counter([tk for st in tokenized_data for tk in st]) + return text.vocab.Vocabulary(counter, min_freq=5) + + +def preprocess_imdb(data, vocab): + ''' + Make the length of each comment 500 by truncating or adding 0s + ''' + max_l = 500 + + def pad(x): + return x[:max_l] if len(x) > max_l else x + [0] * (max_l - len(x)) + + tokenized_data = get_tokenized_imdb(data) + features = nd.array([pad(vocab.to_indices(x)) for x in tokenized_data]) + labels = nd.array([score for _, score in data]) + return features, labels + + +def run(net, train_dataloader, test_dataloader, **kwargs): + ''' + Train a test sentiment model + ''' + num_epochs = kwargs['epochs'] + ctx = kwargs['ctx'] + batch_size = kwargs['batch_size'] + lr = kwargs['lr'] + + # Define trainer + trainer = mx.gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': lr}) + # Define loss and evaluation metrics + loss = gluon.loss.SoftmaxCrossEntropyLoss() + acc = mx.metric.Accuracy() + + # Define estimator + est = estimator.Estimator(net=net, loss=loss, metrics=acc, + trainer=trainer, context=ctx) + # Begin training + est.fit(train_data=train_dataloader, val_data=test_dataloader, + epochs=num_epochs) + return acc + + +def test_estimator_cpu(**kwargs): + ''' + Test estimator by doing one pass over each model with synthetic data + ''' + models = ['TextCNN', 'BiRNN'] + ctx = kwargs['ctx'] + batch_size = kwargs['batch_size'] + embed_size = kwargs['embed_size'] + + train_data = mx.nd.random.randint(low=0, high=100, shape=(2 * batch_size, 500)) + train_label = mx.nd.random.randint(low=0, high=2, shape=(2 * batch_size,)) + val_data = mx.nd.random.randint(low=0, high=100, shape=(batch_size, 500)) + val_label = mx.nd.random.randint(low=0, high=2, shape=(batch_size,)) + + train_dataloader = gluon.data.DataLoader(dataset=gluon.data.ArrayDataset(train_data, train_label), + batch_size=batch_size, shuffle=True) + val_dataloader = gluon.data.DataLoader(dataset=gluon.data.ArrayDataset(val_data, val_label), + batch_size=batch_size) + vocab_list = mx.nd.zeros(shape=(100,)) + + # Get the model + for model in models: + if model == 'TextCNN': + kernel_sizes, nums_channels = [3, 4, 5], [100, 100, 100] + net = TextCNN(vocab_list, embed_size, kernel_sizes, nums_channels) + else: + num_hiddens, num_layers = 100, 2 + net = BiRNN(vocab_list, embed_size, num_hiddens, num_layers) + net.initialize(mx.init.Xavier(), ctx=ctx) + + run(net, train_dataloader, val_dataloader, **kwargs) + + +def test_estimator_gpu(**kwargs): + ''' + Test estimator by training Bidirectional RNN for 5 epochs on the IMDB dataset + and verify accuracy + ''' + ctx = kwargs['ctx'] + batch_size = kwargs['batch_size'] + num_epochs = kwargs['epochs'] + embed_size = kwargs['embed_size'] + + # data + download_imdb() + train_data, test_data = read_imdb('train'), read_imdb('test') + vocab = get_vocab_imdb(train_data) + + train_set = gluon.data.ArrayDataset(*preprocess_imdb(train_data, vocab)) + test_set = gluon.data.ArrayDataset(*preprocess_imdb(test_data, vocab)) + train_dataloader = gluon.data.DataLoader(train_set, batch_size, shuffle=True) + test_dataloader = gluon.data.DataLoader(test_set, batch_size) + + # Model + num_hiddens, num_layers = 100, 2 + net = BiRNN(vocab, embed_size, num_hiddens, num_layers) + net.initialize(mx.init.Xavier(), ctx=ctx) + + glove_embedding = text.embedding.create( + 'glove', pretrained_file_name='glove.6B.100d.txt', vocabulary=vocab) + + net.embedding.weight.set_data(glove_embedding.idx_to_vec) + net.embedding.collect_params().setattr('grad_req', 'null') + + acc = run(net, train_dataloader, test_dataloader, **kwargs) + + assert acc.get()[1] > 0.70 + + +parser = argparse.ArgumentParser(description='test gluon estimator') +parser.add_argument('--type', type=str, default='cpu') +opt = parser.parse_args() +kwargs = { + 'batch_size': 64, + 'lr': 0.01, + 'embed_size': 100 +} + +if opt.type == 'cpu': + kwargs['ctx'] = mx.cpu() + kwargs['epochs'] = 1 + test_estimator_cpu(**kwargs) +elif opt.type == 'gpu': + kwargs['ctx'] = mx.gpu() + kwargs['epochs'] = 5 + test_estimator_gpu(**kwargs) +else: + raise RuntimeError("Unknown test type") diff --git a/tests/python/unittest/test_gluon_estimator.py b/tests/python/unittest/test_gluon_estimator.py new file mode 100644 index 000000000000..d2e8c082aa08 --- /dev/null +++ b/tests/python/unittest/test_gluon_estimator.py @@ -0,0 +1,371 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +''' Unit tests for Gluon Estimator ''' + +import sys +import unittest + +import mxnet as mx +from mxnet import gluon +from mxnet.gluon import nn +from mxnet.gluon.contrib.estimator import * +from nose.tools import assert_raises + + +def _get_test_network(): + net = nn.Sequential() + net.add(nn.Dense(4, activation='relu', flatten=False)) + return net + + +def _get_test_data(): + batch_size = 4 + in_data = mx.nd.random.uniform(shape=(10, 3)) + out_data = mx.nd.random.uniform(shape=(10, 4)) + # Input dataloader + dataset = gluon.data.dataset.ArrayDataset(in_data, out_data) + dataloader = gluon.data.DataLoader(dataset, batch_size=batch_size) + dataiter = mx.io.NDArrayIter(data=in_data, label=out_data, batch_size=batch_size) + return dataloader, dataiter + + +def test_fit(): + ''' test estimator with different train data types ''' + net = _get_test_network() + dataloader, dataiter = _get_test_data() + num_epochs = 1 + ctx = mx.cpu() + loss = gluon.loss.L2Loss() + acc = mx.metric.Accuracy() + net.initialize(ctx=ctx) + trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001}) + est = Estimator(net=net, + loss=loss, + metrics=acc, + trainer=trainer, + context=ctx) + + est.fit(train_data=dataloader, + epochs=num_epochs) + + with assert_raises(ValueError): + est.fit(train_data=dataiter, + epochs=num_epochs) + + # Input NDArray + with assert_raises(ValueError): + est.fit(train_data=[mx.nd.ones(shape=(10, 3))], + epochs=num_epochs) + + +def test_validation(): + ''' test different validation data types''' + net = _get_test_network() + dataloader, dataiter = _get_test_data() + num_epochs = 1 + ctx = mx.cpu() + loss = gluon.loss.L2Loss() + acc = mx.metric.Accuracy() + net.initialize(ctx=ctx) + trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001}) + est = Estimator(net=net, + loss=loss, + metrics=acc, + trainer=trainer, + context=ctx) + # Input dataloader + est.fit(train_data=dataloader, + val_data=dataloader, + epochs=num_epochs) + + # using validation handler + train_metrics, val_metrics = est.prepare_loss_and_metrics() + validation_handler = ValidationHandler(val_data=dataloader, eval_fn=est.evaluate, + val_metrics=val_metrics) + + with assert_raises(ValueError): + est.fit(train_data=dataiter, + val_data=dataiter, + epochs=num_epochs) + # Input NDArray + with assert_raises(ValueError): + est.fit(train_data=[mx.nd.ones(shape=(10, 3))], + val_data=[mx.nd.ones(shape=(10, 3))], + epochs=num_epochs) + + +@unittest.skipIf(sys.version_info.major < 3, 'Test on python 3') +def test_initializer(): + ''' test with no initializer, inconsistent initializer ''' + net = _get_test_network() + train_data, _ = _get_test_data() + num_epochs = 1 + ctx = mx.cpu() + + loss = gluon.loss.L2Loss() + acc = mx.metric.Accuracy() + # no initializer + est = Estimator(net=net, + loss=loss, + metrics=acc, + context=ctx) + est.fit(train_data=train_data, + epochs=num_epochs) + + # different initializer for net and estimator + net = _get_test_network() + net.initialize(mx.init.Xavier(), ctx=ctx) + trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001}) + # catch reinit warning + with warnings.catch_warnings(record=True) as w: + est = Estimator(net=net, + loss=loss, + metrics=acc, + initializer=mx.init.MSRAPrelu(), + trainer=trainer, + context=ctx) + assert 'Network already fully initialized' in str(w[-1].message) + # net partially initialized, fine tuning use case + net = gluon.model_zoo.vision.resnet18_v1(pretrained=True, ctx=ctx) + net.output = gluon.nn.Dense(10) #last layer not initialized + est = Estimator(net, loss=loss, metrics=acc, context=ctx) + dataset = gluon.data.ArrayDataset(mx.nd.zeros((10, 3, 224, 224)), mx.nd.zeros((10, 10))) + train_data = gluon.data.DataLoader(dataset=dataset, batch_size=5) + est.fit(train_data=train_data, + epochs=num_epochs) + + +@unittest.skipIf(sys.version_info.major < 3, 'Test on python 3') +def test_trainer(): + ''' test with no trainer and invalid trainer ''' + net = _get_test_network() + train_data, _ = _get_test_data() + num_epochs = 1 + ctx = mx.cpu() + + loss = gluon.loss.L2Loss() + acc = mx.metric.Accuracy() + net.initialize(ctx=ctx) + # input no trainer + with warnings.catch_warnings(record=True) as w: + est = Estimator(net=net, + loss=loss, + metrics=acc, + context=ctx) + assert 'No trainer specified' in str(w[-1].message) + est.fit(train_data=train_data, + epochs=num_epochs) + + # input invalid trainer + trainer = 'sgd' + with assert_raises(ValueError): + est = Estimator(net=net, + loss=loss, + metrics=acc, + trainer=trainer, + context=ctx) + + +def test_metric(): + ''' test with no metric, list of metrics, invalid metric ''' + net = _get_test_network() + train_data, _ = _get_test_data() + num_epochs = 1 + ctx = mx.cpu() + + loss = gluon.loss.L2Loss() + net.initialize(ctx=ctx) + trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001}) + # input no metric + est = Estimator(net=net, + loss=loss, + trainer=trainer, + context=ctx) + est.fit(train_data=train_data, + epochs=num_epochs) + # input list of metrics + metrics = [mx.metric.Accuracy(), mx.metric.Accuracy()] + est = Estimator(net=net, + loss=loss, + metrics=metrics, + trainer=trainer, + context=ctx) + est.fit(train_data=train_data, + epochs=num_epochs) + # input invalid metric + with assert_raises(ValueError): + est = Estimator(net=net, + loss=loss, + metrics='acc', + trainer=trainer, + context=ctx) + # test default metric + loss = gluon.loss.SoftmaxCrossEntropyLoss() + est = Estimator(net=net, + loss=loss, + trainer=trainer, + context=ctx) + est.prepare_loss_and_metrics() + assert isinstance(est.train_metrics[0], mx.metric.Accuracy) + + +def test_loss(): + ''' test with invalid loss ''' + net = _get_test_network() + ctx = mx.cpu() + acc = mx.metric.Accuracy() + net.initialize(ctx=ctx) + trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001}) + # input invalid loss + with assert_raises(ValueError): + est = Estimator(net=net, + loss='mse', + metrics=acc, + trainer=trainer, + context=ctx) + + +def test_context(): + ''' test with no context, list of context, invalid context ''' + net = _get_test_network() + loss = gluon.loss.L2Loss() + metrics = mx.metric.Accuracy() + # input no context + est = Estimator(net=net, + loss=loss, + metrics=metrics) + # input list of context + gpus = mx.context.num_gpus() + ctx = [mx.gpu(i) for i in range(gpus)] if gpus > 0 else [mx.cpu()] + net = _get_test_network() + est = Estimator(net=net, + loss=loss, + metrics=metrics, + context=ctx) + # input invalid context + with assert_raises(ValueError): + est = Estimator(net=net, + loss=loss, + metrics=metrics, + context='cpu') + + with assert_raises(AssertionError): + est = Estimator(net=net, + loss=loss, + metrics=metrics, + context=[mx.gpu(0), mx.gpu(100)]) + + +def test_categorize_handlers(): + class CustomHandler1(TrainBegin): + + def train_begin(self): + print("custom train begin") + + class CustomHandler2(EpochBegin, BatchBegin, TrainEnd): + + def epoch_begin(self): + print("custom epoch begin") + + def batch_begin(self): + print("custom batch begin") + + def train_end(self): + print("custom train end") + + class CustomHandler3(EpochBegin, BatchBegin, BatchEnd, TrainEnd): + + def epoch_begin(self): + print("custom epoch begin") + + def batch_begin(self): + print("custom batch begin") + + def batch_end(self): + print("custom batch end") + + def train_end(self): + print("custom train end") + + net = nn.Sequential() + net.add(nn.Dense(10)) + loss = gluon.loss.SoftmaxCrossEntropyLoss() + est = Estimator(net, loss=loss) + event_handlers = [CustomHandler1(), CustomHandler2(), CustomHandler3()] + train_begin, epoch_begin, batch_begin, \ + batch_end, epoch_end, train_end = est._categorize_handlers(event_handlers) + assert len(train_begin) == 1 + assert len(epoch_begin) == 2 + assert len(batch_begin) == 2 + assert len(batch_end) == 1 + assert len(train_end) == 2 + + +@unittest.skipIf(sys.version_info.major < 3, 'Test on python 3') +def test_default_handlers(): + net = _get_test_network() + train_data, _ = _get_test_data() + + num_epochs = 1 + ctx = mx.cpu() + + net.initialize(ctx=ctx) + trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001}) + + train_acc = mx.metric.RMSE() + loss = gluon.loss.L2Loss() + + est = Estimator(net=net, + loss=loss, + metrics=train_acc, + trainer=trainer, + context=ctx) + # no handler + with warnings.catch_warnings(record=True) as w: + est.fit(train_data=train_data, epochs=num_epochs) + assert 'You are training with the' in str(w[-1].message) + + # handler with prepared loss and metrics + # use mix of default and user defined handlers + train_metrics, val_metrics = est.prepare_loss_and_metrics() + logging = LoggingHandler(train_metrics=train_metrics, val_metrics=val_metrics) + with warnings.catch_warnings(record=True) as w: + est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[logging]) + assert 'You are training with the' in str(w[-1].message) + # provide metric handler by default + assert 'MetricHandler' in str(w[-1].message) + + # handler with all user defined metrics + # use mix of default and user defined handlers + metric = MetricHandler(train_metrics=[train_acc]) + logging = LoggingHandler(train_metrics=[train_acc], val_metrics=[mx.metric.RMSE("val acc")]) + est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[metric, logging]) + + # handler with mixed metrics, some handler use metrics prepared by estimator + # some handler use metrics user prepared + logging = LoggingHandler(train_metrics=train_metrics, val_metrics=[mx.metric.RMSE("val acc")]) + with assert_raises(ValueError): + est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[logging]) + + # test handler order + train_metrics, val_metrics = est.prepare_loss_and_metrics() + early_stopping = EarlyStoppingHandler(monitor=val_metrics[0]) + handlers = est._prepare_default_handlers(val_data=None, event_handlers=[early_stopping]) + assert len(handlers) == 4 + assert isinstance(handlers[0], MetricHandler) + assert isinstance(handlers[3], LoggingHandler) diff --git a/tests/python/unittest/test_gluon_event_handler.py b/tests/python/unittest/test_gluon_event_handler.py new file mode 100644 index 000000000000..7ea5ff3f4b62 --- /dev/null +++ b/tests/python/unittest/test_gluon_event_handler.py @@ -0,0 +1,198 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os + +import mxnet as mx +from common import TemporaryDirectory +from mxnet import nd +from mxnet.gluon import nn, loss +from mxnet.gluon.contrib.estimator import estimator, event_handler + + +def _get_test_network(net=nn.Sequential()): + net.add(nn.Dense(128, activation='relu', flatten=False), + nn.Dense(64, activation='relu'), + nn.Dense(10, activation='relu')) + return net + + +def _get_test_data(): + data = nd.ones((32, 100)) + label = nd.zeros((32, 1)) + data_arr = mx.gluon.data.dataset.ArrayDataset(data, label) + return mx.gluon.data.DataLoader(data_arr, batch_size=8) + + +def test_checkpoint_handler(): + with TemporaryDirectory() as tmpdir: + model_prefix = 'test_epoch' + file_path = os.path.join(tmpdir, model_prefix) + test_data = _get_test_data() + + net = _get_test_network() + ce_loss = loss.SoftmaxCrossEntropyLoss() + acc = mx.metric.Accuracy() + est = estimator.Estimator(net, loss=ce_loss, metrics=acc) + checkpoint_handler = event_handler.CheckpointHandler(model_dir=tmpdir, + model_prefix=model_prefix, + monitor=acc, + save_best=True, + epoch_period=1) + est.fit(test_data, event_handlers=[checkpoint_handler], epochs=1) + assert checkpoint_handler.current_epoch == 1 + assert checkpoint_handler.current_batch == 4 + assert os.path.isfile(file_path + '-best.params') + assert os.path.isfile(file_path + '-best.states') + assert os.path.isfile(file_path + '-epoch0batch4.params') + assert os.path.isfile(file_path + '-epoch0batch4.states') + + model_prefix = 'test_batch' + file_path = os.path.join(tmpdir, model_prefix) + net = _get_test_network(nn.HybridSequential()) + net.hybridize() + est = estimator.Estimator(net, loss=ce_loss, metrics=acc) + checkpoint_handler = event_handler.CheckpointHandler(model_dir=tmpdir, + model_prefix=model_prefix, + epoch_period=None, + batch_period=2, + max_checkpoints=2) + est.fit(test_data, event_handlers=[checkpoint_handler], batches=10) + assert checkpoint_handler.current_batch == 10 + assert checkpoint_handler.current_epoch == 3 + assert not os.path.isfile(file_path + 'best.params') + assert not os.path.isfile(file_path + 'best.states') + assert not os.path.isfile(file_path + '-epoch0batch0.params') + assert not os.path.isfile(file_path + '-epoch0batch0.states') + assert os.path.isfile(file_path + '-symbol.json') + assert os.path.isfile(file_path + '-epoch1batch7.params') + assert os.path.isfile(file_path + '-epoch1batch7.states') + assert os.path.isfile(file_path + '-epoch2batch9.params') + assert os.path.isfile(file_path + '-epoch2batch9.states') + +def test_resume_checkpoint(): + with TemporaryDirectory() as tmpdir: + model_prefix = 'test_net' + file_path = os.path.join(tmpdir, model_prefix) + test_data = _get_test_data() + + net = _get_test_network() + ce_loss = loss.SoftmaxCrossEntropyLoss() + acc = mx.metric.Accuracy() + est = estimator.Estimator(net, loss=ce_loss, metrics=acc) + checkpoint_handler = event_handler.CheckpointHandler(model_dir=tmpdir, + model_prefix=model_prefix, + monitor=acc, + max_checkpoints=1) + est.fit(test_data, event_handlers=[checkpoint_handler], epochs=2) + assert os.path.isfile(file_path + '-epoch1batch8.params') + assert os.path.isfile(file_path + '-epoch1batch8.states') + checkpoint_handler = event_handler.CheckpointHandler(model_dir=tmpdir, + model_prefix=model_prefix, + monitor=acc, + max_checkpoints=1, + resume_from_checkpoint=True) + est.fit(test_data, event_handlers=[checkpoint_handler], epochs=5) + # should only continue to train 3 epochs and last checkpoint file is epoch4 + assert est.max_epoch == 3 + assert os.path.isfile(file_path + '-epoch4batch20.states') + + +def test_early_stopping(): + test_data = _get_test_data() + + net = _get_test_network() + ce_loss = loss.SoftmaxCrossEntropyLoss() + acc = mx.metric.Accuracy() + est = estimator.Estimator(net, loss=ce_loss, metrics=acc) + early_stopping = event_handler.EarlyStoppingHandler(monitor=acc, + patience=0, + mode='min') + est.fit(test_data, event_handlers=[early_stopping], epochs=5) + assert early_stopping.current_epoch == 2 + assert early_stopping.stopped_epoch == 1 + + early_stopping = event_handler.EarlyStoppingHandler(monitor=acc, + patience=2, + mode='auto') + est.fit(test_data, event_handlers=[early_stopping], epochs=1) + assert early_stopping.current_epoch == 1 + + +def test_logging(): + with TemporaryDirectory() as tmpdir: + test_data = _get_test_data() + file_name = 'test_log' + output_dir = os.path.join(tmpdir, file_name) + + net = _get_test_network() + ce_loss = loss.SoftmaxCrossEntropyLoss() + acc = mx.metric.Accuracy() + est = estimator.Estimator(net, loss=ce_loss, metrics=acc) + train_metrics, val_metrics = est.prepare_loss_and_metrics() + logging_handler = event_handler.LoggingHandler(file_name=file_name, + file_location=tmpdir, + train_metrics=train_metrics, + val_metrics=val_metrics) + est.fit(test_data, event_handlers=[logging_handler], epochs=3) + assert logging_handler.batch_index == 0 + assert logging_handler.current_epoch == 3 + assert os.path.isfile(output_dir) + + +def test_custom_handler(): + class CustomStopHandler(event_handler.TrainBegin, + event_handler.BatchEnd, + event_handler.EpochEnd): + def __init__(self, batch_stop=None, epoch_stop=None): + self.batch_stop = batch_stop + self.epoch_stop = epoch_stop + self.num_batch = 0 + self.num_epoch = 0 + self.stop_training = False + + def train_begin(self, estimator, *args, **kwargs): + self.num_batch = 0 + self.num_epoch = 0 + + def batch_end(self, estimator, *args, **kwargs): + self.num_batch += 1 + if self.num_batch == self.batch_stop: + self.stop_training = True + return self.stop_training + + def epoch_end(self, estimator, *args, **kwargs): + self.num_epoch += 1 + if self.num_epoch == self.epoch_stop: + self.stop_training = True + return self.stop_training + + # total data size is 32, batch size is 8 + # 4 batch per epoch + test_data = _get_test_data() + net = _get_test_network() + ce_loss = loss.SoftmaxCrossEntropyLoss() + acc = mx.metric.Accuracy() + est = estimator.Estimator(net, loss=ce_loss, metrics=acc) + custom_handler = CustomStopHandler(3, 2) + est.fit(test_data, event_handlers=[custom_handler], epochs=3) + assert custom_handler.num_batch == 3 + assert custom_handler.num_epoch == 1 + custom_handler = CustomStopHandler(100, 5) + est.fit(test_data, event_handlers=[custom_handler], epochs=10) + assert custom_handler.num_batch == 5 * 4 + assert custom_handler.num_epoch == 5