diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py index f43f17520654..da1a3915caec 100644 --- a/python/mxnet/gluon/contrib/estimator/estimator.py +++ b/python/mxnet/gluon/contrib/estimator/estimator.py @@ -21,9 +21,8 @@ import copy import warnings -import weakref -from .event_handler import MetricHandler, ValidationHandler, LoggingHandler +from .event_handler import MetricHandler, ValidationHandler, LoggingHandler, StoppingHandler from .event_handler import TrainBegin, EpochBegin, BatchBegin, BatchEnd, EpochEnd, TrainEnd from .... import gluon, autograd from ....context import Context, cpu, gpu, num_gpus @@ -40,16 +39,18 @@ class Estimator(object): Parameters ---------- + net : Block + The model used for training. loss : gluon.loss.Loss or list of gluon.loss.Loss - Loss(objective functions) to calculate during training + Loss(objective functions) to calculate during training. metrics : EvalMetric or list of EvalMetric - Metrics for evaluating models + Metrics for evaluating models. initializer : Initializer - initializer to initialize the network + Initializer to initialize the network. trainer : Trainer - Trainer to apply optimizer on network parameters + Trainer to apply optimizer on network parameters. context : Context or list of Context - device(s) to run the training on + Device(s) to run the training on. """ def __init__(self, net, @@ -70,7 +71,7 @@ def __init__(self, net, def _check_loss(self, loss): if isinstance(loss, gluon.loss.Loss): loss = [loss] - elif isinstance(loss, list) or all([isinstance(l, gluon.loss.Loss) for l in loss]): + elif isinstance(loss, list) and all([isinstance(l, gluon.loss.Loss) for l in loss]): loss = loss else: raise ValueError("loss must be a Loss or a list of Loss, " @@ -122,19 +123,23 @@ def _check_context(self, context): def _initialize(self, initializer): # initialize the network - if initializer: - if self._is_initialized(): - # if already initialized, re-init with user specified initializer - warnings.warn("Network already initialized, re-initializing with %s. " - "You don't need to pass initializer if you already " - "initialized your net." % type(initializer).__name__) - self.net.initialize(init=initializer, ctx=self.context, force_reinit=True) + if not self._is_initialized(): + # net is partially or not initialized, + # initialize with user specified initializer + # if initializer is None, default initializer will be used + # do not re-init layers already initialized + if initializer: + self.net.initialize(init=initializer, ctx=self.context) else: - # initialize with user specified initializer - self.net.initialize(init=initializer, ctx=self.context, force_reinit=False) - else: - if not self._is_initialized(): self.net.initialize(ctx=self.context) + elif initializer: + # net is fully initialized, and user passed not None initializer + # do not force reinitialize, give warning + warnings.warn("Network already fully initialized, skipping initialization. " + "You don't need to pass initializer if you already " + "initialized your net. " + "You can use net.initialize(init=your_initializer, force_reinit=True)" + "to force re-initialize.") def _check_trainer(self, trainer): # handle trainer @@ -157,11 +162,11 @@ def _is_initialized(self): return False return True - def _get_data_and_label(self, batch, ctx): + def _get_data_and_label(self, batch, ctx, batch_axis=0): data = batch[0] label = batch[1] - data = gluon.utils.split_and_load(data, ctx_list=ctx, batch_axis=0) - label = gluon.utils.split_and_load(label, ctx_list=ctx, batch_axis=0) + data = gluon.utils.split_and_load(data, ctx_list=ctx, batch_axis=batch_axis) + label = gluon.utils.split_and_load(label, ctx_list=ctx, batch_axis=batch_axis) return data, label def prepare_loss_and_metrics(self): @@ -183,33 +188,36 @@ def prepare_loss_and_metrics(self): self.train_metrics.append(Loss(loss.name.rstrip('1234567890'))) for metric in self.train_metrics: val_metric = copy.deepcopy(metric) - metric.name = "Train " + metric.name - val_metric.name = "Validation " + val_metric.name + metric.name = "train " + metric.name + val_metric.name = "validation " + val_metric.name self.val_metrics.append(val_metric) return self.train_metrics, self.val_metrics def evaluate(self, val_data, - val_metrics): + val_metrics, + batch_axis=0): """Evaluate model on validation data Parameters ---------- val_data : DataLoader - validation data with data and labels + Validation data loader with data and labels. val_metrics : EvalMetric or list of EvalMetrics - metrics to update validation result + Metrics to update validation result. + batch_axis : int, default 0 + Batch axis to split the validation data into devices. """ + if not isinstance(val_data, gluon.data.DataLoader): + raise ValueError("Estimator only support input as Gluon DataLoader. Alternatively, you " + "can transform your DataIter or any NDArray into Gluon DataLoader. " + "Refer to gluon.data.dataloader") for metric in val_metrics: metric.reset() for _, batch in enumerate(val_data): - if not isinstance(val_data, gluon.data.DataLoader): - raise ValueError("Estimator only support input as Gluon DataLoader. Alternatively, you " - "can transform your DataIter or any NDArray into Gluon DataLoader. " - "Refer to gluon.data.dataloader") - data, label = self._get_data_and_label(batch, self.context) + data, label = self._get_data_and_label(batch, self.context, batch_axis) pred = [self.net(x) for x in data] loss = [self.loss[0](y_hat, y) for y_hat, y in zip(pred, label)] # update metrics @@ -221,30 +229,45 @@ def evaluate(self, def fit(self, train_data, val_data=None, - epochs=1, - event_handlers=None): - """Trains the model on a given dataset for a specified - number of epochs. Also, the batch size is inferred from the - DataLoader's batch_size. + epochs=None, + event_handlers=None, + batches=None, + batch_axis=0): + """Trains the model with a given :py:class:`DataLoader` for a specified + number of epochs or batches. The batch size is inferred from the + data loader's batch_size. Parameters ---------- train_data : DataLoader - training data with data and labels - val_data : DataLoader - validation data with data and labels - epochs : int, default 1 - number of epochs to iterate on the training data. - batch_size : int - number of samples per gradient update. - default will be 32 per device + Training data loader with data and labels. + val_data : DataLoader, default None + Validation data loader with data and labels. + epochs : int, default None + Number of epochs to iterate on the training data. + You can only specify one and only one type of iteration(epochs or batches). event_handlers : EventHandler or list of EventHandler - list of EventHandlers to apply during training - batch_fn : function - custom batch function to extract data and label - from a data batch and load into contexts(devices) + List of :py:class:`EventHandlers` to apply during training. + batches : int, default None + Number of batches to iterate on the training data. + You can only specify one and only one type of iteration(epochs or batches). + batch_axis : int, default 0 + Batch axis to split the training data into devices. """ - self.max_epochs = epochs + if not isinstance(train_data, gluon.data.DataLoader): + raise ValueError("Estimator only support input as Gluon DataLoader. Alternatively, you " + "can transform your DataIter or any NDArray into Gluon DataLoader. " + "Refer to gluon.data.dataloader") + + # must specify one and only one of epochs or batches + if (not epochs) == (not batches): + raise ValueError( + "Fit only support exactly one type of iteration, " + "train by number of epochs or number of batches." + "Please specify one and only one of: epochs or batches.") + + self.max_epoch = epochs + self.max_batch = batches # provide default handlers event_handlers = self._prepare_default_handlers(val_data, event_handlers) @@ -252,23 +275,19 @@ def fit(self, train_data, train_begin, epoch_begin, batch_begin, \ batch_end, epoch_end, train_end = self._categorize_handlers(event_handlers) - # only pass a weak reference to all event handlers - estimator_ref = weakref.proxy(self) + # pass a reference to all event handlers + estimator_ref = self # training begin for handler in train_begin: handler.train_begin(estimator_ref) - for epoch in range(epochs): + while True: # epoch begin for handler in epoch_begin: handler.epoch_begin(estimator_ref) for i, batch in enumerate(train_data): - if not isinstance(train_data, gluon.data.DataLoader): - raise ValueError("Estimator only support input as Gluon DataLoader. Alternatively, you " - "can transform your DataIter or any NDArray into Gluon DataLoader. " - "Refer to gluon.data.dataloader") - data, label = self._get_data_and_label(batch, self.context) + data, label = self._get_data_and_label(batch, self.context, batch_axis) batch_size = batch[0].shape[0] @@ -285,15 +304,22 @@ def fit(self, train_data, self.trainer.step(batch_size) # batch end + + batch_end_result = [] for handler in batch_end: - if handler.batch_end(estimator_ref, batch=batch, - pred=pred, label=label, loss=loss): - break + batch_end_result.append(handler.batch_end(estimator_ref, batch=batch, + pred=pred, label=label, loss=loss)) + # if any handler signaled to stop + if any(batch_end_result): + break # epoch end + epoch_end_result = [] for handler in epoch_end: - if handler.epoch_end(estimator_ref): - break + epoch_end_result.append(handler.epoch_end(estimator_ref)) + # if any handler signaled to stop + if any(epoch_end_result): + break # train end for handler in train_end: @@ -304,6 +330,9 @@ def _prepare_default_handlers(self, val_data, event_handlers): default_handlers = [] train_metrics, val_metrics = self.prepare_loss_and_metrics() + # no need to add to default handler check as StoppingHandler does not use metrics + event_handlers.append(StoppingHandler(self.max_epoch, self.max_batch)) + if not any(isinstance(handler, MetricHandler) for handler in event_handlers): event_handlers.append(MetricHandler(train_metrics=train_metrics)) default_handlers.append("MetricHandler") @@ -319,13 +348,14 @@ def _prepare_default_handlers(self, val_data, event_handlers): default_handlers.append("LoggingHandler") # if there is a mix of user defined event handlers and default event handlers - # they should have the save set of loss and metrics + # they should have the same set of loss and metrics if default_handlers: msg = "You are training with the following default event handlers: %s. " \ "They use loss and metrics from estimator.prepare_loss_and_metrics(). " \ "Please use the same set of metrics for all your other handlers." % \ ", ".join(default_handlers) warnings.warn(msg) + # check if all handlers has the same set of references to loss and metrics references = [] for handler in event_handlers: for attribute in dir(handler): @@ -335,8 +365,10 @@ def _prepare_default_handlers(self, val_data, event_handlers): references += reference else: references.append(reference) + # remove None metric references + references = set([ref for ref in references if ref]) for metric in references: - if metric and metric not in train_metrics + val_metrics: + if metric not in train_metrics + val_metrics: msg = "We have added following default handlers for you: %s and used " \ "estimator.prepare_loss_and_metrics() to pass metrics to " \ "those handlers. Please use the same set of metrics " \ diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py index d8c3c6eaa6aa..ce5890e0bcae 100644 --- a/python/mxnet/gluon/contrib/estimator/event_handler.py +++ b/python/mxnet/gluon/contrib/estimator/event_handler.py @@ -59,17 +59,57 @@ def batch_end(self, estimator, *args, **kwargs): return False +class StoppingHandler(TrainBegin, BatchEnd, EpochEnd): + """Stop conditions to stop training + Stop training if maximum number of batches or epochs + reached. + + Parameters + ---------- + max_epoch : int, default None + Number of maximum epochs to train. + max_batch : int, default None + Number of maximum batches to train. + + """ + + def __init__(self, max_epoch=None, max_batch=None): + self.max_epoch = max_epoch + self.max_batch = max_batch + self.current_batch = 0 + self.current_epoch = 0 + self.stop_training = False + + def train_begin(self, estimator, *args, **kwargs): + self.max_epoch = estimator.max_epoch + self.max_batch = estimator.max_batch + self.current_batch = 0 + self.current_epoch = 0 + + def batch_end(self, estimator, *args, **kwargs): + self.current_batch += 1 + if self.current_batch == self.max_batch: + self.stop_training = True + return self.stop_training + + def epoch_end(self, estimator, *args, **kwargs): + self.current_epoch += 1 + if self.current_epoch == self.max_epoch: + self.stop_training = True + return self.stop_training + + class MetricHandler(EpochBegin, BatchEnd): """Metric Handler that update metric values at batch end :py:class:`MetricHandler` takes model predictions and true labels - and update the metrics, it also update metric wrapper for loss with loss values + and update the metrics, it also update metric wrapper for loss with loss values. Validation loss and metrics will be handled by :py:class:`ValidationHandler` Parameters ---------- train_metrics : List of EvalMetrics - training metrics to be updated at batch end + Training metrics to be updated at batch end. """ def __init__(self, train_metrics): @@ -94,7 +134,7 @@ def batch_end(self, estimator, *args, **kwargs): metric.update(label, pred) -class ValidationHandler(BatchEnd, EpochEnd): +class ValidationHandler(TrainBegin, BatchEnd, EpochEnd): """"Validation Handler that evaluate model on validation dataset :py:class:`ValidationHandler` takes validation dataset, an evaluation function, @@ -104,18 +144,18 @@ class ValidationHandler(BatchEnd, EpochEnd): Parameters ---------- val_data : DataLoader - validation data set to run evaluation + Validation data set to run evaluation. eval_fn : function - a function defines how to run evaluation and - calculate loss and metrics + A function defines how to run evaluation and + calculate loss and metrics. val_metrics : List of EvalMetrics - validation metrics to be updated + Validation metrics to be updated. epoch_period : int, default 1 - how often to run validation at epoch end, by default - validate every epoch + How often to run validation at epoch end, by default + :py:class:`ValidationHandler` validate every epoch. batch_period : int, default None - how often to run validation at batch end, by default - does not validate at batch end + How often to run validation at batch end, by default + :py:class:`ValidationHandler` does not validate at batch end. """ def __init__(self, @@ -129,25 +169,36 @@ def __init__(self, self.epoch_period = epoch_period self.batch_period = batch_period self.val_metrics = val_metrics - self.num_batches = 0 - self.num_epochs = 0 + self.current_batch = 0 + self.current_epoch = 0 # order to be called among all callbacks # validation metrics need to be calculated before other callbacks can access them self.priority = -np.Inf + self.logger = logging.getLogger(__name__) + + def train_begin(self, estimator, *args, **kwargs): + # reset epoch and batch counter + self.current_batch = 0 + self.current_epoch = 0 def batch_end(self, estimator, *args, **kwargs): - if self.batch_period and self.num_batches % self.batch_period == 0: + self.current_batch += 1 + if self.batch_period and self.current_batch % self.batch_period == 0: self.eval_fn(val_data=self.val_data, val_metrics=self.val_metrics) - self.num_batches += 1 + msg = '[Epoch %d] ValidationHandler: %d batches reached, ' \ + % (self.current_epoch, self.current_batch) + for monitor in self.val_metrics: + name, value = monitor.get() + msg += '%s: %.4f, ' % (name, value) + self.logger.info(msg.rstrip(',')) def epoch_end(self, estimator, *args, **kwargs): - if self.num_epochs % self.epoch_period == 0: + self.current_epoch += 1 + if self.epoch_period and self.current_epoch % self.epoch_period == 0: self.eval_fn(val_data=self.val_data, val_metrics=self.val_metrics) - self.num_epochs += 1 - class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, BatchEnd): """Basic Logging Handler that applies to every Gluon estimator by default. @@ -158,25 +209,28 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat Parameters ---------- file_name : str - file name to save the logs + File name to save the logs. file_location : str - file location to save the logs - verbose : int, default LOG_VERBOSITY_PER_EPOCH - Limit the granularity of metrics displayed during training process - verbose=LOG_VERBOSITY_PER_EPOCH: display metrics every epoch - verbose=LOG_VERBOSITY_PER_BATCH: display metrics every batch + File location to save the logs. + filemode : str, default 'a' + Logging file mode, default using append mode. + verbose : int, default LOG_PER_EPOCH + Limit the granularity of metrics displayed during training process. + verbose=LOG_PER_EPOCH: display metrics every epoch + verbose=LOG_PER_BATCH: display metrics every batch train_metrics : list of EvalMetrics - training metrics to be logged, logged at batch end, epoch end, train end + Training metrics to be logged, logged at batch end, epoch end, train end. val_metrics : list of EvalMetrics - validation metrics to be logged, logged at epoch end, train end + Validation metrics to be logged, logged at epoch end, train end. """ - LOG_VERBOSITY_PER_EPOCH = 1 - LOG_VERBOSITY_PER_BATCH = 2 + LOG_PER_EPOCH = 1 + LOG_PER_BATCH = 2 def __init__(self, file_name=None, file_location=None, - verbose=LOG_VERBOSITY_PER_EPOCH, + filemode='a', + verbose=LOG_PER_EPOCH, train_metrics=None, val_metrics=None): super(LoggingHandler, self).__init__() @@ -184,18 +238,18 @@ def __init__(self, file_name=None, self.logger.setLevel(logging.INFO) stream_handler = logging.StreamHandler() self.logger.addHandler(stream_handler) - if verbose not in [self.LOG_VERBOSITY_PER_EPOCH, self.LOG_VERBOSITY_PER_BATCH]: - raise ValueError("verbose level must be either LOG_VERBOSITY_PER_EPOCH or " - "LOG_VERBOSITY_PER_BATCH, received %s. " - "E.g: LoggingHandler(verbose=LoggingHandler.LOG_VERBOSITY_PER_EPOCH)" - % verbose) - self.verbose = verbose # save logger to file only if file name or location is specified if file_name or file_location: file_name = file_name or 'estimator_log' file_location = file_location or './' - file_handler = logging.FileHandler(os.path.join(file_location, file_name)) + file_handler = logging.FileHandler(os.path.join(file_location, file_name), mode=filemode) self.logger.addHandler(file_handler) + if verbose not in [self.LOG_PER_EPOCH, self.LOG_PER_BATCH]: + raise ValueError("verbose level must be either LOG_PER_EPOCH or " + "LOG_PER_BATCH, received %s. " + "E.g: LoggingHandler(verbose=LoggingHandler.LOG_PER_EPOCH)" + % verbose) + self.verbose = verbose self.train_metrics = train_metrics or [] self.val_metrics = val_metrics or [] self.batch_index = 0 @@ -213,158 +267,339 @@ def train_begin(self, estimator, *args, **kwargs): self.logger.info("Training begin: using optimizer %s " "with current learning rate %.4f ", optimizer, lr) - self.logger.info("Train for %d epochs.", estimator.max_epochs) + if estimator.max_epoch: + self.logger.info("Train for %d epochs.", estimator.max_epoch) + else: + self.logger.info("Train for %d batches.", estimator.max_batch) + # reset all counters + self.current_epoch = 0 + self.batch_index = 0 + self.processed_samples = 0 def train_end(self, estimator, *args, **kwargs): train_time = time.time() - self.train_start - msg = 'Train finished using total %ds with %d epochs.' % (train_time, self.current_epoch) + msg = 'Train finished using total %ds with %d epochs. ' % (train_time, self.current_epoch) # log every result in train stats including train/validation loss & metrics for metric in self.train_metrics + self.val_metrics: name, value = metric.get() - msg += '%s : %.4f ' % (name, value) - self.logger.info(msg) - for handler in self.logger.handlers: + msg += '%s: %.4f, ' % (name, value) + self.logger.info(msg.rstrip(', ')) + # make a copy of handler list and remove one by one + # as removing handler will edit the handler list + for handler in self.logger.handlers[:]: handler.close() self.logger.removeHandler(handler) logging.shutdown() def batch_begin(self, estimator, *args, **kwargs): - if self.verbose == self.LOG_VERBOSITY_PER_BATCH: + if self.verbose == self.LOG_PER_BATCH: self.batch_start = time.time() def batch_end(self, estimator, *args, **kwargs): - if self.verbose == self.LOG_VERBOSITY_PER_BATCH: + if self.verbose == self.LOG_PER_BATCH: batch_time = time.time() - self.batch_start - msg = '[Epoch %d] [Batch %d] ' % (self.current_epoch, self.batch_index) + msg = '[Epoch %d][Batch %d]' % (self.current_epoch, self.batch_index) self.processed_samples += kwargs['batch'][0].shape[0] msg += '[Samples %s] ' % (self.processed_samples) msg += 'time/batch: %.3fs ' % batch_time for metric in self.train_metrics: # only log current training loss & metric after each batch name, value = metric.get() - msg += '%s : %.4f ' % (name, value) - self.logger.info(msg) - self.batch_index += 1 + msg += '%s: %.4f, ' % (name, value) + self.logger.info(msg.rstrip(', ')) + self.batch_index += 1 def epoch_begin(self, estimator, *args, **kwargs): - if self.verbose >= self.LOG_VERBOSITY_PER_EPOCH: + if self.verbose >= self.LOG_PER_EPOCH: self.epoch_start = time.time() + self.logger.info("[Epoch %d] Begin, current learning rate: %.4f", + self.current_epoch, estimator.trainer.learning_rate) def epoch_end(self, estimator, *args, **kwargs): - if self.verbose >= self.LOG_VERBOSITY_PER_EPOCH: + if self.verbose >= self.LOG_PER_EPOCH: epoch_time = time.time() - self.epoch_start - msg = '\n[Epoch %d] finished in %.3fs: ' % (self.current_epoch, epoch_time) + msg = '[Epoch %d] Finished in %.3fs, ' % (self.current_epoch, epoch_time) for monitor in self.train_metrics + self.val_metrics: name, value = monitor.get() - msg += '%s : %.4f ' % (name, value) - self.logger.info(msg) - self.current_epoch += 1 - self.batch_index = 0 + msg += '%s: %.4f, ' % (name, value) + self.logger.info(msg.rstrip(', ')) + self.current_epoch += 1 + self.batch_index = 0 -class CheckpointHandler(BatchEnd, EpochEnd): - """Save the model after every epoch. +class CheckpointHandler(TrainBegin, BatchEnd, EpochEnd): + """Save the model after user define period - :py:class:`CheckpointHandler` save the network parameters every epoch + :py:class:`CheckpointHandler` saves the network architecture after first batch if the model + can be fully hybridized, saves model parameters and trainer states after user defined period, + default saves every epoch. Parameters ---------- - filepath : str - file name to save the parameters, it can contain directories, - for example: ./saved_model/resnet.params - monitor: EvalMetric - the metrics to monitor + model_dir : str + File directory to save all the model related files including model architecture, + model parameters, and trainer states. + model_prefix : str default 'model' + Prefix to add for all checkpoint file names. + monitor: EvalMetric, default None + The metrics to monitor and determine if model has improved verbose: int, default 0 - verbosity mode - save_best_only: bool - if True, only save the parameters if monitored value improved + Verbosity mode, 1 means inform user every time a checkpoint is saved + save_best: bool, default False + If True, monitor must not be None, :py:class:`CheckpointHandler` will save the + model parameters and trainer states with the best monitored value. mode: str, default 'auto' - one of {auto, min, max}, if `save_best_only=True`, the comparison to make - and determine if the monitored value has improved - period: int, default 1 - intervals between saving the network + One of {auto, min, max}, if `save_best=True`, the comparison to make + and determine if the monitored value has improved. if 'auto' mode, + :py:class:`CheckpointHandler` will try to use min or max based on + the monitored metric name. + epoch_period: int, default 1 + Epoch intervals between saving the network. By default, checkpoints are + saved every epoch. + batch_period: int, default None + Batch intervals between saving the network. + By default, checkpoints are not saved based on the number of batches. + max_checkpoints : int, default 5 + Maximum number of checkpoint files to keep in the model_dir, older checkpoints + will be removed. Best checkpoint file is not counted. + resume_from_checkpoint : bool, default False + Whether to resume training from checkpoint in model_dir. If True and checkpoints + found, :py:class:`CheckpointHandler` will load net parameters and trainer states, + and train the remaining of epochs and batches. """ def __init__(self, - filepath, + model_dir, + model_prefix='model', monitor=None, verbose=0, - save_best_only=False, + save_best=False, mode='auto', epoch_period=1, - batch_period=None): + batch_period=None, + max_checkpoints=5, + resume_from_checkpoint=False): self.monitor = monitor self.verbose = verbose - self.filepath = filepath - self.save_best_only = save_best_only - if self.save_best_only and not isinstance(self.monitor, EvalMetric): + if not os.path.exists(model_dir): + os.makedirs(model_dir) + self.model_dir = model_dir + self.model_prefix = model_prefix + self.save_best = save_best + if self.save_best and not isinstance(self.monitor, EvalMetric): raise ValueError("To save best model only, please provide one of the metric objects as monitor, " "You can get these objects using estimator.prepare_loss_and_metric()") self.epoch_period = epoch_period self.batch_period = batch_period - self.num_batches = 0 - self.num_epochs = 0 + self.current_batch = 0 + self.current_epoch = 0 + self.max_checkpoints = max_checkpoints + self.resume_from_checkpoint = resume_from_checkpoint + self.saved_checkpoints = [] self.logger = logging.getLogger(__name__) - - if mode not in ['auto', 'min', 'max']: - warnings.warn('ModelCheckpoint mode %s is unknown, ' - 'fallback to auto mode.' % (mode), - RuntimeWarning) - mode = 'auto' - - if mode == 'min': - self.monitor_op = np.less - self.best = np.Inf - elif mode == 'max': - self.monitor_op = np.greater - self.best = -np.Inf - else: - # use greater for accuracy and less otherwise - if 'acc' in self.monitor.get()[0].lower(): + if self.save_best: + if mode not in ['auto', 'min', 'max']: + warnings.warn('ModelCheckpoint mode %s is unknown, ' + 'fallback to auto mode. CheckpointHandler will use' + 'max mode for f1 and accuracy metric comparison and ' + 'use min mode other wise' % (mode), + RuntimeWarning) + mode = 'auto' + + if mode == 'min': + self.monitor_op = np.less + self.best = np.Inf + elif mode == 'max': self.monitor_op = np.greater self.best = -np.Inf else: - self.monitor_op = np.less - self.best = np.Inf + # use greater for accuracy and f1 and less otherwise + if 'acc' or 'f1' in self.monitor.get()[0].lower(): + self.logger.info("`greater` operator will be used to determine " + "if %s has improved, please use `min` for mode " + "if you want otherwise", self.monitor.get()[0]) + self.monitor_op = np.greater + else: + self.logger.info("`less` operator will be used to determine " + "if %s has improved, please use `max` for mode " + "if you want otherwise", self.monitor.get()[0]) + self.monitor_op = np.less + + def train_begin(self, estimator, *args, **kwargs): + # reset all counters + self.current_epoch = 0 + self.current_batch = 0 + if self.save_best: + self.best = np.Inf if self.monitor_op == np.less else -np.Inf + if self.resume_from_checkpoint: + error_msg = "To use resume from checkpoint, you must only specify " \ + "the same type of period you used for training." \ + "For example, if you are training based on number of epochs," \ + "you must save only based on epochs, and set batch_period to None." + if estimator.max_batch: + assert self.batch_period, error_msg + assert not self.epoch_period, error_msg + if estimator.max_epoch: + assert self.epoch_period, error_msg + assert not self.batch_period, error_msg + + self._resume_from_checkpoint(estimator) def batch_end(self, estimator, *args, **kwargs): - self._save_checkpoint(estimator.net, "Batch", self.num_batches) - self.num_batches += 1 + # only save symbol once after first batch + if self.current_batch == 0: + self._save_symbol(estimator) + if self.batch_period and (self.current_batch + 1) % self.batch_period == 0: + self._save_checkpoint(estimator) + self.current_batch += 1 def epoch_end(self, estimator, *args, **kwargs): - self._save_checkpoint(estimator.net, "Epoch", self.num_epochs) - self.num_epochs += 1 - - def _save_checkpoint(self, net, period_name, period_value): - # add extension for weights - if '.params' not in self.filepath: - self.filepath += '.params' - if self.num_epochs % self.epoch_period == 0: - if self.save_best_only: - monitor_name, monitor_value = self.monitor.get() - # check if monitor exists in train stats - if np.isnan(monitor_value): - warnings.warn(RuntimeWarning('%s is not updated, make sure you pass one of the metric objects' - 'as monitor, you can use estimator.prepare_loss_and_metrics to' - 'create all metric objects', monitor_name)) - net.save_parameters(self.filepath) + if self.epoch_period and (self.current_epoch + 1) % self.epoch_period == 0: + self._save_checkpoint(estimator) + self.current_epoch += 1 + + def _save_checkpoint(self, estimator): + # if resumed from checkpoint, increment checkpoint number + if self.resume_from_checkpoint: + save_epoch_number = self.current_epoch + self.trained_epoch + 1 + if estimator.max_epoch: + # checkpoint saved at epoch end, batch number already incremented + save_batch_number = self.current_batch + self.trained_batch + else: + save_batch_number = self.current_batch + self.trained_batch + 1 + else: + save_epoch_number = self.current_epoch + save_batch_number = self.current_batch + prefix = "%s-epoch%dbatch%d" % (self.model_prefix, save_epoch_number, save_batch_number) + self._save_params_and_trainer(estimator, prefix) + if self.verbose > 0: + self.logger.info('[Epoch %d] CheckpointHandler: trained total %d batches, ' + 'saving model at %s with prefix: %s', + self.current_epoch, self.current_batch + 1, self.model_dir, prefix) + + if self.save_best: + monitor_name, monitor_value = self.monitor.get() + # check if monitor exists in train stats + if np.isnan(monitor_value): + warnings.warn(RuntimeWarning('Skipping save best because %s is not updated, make sure you ' + 'pass one of the metric objects as monitor, ' + 'you can use estimator.prepare_loss_and_metrics to' + 'create all metric objects', monitor_name)) + else: + if self.monitor_op(monitor_value, self.best): + prefix = self.model_prefix + '-best' + self._save_params_and_trainer(estimator, prefix) + self.best = monitor_value + if self.verbose > 0: + self.logger.info('[Epoch %d] CheckpointHandler: ' + '%s improved from %0.5f to %0.5f, ' + 'updating best model at %s with prefix: %s', + self.current_epoch, monitor_name, + self.best, monitor_value, self.model_dir, prefix) else: - if self.monitor_op(monitor_value, self.best): - if self.verbose > 0: - self.logger.info('\n[%s %d] %s improved from %0.5f to %0.5f,' - ' saving model to %s', - period_name, period_value, monitor_name, - self.best, monitor_value, self.filepath) - self.best = monitor_value - net.save_parameters(self.filepath) - else: - if self.verbose > 0: - self.logger.info('\n[%s %d] %s did not improve from %0.5f, skipping save model', - period_name, period_value, monitor_name, self.best) + if self.verbose > 0: + self.logger.info('[Epoch %d] CheckpointHandler: ' + '%s did not improve from %0.5f, ' + 'skipping updating best model', + self.current_batch, monitor_name, + self.best) + + def _save_symbol(self, estimator): + symbol_file = os.path.join(self.model_dir, self.model_prefix + '-symbol.json') + if hasattr(estimator.net, '_cached_graph'): + sym = estimator.net._cached_graph[1] + sym.save(symbol_file) + else: + self.logger.info("Model architecture(symbol file) is not saved, please use HybridBlock" + "to construct your model, can call net.hybridize() before passing to" + "Estimator in order to save model architecture as %s.", symbol_file) + + def _save_params_and_trainer(self, estimator, file_prefix): + param_file = os.path.join(self.model_dir, file_prefix + '.params') + trainer_file = os.path.join(self.model_dir, file_prefix + '.states') + estimator.net.save_parameters(param_file) + estimator.trainer.save_states(trainer_file) + + # only count checkpoints with epoch or batch number in file name + if 'best' not in file_prefix: + self.saved_checkpoints.append(file_prefix) + # remove old checkpoint when max number of checkpoints reached + if len(self.saved_checkpoints) > self.max_checkpoints: + prefix = self.saved_checkpoints.pop(0) + for fname in os.listdir(self.model_dir): + if fname.startswith(prefix): + os.remove(os.path.join(self.model_dir, fname)) + + def _resume_from_checkpoint(self, estimator): + prefix = self.model_prefix + '-epoch' + self.trained_epoch = self._find_max_iteration( + dir=self.model_dir, + prefix=prefix, + start='epoch', + end='batch', + saved_checkpoints=self.saved_checkpoints) + prefix += str(self.trained_epoch) + self.trained_batch = self._find_max_iteration( + dir=self.model_dir, + prefix=prefix, + start='batch', + end='.params') + + if self.trained_epoch == -1: + msg = "CheckpointHandler: No checkpoint found, training from scratch for " + if estimator.max_batch: + msg += "%d batches" % estimator.max_batch else: - if self.verbose > 0: - logging.info('\n%s %d: saving model to %s', period_name, period_value, self.filepath) - net.save_parameters(self.filepath) + msg += "%d epochs" % estimator.max_epoch + self.logger.info(msg) + else: + msg = "CheckpointHandler: Checkpoint resumed from epoch %d batch %d, " \ + "continue to train for " % (self.trained_epoch, self.trained_batch) + # change maximum number of epoch or batch to train if resumed from epoch checkpoint + if estimator.max_epoch: + if self.trained_epoch >= estimator.max_epoch - 1: + raise ValueError("Found checkpoint with maximum number of epoch %d reached, please specify " + "resume_from_checkpoint=False (default value) if you wan to train from scratch." + % estimator.max_epoch) + estimator.max_epoch = estimator.max_epoch - self.trained_epoch - 1 + msg += "%d epochs " % estimator.max_epoch + if estimator.max_batch: + if self.trained_batch >= estimator.max_batch - 1: + raise ValueError("Found checkpoint with maximum number of batch %d reached, please specify" + "resume_from_checkpoint=False (default value) if you wan to train from scratch." + % self.trained_batch) + estimator.max_batch = estimator.max_batch - self.trained_batch - 1 + msg += "%d batches " % estimator.max_batch + # load checkpoint + param_file = "%s-epoch%dbatch%d.params" % (self.model_prefix, self.trained_epoch, self.trained_batch) + param_file = os.path.join(self.model_dir, param_file) + trainer_file = "%s-epoch%dbatch%d.states" % (self.model_prefix, self.trained_epoch, self.trained_batch) + trainer_file = os.path.join(self.model_dir, trainer_file) + assert os.path.exists(param_file), "Failed to load checkpoint, %s does not exist" % param_file + assert os.path.exists(trainer_file), "Failed to load checkpoint, %s does not exist" % trainer_file + estimator.net.load_parameters(param_file, ctx=estimator.context) + estimator.trainer.load_states(trainer_file) + self.logger.warning(msg) + + def _find_max_iteration(self, dir, prefix, start, end, saved_checkpoints=None): + error_msg = "Error parsing checkpoint file, please check your " \ + "checkpoints have the format: " \ + "{model_name}-epoch{epoch_number}batch{batch_number}.params, " \ + "there should also be a .states file for each .params file " + max_iter = -1 + for fname in os.listdir(dir): + if fname.startswith(prefix) and '.params' in fname: + if saved_checkpoints: + # save prefix of existing checkpoints + saved_checkpoints.append(fname[:fname.find('.params')]) + try: + # find trained number of epoch + iter = int(fname[fname.find(start) + len(start): fname.find(end)]) + if iter > max_iter: + max_iter = iter + except ValueError: + raise ValueError(error_msg) + return max_iter class EarlyStoppingHandler(TrainBegin, EpochEnd, TrainEnd): @@ -372,19 +607,18 @@ class EarlyStoppingHandler(TrainBegin, EpochEnd, TrainEnd): Parameters ---------- - estimator : Estimator - The :py:class:`Estimator` to get training statistics monitor: EvalMetric - the metrics to monitor + The metric to monitor, and stop training if this metric does not improve. min_delta: float, default 0 - minimal change in monitored value to be considered as an improvement + Minimal change in monitored value to be considered as an improvement. patience: int, default 0 - number of epochs to wait for improvement before terminate training + Number of epochs to wait for improvement before terminate training. mode: str, default 'auto' - one of {auto, min, max}, the comparison to make - and determine if the monitored value has improved + One of {auto, min, max}, if `save_best_only=True`, the comparison to make + and determine if the monitored value has improved. if 'auto' mode, checkpoint + handler will try to use min or max based on the monitored metric name. baseline: float - baseline value to compare the monitored value with + Baseline value to compare the monitored value with. """ def __init__(self, @@ -404,13 +638,16 @@ def __init__(self, self.min_delta = min_delta self.wait = 0 self.stopped_epoch = 0 - self.num_epochs = 0 + self.current_epoch = 0 self.stop_training = False self.logger = logging.getLogger(__name__) if mode not in ['auto', 'min', 'max']: - warnings.warn(RuntimeWarning('EarlyStopping mode %s is unknown, ' - 'fallback to auto mode.', mode)) + warnings.warn('EarlyStopping mode %s is unknown, ' + 'fallback to auto mode. CheckpointHandler will use' + 'max mode for f1 and accuracy metric comparison and ' + 'use min mode other wise' % (mode), + RuntimeWarning) mode = 'auto' if mode == 'min': @@ -418,9 +655,15 @@ def __init__(self, elif mode == 'max': self.monitor_op = np.greater else: - if 'acc' in self.monitor.get()[0].lower(): + if 'acc' or 'f1' in self.monitor.get()[0].lower(): + self.logger.info("`greater` operator is used to determine " + "if %s has improved, please use `min` for mode " + "if you want otherwise", self.monitor.get()[0]) self.monitor_op = np.greater else: + self.logger.info("`less` operator is used to determine " + "if %s has improved, please use `max` for mode " + "if you want otherwise", self.monitor.get()[0]) self.monitor_op = np.less if self.monitor_op == np.greater: @@ -431,6 +674,8 @@ def __init__(self, def train_begin(self, estimator, *args, **kwargs): self.wait = 0 self.stopped_epoch = 0 + self.current_epoch = 0 + self.stop_training = False if self.baseline is not None: self.best = self.baseline else: @@ -449,11 +694,12 @@ def epoch_end(self, estimator, *args, **kwargs): else: self.wait += 1 if self.wait >= self.patience: - self.stopped_epoch = self.num_epochs + self.stopped_epoch = self.current_epoch self.stop_training = True + self.current_epoch += 1 return self.stop_training def train_end(self, estimator, *args, **kwargs): if self.stopped_epoch > 0: - self.logger.info('Epoch %d: early stopping due to %s not improving', + self.logger.info('[Epoch %d] EarlyStoppingHanlder: early stopping due to %s not improving', self.stopped_epoch, self.monitor.get()[0]) diff --git a/tests/python/unittest/test_gluon_estimator.py b/tests/python/unittest/test_gluon_estimator.py index b25baa255165..d2e8c082aa08 100644 --- a/tests/python/unittest/test_gluon_estimator.py +++ b/tests/python/unittest/test_gluon_estimator.py @@ -139,7 +139,13 @@ def test_initializer(): initializer=mx.init.MSRAPrelu(), trainer=trainer, context=ctx) - assert 'Network already initialized' in str(w[-1].message) + assert 'Network already fully initialized' in str(w[-1].message) + # net partially initialized, fine tuning use case + net = gluon.model_zoo.vision.resnet18_v1(pretrained=True, ctx=ctx) + net.output = gluon.nn.Dense(10) #last layer not initialized + est = Estimator(net, loss=loss, metrics=acc, context=ctx) + dataset = gluon.data.ArrayDataset(mx.nd.zeros((10, 3, 224, 224)), mx.nd.zeros((10, 10))) + train_data = gluon.data.DataLoader(dataset=dataset, batch_size=5) est.fit(train_data=train_data, epochs=num_epochs) @@ -335,6 +341,7 @@ def test_default_handlers(): assert 'You are training with the' in str(w[-1].message) # handler with prepared loss and metrics + # use mix of default and user defined handlers train_metrics, val_metrics = est.prepare_loss_and_metrics() logging = LoggingHandler(train_metrics=train_metrics, val_metrics=val_metrics) with warnings.catch_warnings(record=True) as w: @@ -344,21 +351,21 @@ def test_default_handlers(): assert 'MetricHandler' in str(w[-1].message) # handler with all user defined metrics - val_metrics = [mx.metric.RMSE("val acc")] + # use mix of default and user defined handlers metric = MetricHandler(train_metrics=[train_acc]) - logging = LoggingHandler(train_metrics=train_metrics, val_metrics=val_metrics) + logging = LoggingHandler(train_metrics=[train_acc], val_metrics=[mx.metric.RMSE("val acc")]) est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[metric, logging]) # handler with mixed metrics, some handler use metrics prepared by estimator # some handler use metrics user prepared - val_metrics = [mx.metric.RMSE("val acc")] - logging = LoggingHandler(train_metrics=train_metrics, val_metrics=val_metrics) + logging = LoggingHandler(train_metrics=train_metrics, val_metrics=[mx.metric.RMSE("val acc")]) with assert_raises(ValueError): est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[logging]) # test handler order + train_metrics, val_metrics = est.prepare_loss_and_metrics() early_stopping = EarlyStoppingHandler(monitor=val_metrics[0]) handlers = est._prepare_default_handlers(val_data=None, event_handlers=[early_stopping]) - assert len(handlers) == 3 + assert len(handlers) == 4 assert isinstance(handlers[0], MetricHandler) - assert isinstance(handlers[2], LoggingHandler) + assert isinstance(handlers[3], LoggingHandler) diff --git a/tests/python/unittest/test_gluon_event_handler.py b/tests/python/unittest/test_gluon_event_handler.py index cdb4264e18a0..7ea5ff3f4b62 100644 --- a/tests/python/unittest/test_gluon_event_handler.py +++ b/tests/python/unittest/test_gluon_event_handler.py @@ -16,7 +16,6 @@ # under the License. import os -import tempfile import mxnet as mx from common import TemporaryDirectory @@ -25,64 +24,113 @@ from mxnet.gluon.contrib.estimator import estimator, event_handler -def _get_test_network(): - net = nn.Sequential() - net.add(nn.Dense(128, activation='relu', in_units=100, flatten=False), - nn.Dense(64, activation='relu', in_units=128), - nn.Dense(10, activation='relu', in_units=64)) +def _get_test_network(net=nn.Sequential()): + net.add(nn.Dense(128, activation='relu', flatten=False), + nn.Dense(64, activation='relu'), + nn.Dense(10, activation='relu')) return net def _get_test_data(): data = nd.ones((32, 100)) - label = nd.random.randint(0, 10, (32, 1)) + label = nd.zeros((32, 1)) data_arr = mx.gluon.data.dataset.ArrayDataset(data, label) - return mx.gluon.data.DataLoader(data_arr, batch_size=32) + return mx.gluon.data.DataLoader(data_arr, batch_size=8) def test_checkpoint_handler(): - tmpdir = tempfile.mkdtemp() - file_path = os.path.join(tmpdir, "model.params") - test_data = _get_test_data() + with TemporaryDirectory() as tmpdir: + model_prefix = 'test_epoch' + file_path = os.path.join(tmpdir, model_prefix) + test_data = _get_test_data() - save_best_only = False - mode = 'auto' + net = _get_test_network() + ce_loss = loss.SoftmaxCrossEntropyLoss() + acc = mx.metric.Accuracy() + est = estimator.Estimator(net, loss=ce_loss, metrics=acc) + checkpoint_handler = event_handler.CheckpointHandler(model_dir=tmpdir, + model_prefix=model_prefix, + monitor=acc, + save_best=True, + epoch_period=1) + est.fit(test_data, event_handlers=[checkpoint_handler], epochs=1) + assert checkpoint_handler.current_epoch == 1 + assert checkpoint_handler.current_batch == 4 + assert os.path.isfile(file_path + '-best.params') + assert os.path.isfile(file_path + '-best.states') + assert os.path.isfile(file_path + '-epoch0batch4.params') + assert os.path.isfile(file_path + '-epoch0batch4.states') + + model_prefix = 'test_batch' + file_path = os.path.join(tmpdir, model_prefix) + net = _get_test_network(nn.HybridSequential()) + net.hybridize() + est = estimator.Estimator(net, loss=ce_loss, metrics=acc) + checkpoint_handler = event_handler.CheckpointHandler(model_dir=tmpdir, + model_prefix=model_prefix, + epoch_period=None, + batch_period=2, + max_checkpoints=2) + est.fit(test_data, event_handlers=[checkpoint_handler], batches=10) + assert checkpoint_handler.current_batch == 10 + assert checkpoint_handler.current_epoch == 3 + assert not os.path.isfile(file_path + 'best.params') + assert not os.path.isfile(file_path + 'best.states') + assert not os.path.isfile(file_path + '-epoch0batch0.params') + assert not os.path.isfile(file_path + '-epoch0batch0.states') + assert os.path.isfile(file_path + '-symbol.json') + assert os.path.isfile(file_path + '-epoch1batch7.params') + assert os.path.isfile(file_path + '-epoch1batch7.states') + assert os.path.isfile(file_path + '-epoch2batch9.params') + assert os.path.isfile(file_path + '-epoch2batch9.states') + +def test_resume_checkpoint(): + with TemporaryDirectory() as tmpdir: + model_prefix = 'test_net' + file_path = os.path.join(tmpdir, model_prefix) + test_data = _get_test_data() - net = _get_test_network() - ce_loss = loss.SoftmaxCrossEntropyLoss() - ce_loss_metric = mx.metric.Loss(ce_loss.name) - acc = mx.metric.Accuracy() - est = estimator.Estimator(net, loss=ce_loss, metrics=acc) - checkpoint_handler = [event_handler.CheckpointHandler(file_path, - monitor=acc, - save_best_only=save_best_only, - mode=mode)] - est.fit(test_data, event_handlers=checkpoint_handler, epochs=1) - assert os.path.isfile(file_path) - os.remove(file_path) + net = _get_test_network() + ce_loss = loss.SoftmaxCrossEntropyLoss() + acc = mx.metric.Accuracy() + est = estimator.Estimator(net, loss=ce_loss, metrics=acc) + checkpoint_handler = event_handler.CheckpointHandler(model_dir=tmpdir, + model_prefix=model_prefix, + monitor=acc, + max_checkpoints=1) + est.fit(test_data, event_handlers=[checkpoint_handler], epochs=2) + assert os.path.isfile(file_path + '-epoch1batch8.params') + assert os.path.isfile(file_path + '-epoch1batch8.states') + checkpoint_handler = event_handler.CheckpointHandler(model_dir=tmpdir, + model_prefix=model_prefix, + monitor=acc, + max_checkpoints=1, + resume_from_checkpoint=True) + est.fit(test_data, event_handlers=[checkpoint_handler], epochs=5) + # should only continue to train 3 epochs and last checkpoint file is epoch4 + assert est.max_epoch == 3 + assert os.path.isfile(file_path + '-epoch4batch20.states') def test_early_stopping(): test_data = _get_test_data() - mode = 'max' - patience = 0 - net = _get_test_network() ce_loss = loss.SoftmaxCrossEntropyLoss() acc = mx.metric.Accuracy() est = estimator.Estimator(net, loss=ce_loss, metrics=acc) - early_stopping = [event_handler.EarlyStoppingHandler(monitor=acc, - patience=patience, - mode=mode)] - est.fit(test_data, event_handlers=early_stopping, epochs=3) + early_stopping = event_handler.EarlyStoppingHandler(monitor=acc, + patience=0, + mode='min') + est.fit(test_data, event_handlers=[early_stopping], epochs=5) + assert early_stopping.current_epoch == 2 + assert early_stopping.stopped_epoch == 1 - mode = 'auto' - patience = 2 - early_stopping = [event_handler.EarlyStoppingHandler(monitor=acc, - patience=patience, - mode=mode)] - est.fit(test_data, event_handlers=early_stopping, epochs=1) + early_stopping = event_handler.EarlyStoppingHandler(monitor=acc, + patience=2, + mode='auto') + est.fit(test_data, event_handlers=[early_stopping], epochs=1) + assert early_stopping.current_epoch == 1 def test_logging(): @@ -96,9 +144,55 @@ def test_logging(): acc = mx.metric.Accuracy() est = estimator.Estimator(net, loss=ce_loss, metrics=acc) train_metrics, val_metrics = est.prepare_loss_and_metrics() - logging_handler = [event_handler.LoggingHandler(file_name=file_name, - file_location=tmpdir, - train_metrics=train_metrics, - val_metrics=val_metrics)] - est.fit(test_data, event_handlers=logging_handler, epochs=1) + logging_handler = event_handler.LoggingHandler(file_name=file_name, + file_location=tmpdir, + train_metrics=train_metrics, + val_metrics=val_metrics) + est.fit(test_data, event_handlers=[logging_handler], epochs=3) + assert logging_handler.batch_index == 0 + assert logging_handler.current_epoch == 3 assert os.path.isfile(output_dir) + + +def test_custom_handler(): + class CustomStopHandler(event_handler.TrainBegin, + event_handler.BatchEnd, + event_handler.EpochEnd): + def __init__(self, batch_stop=None, epoch_stop=None): + self.batch_stop = batch_stop + self.epoch_stop = epoch_stop + self.num_batch = 0 + self.num_epoch = 0 + self.stop_training = False + + def train_begin(self, estimator, *args, **kwargs): + self.num_batch = 0 + self.num_epoch = 0 + + def batch_end(self, estimator, *args, **kwargs): + self.num_batch += 1 + if self.num_batch == self.batch_stop: + self.stop_training = True + return self.stop_training + + def epoch_end(self, estimator, *args, **kwargs): + self.num_epoch += 1 + if self.num_epoch == self.epoch_stop: + self.stop_training = True + return self.stop_training + + # total data size is 32, batch size is 8 + # 4 batch per epoch + test_data = _get_test_data() + net = _get_test_network() + ce_loss = loss.SoftmaxCrossEntropyLoss() + acc = mx.metric.Accuracy() + est = estimator.Estimator(net, loss=ce_loss, metrics=acc) + custom_handler = CustomStopHandler(3, 2) + est.fit(test_data, event_handlers=[custom_handler], epochs=3) + assert custom_handler.num_batch == 3 + assert custom_handler.num_epoch == 1 + custom_handler = CustomStopHandler(100, 5) + est.fit(test_data, event_handlers=[custom_handler], epochs=10) + assert custom_handler.num_batch == 5 * 4 + assert custom_handler.num_epoch == 5