From d80464826db1bd308da3c3b440da1c93a166e179 Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Sun, 5 May 2019 17:46:01 -0700 Subject: [PATCH 1/9] address comments --- .../gluon/contrib/estimator/estimator.py | 84 ++++++++++--------- .../gluon/contrib/estimator/event_handler.py | 70 +++++++++++----- tests/python/unittest/test_gluon_estimator.py | 8 +- .../unittest/test_gluon_event_handler.py | 13 ++- 4 files changed, 115 insertions(+), 60 deletions(-) diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py index f43f17520654..d7671cc8e4ce 100644 --- a/python/mxnet/gluon/contrib/estimator/estimator.py +++ b/python/mxnet/gluon/contrib/estimator/estimator.py @@ -21,7 +21,6 @@ import copy import warnings -import weakref from .event_handler import MetricHandler, ValidationHandler, LoggingHandler from .event_handler import TrainBegin, EpochBegin, BatchBegin, BatchEnd, EpochEnd, TrainEnd @@ -40,6 +39,8 @@ class Estimator(object): Parameters ---------- + net : Block + The model used for training loss : gluon.loss.Loss or list of gluon.loss.Loss Loss(objective functions) to calculate during training metrics : EvalMetric or list of EvalMetric @@ -70,7 +71,7 @@ def __init__(self, net, def _check_loss(self, loss): if isinstance(loss, gluon.loss.Loss): loss = [loss] - elif isinstance(loss, list) or all([isinstance(l, gluon.loss.Loss) for l in loss]): + elif isinstance(loss, list) and all([isinstance(l, gluon.loss.Loss) for l in loss]): loss = loss else: raise ValueError("loss must be a Loss or a list of Loss, " @@ -122,19 +123,23 @@ def _check_context(self, context): def _initialize(self, initializer): # initialize the network - if initializer: - if self._is_initialized(): - # if already initialized, re-init with user specified initializer - warnings.warn("Network already initialized, re-initializing with %s. " - "You don't need to pass initializer if you already " - "initialized your net." % type(initializer).__name__) - self.net.initialize(init=initializer, ctx=self.context, force_reinit=True) + if not self._is_initialized(): + # net is partially or not initialized, + # initialize with user specified initializer + # if initializer is None, default initializer will be used + # do not re-init layers already initialized + if initializer: + self.net.initialize(init=initializer, ctx=self.context) else: - # initialize with user specified initializer - self.net.initialize(init=initializer, ctx=self.context, force_reinit=False) - else: - if not self._is_initialized(): self.net.initialize(ctx=self.context) + elif initializer: + # net is fully initialized, and user passed not None initializer + # do not force reinitialize, give warning + warnings.warn("Network already fully initialized, skipping initialization. " + "You don't need to pass initializer if you already " + "initialized your net. " + "You can use net.initialize(init=your_initializer, force_reinit=True)" + "to force re-initialize.") def _check_trainer(self, trainer): # handle trainer @@ -157,11 +162,11 @@ def _is_initialized(self): return False return True - def _get_data_and_label(self, batch, ctx): + def _get_data_and_label(self, batch, ctx, batch_axis=0): data = batch[0] label = batch[1] - data = gluon.utils.split_and_load(data, ctx_list=ctx, batch_axis=0) - label = gluon.utils.split_and_load(label, ctx_list=ctx, batch_axis=0) + data = gluon.utils.split_and_load(data, ctx_list=ctx, batch_axis=batch_axis) + label = gluon.utils.split_and_load(label, ctx_list=ctx, batch_axis=batch_axis) return data, label def prepare_loss_and_metrics(self): @@ -190,26 +195,29 @@ def prepare_loss_and_metrics(self): def evaluate(self, val_data, - val_metrics): + val_metrics, + batch_axis=0): """Evaluate model on validation data Parameters ---------- val_data : DataLoader - validation data with data and labels + validation data loader with data and labels + batch_axis : int, default 0 + batch axis to split the validation data into devices val_metrics : EvalMetric or list of EvalMetrics metrics to update validation result """ + if not isinstance(val_data, gluon.data.DataLoader): + raise ValueError("Estimator only support input as Gluon DataLoader. Alternatively, you " + "can transform your DataIter or any NDArray into Gluon DataLoader. " + "Refer to gluon.data.dataloader") for metric in val_metrics: metric.reset() for _, batch in enumerate(val_data): - if not isinstance(val_data, gluon.data.DataLoader): - raise ValueError("Estimator only support input as Gluon DataLoader. Alternatively, you " - "can transform your DataIter or any NDArray into Gluon DataLoader. " - "Refer to gluon.data.dataloader") - data, label = self._get_data_and_label(batch, self.context) + data, label = self._get_data_and_label(batch, self.context, batch_axis) pred = [self.net(x) for x in data] loss = [self.loss[0](y_hat, y) for y_hat, y in zip(pred, label)] # update metrics @@ -222,7 +230,8 @@ def evaluate(self, def fit(self, train_data, val_data=None, epochs=1, - event_handlers=None): + event_handlers=None, + batch_axis=0): """Trains the model on a given dataset for a specified number of epochs. Also, the batch size is inferred from the DataLoader's batch_size. @@ -230,20 +239,21 @@ def fit(self, train_data, Parameters ---------- train_data : DataLoader - training data with data and labels + training data loader with data and labels val_data : DataLoader - validation data with data and labels + validation data loader with data and labels epochs : int, default 1 number of epochs to iterate on the training data. - batch_size : int - number of samples per gradient update. - default will be 32 per device event_handlers : EventHandler or list of EventHandler list of EventHandlers to apply during training - batch_fn : function - custom batch function to extract data and label - from a data batch and load into contexts(devices) + batch_axis : int, default 0 + batch axis to split the validation data into devices """ + if not isinstance(train_data, gluon.data.DataLoader): + raise ValueError("Estimator only support input as Gluon DataLoader. Alternatively, you " + "can transform your DataIter or any NDArray into Gluon DataLoader. " + "Refer to gluon.data.dataloader") + self.max_epochs = epochs # provide default handlers @@ -252,8 +262,8 @@ def fit(self, train_data, train_begin, epoch_begin, batch_begin, \ batch_end, epoch_end, train_end = self._categorize_handlers(event_handlers) - # only pass a weak reference to all event handlers - estimator_ref = weakref.proxy(self) + # pass a reference to all event handlers + estimator_ref = self # training begin for handler in train_begin: handler.train_begin(estimator_ref) @@ -264,11 +274,7 @@ def fit(self, train_data, handler.epoch_begin(estimator_ref) for i, batch in enumerate(train_data): - if not isinstance(train_data, gluon.data.DataLoader): - raise ValueError("Estimator only support input as Gluon DataLoader. Alternatively, you " - "can transform your DataIter or any NDArray into Gluon DataLoader. " - "Refer to gluon.data.dataloader") - data, label = self._get_data_and_label(batch, self.context) + data, label = self._get_data_and_label(batch, self.context, batch_axis) batch_size = batch[0].shape[0] diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py index d8c3c6eaa6aa..3960fe72e6dd 100644 --- a/python/mxnet/gluon/contrib/estimator/event_handler.py +++ b/python/mxnet/gluon/contrib/estimator/event_handler.py @@ -161,6 +161,8 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat file name to save the logs file_location : str file location to save the logs + filemode : str, default 'a' + logging file mode, default using append mode verbose : int, default LOG_VERBOSITY_PER_EPOCH Limit the granularity of metrics displayed during training process verbose=LOG_VERBOSITY_PER_EPOCH: display metrics every epoch @@ -176,6 +178,7 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat def __init__(self, file_name=None, file_location=None, + filemode='a', verbose=LOG_VERBOSITY_PER_EPOCH, train_metrics=None, val_metrics=None): @@ -194,7 +197,7 @@ def __init__(self, file_name=None, if file_name or file_location: file_name = file_name or 'estimator_log' file_location = file_location or './' - file_handler = logging.FileHandler(os.path.join(file_location, file_name)) + file_handler = logging.FileHandler(os.path.join(file_location, file_name), mode=filemode) self.logger.addHandler(file_handler) self.train_metrics = train_metrics or [] self.val_metrics = val_metrics or [] @@ -249,6 +252,8 @@ def batch_end(self, estimator, *args, **kwargs): def epoch_begin(self, estimator, *args, **kwargs): if self.verbose >= self.LOG_VERBOSITY_PER_EPOCH: self.epoch_start = time.time() + self.logger.info("[Epoch %d] begin, current learning rate: %.4f", + self.current_epoch, estimator.trainer.learning_rate) def epoch_end(self, estimator, *args, **kwargs): if self.verbose >= self.LOG_VERBOSITY_PER_EPOCH: @@ -280,9 +285,13 @@ class CheckpointHandler(BatchEnd, EpochEnd): if True, only save the parameters if monitored value improved mode: str, default 'auto' one of {auto, min, max}, if `save_best_only=True`, the comparison to make - and determine if the monitored value has improved - period: int, default 1 - intervals between saving the network + and determine if the monitored value has improved. if 'auto' mode, checkpoint + handler will try to use min or max based on the monitored metric name + epoch period: int, default 1 + epoch intervals between saving the network + batch period: int, default None + batch intervals between saving the network, + by default don't save any checkpoint based on number of batches """ def __init__(self, @@ -308,7 +317,9 @@ def __init__(self, if mode not in ['auto', 'min', 'max']: warnings.warn('ModelCheckpoint mode %s is unknown, ' - 'fallback to auto mode.' % (mode), + 'fallback to auto mode. CheckpointHandler will use' + 'max mode for f1 and accuracy metric comparison and ' + 'use min mode other wise' % (mode), RuntimeWarning) mode = 'auto' @@ -319,27 +330,38 @@ def __init__(self, self.monitor_op = np.greater self.best = -np.Inf else: - # use greater for accuracy and less otherwise - if 'acc' in self.monitor.get()[0].lower(): + # use greater for accuracy and f1 and less otherwise + if 'acc' or 'f1' in self.monitor.get()[0].lower(): + self.logger.info("`greater` operator will be used to determine " + "if %s has improved, please use `min` for mode " + "if you want otherwise", self.monitor.get()[0]) self.monitor_op = np.greater self.best = -np.Inf else: + self.logger.info("`less` operator will be used to determine " + "if %s has improved, please use `max` for mode " + "if you want otherwise", self.monitor.get()[0]) self.monitor_op = np.less self.best = np.Inf def batch_end(self, estimator, *args, **kwargs): - self._save_checkpoint(estimator.net, "Batch", self.num_batches) - self.num_batches += 1 + if self.batch_period: + self._save_checkpoint(estimator.net, "Batch", self.batch_period, self.num_batches) + self.num_batches += 1 def epoch_end(self, estimator, *args, **kwargs): - self._save_checkpoint(estimator.net, "Epoch", self.num_epochs) - self.num_epochs += 1 - - def _save_checkpoint(self, net, period_name, period_value): + if self.epoch_period: + self._save_checkpoint(estimator.net, "Epoch", self.epoch_period, self.num_epochs) + self.num_epochs += 1 + + def _save_checkpoint(self, net, period_name, period_value, num_of_periods): + # period name can be batch or epoch + # period value determine how often a checkpoint is saved + # num_of_periods records the number of batch or epoch # add extension for weights if '.params' not in self.filepath: self.filepath += '.params' - if self.num_epochs % self.epoch_period == 0: + if num_of_periods % period_value == 0: if self.save_best_only: monitor_name, monitor_value = self.monitor.get() # check if monitor exists in train stats @@ -381,8 +403,9 @@ class EarlyStoppingHandler(TrainBegin, EpochEnd, TrainEnd): patience: int, default 0 number of epochs to wait for improvement before terminate training mode: str, default 'auto' - one of {auto, min, max}, the comparison to make - and determine if the monitored value has improved + one of {auto, min, max}, if `save_best_only=True`, the comparison to make + and determine if the monitored value has improved. if 'auto' mode, checkpoint + handler will try to use min or max based on the monitored metric name baseline: float baseline value to compare the monitored value with """ @@ -409,8 +432,11 @@ def __init__(self, self.logger = logging.getLogger(__name__) if mode not in ['auto', 'min', 'max']: - warnings.warn(RuntimeWarning('EarlyStopping mode %s is unknown, ' - 'fallback to auto mode.', mode)) + warnings.warn('EarlyStopping mode %s is unknown, ' + 'fallback to auto mode. CheckpointHandler will use' + 'max mode for f1 and accuracy metric comparison and ' + 'use min mode other wise' % (mode), + RuntimeWarning) mode = 'auto' if mode == 'min': @@ -418,9 +444,15 @@ def __init__(self, elif mode == 'max': self.monitor_op = np.greater else: - if 'acc' in self.monitor.get()[0].lower(): + if 'acc' or 'f1' in self.monitor.get()[0].lower(): + self.logger.info("`greater` operator is used to determine " + "if %s has improved, please use `min` for mode " + "if you want otherwise", self.monitor.get()[0]) self.monitor_op = np.greater else: + self.logger.info("`less` operator is used to determine " + "if %s has improved, please use `max` for mode " + "if you want otherwise", self.monitor.get()[0]) self.monitor_op = np.less if self.monitor_op == np.greater: diff --git a/tests/python/unittest/test_gluon_estimator.py b/tests/python/unittest/test_gluon_estimator.py index b25baa255165..315c5140fc39 100644 --- a/tests/python/unittest/test_gluon_estimator.py +++ b/tests/python/unittest/test_gluon_estimator.py @@ -139,7 +139,13 @@ def test_initializer(): initializer=mx.init.MSRAPrelu(), trainer=trainer, context=ctx) - assert 'Network already initialized' in str(w[-1].message) + assert 'Network already fully initialized' in str(w[-1].message) + # net partially initialized, fine tuning use case + net = gluon.model_zoo.vision.resnet18_v1(pretrained=True, ctx=ctx) + net.output = gluon.nn.Dense(10) #last layer not initialized + est = Estimator(net, loss=loss, metrics=acc, context=ctx) + dataset = gluon.data.ArrayDataset(mx.nd.zeros((10, 3, 224, 224)), mx.nd.zeros((10, 10))) + train_data = gluon.data.DataLoader(dataset=dataset, batch_size=5) est.fit(train_data=train_data, epochs=num_epochs) diff --git a/tests/python/unittest/test_gluon_event_handler.py b/tests/python/unittest/test_gluon_event_handler.py index cdb4264e18a0..163ca7d3e92e 100644 --- a/tests/python/unittest/test_gluon_event_handler.py +++ b/tests/python/unittest/test_gluon_event_handler.py @@ -45,7 +45,7 @@ def test_checkpoint_handler(): file_path = os.path.join(tmpdir, "model.params") test_data = _get_test_data() - save_best_only = False + save_best_only = True mode = 'auto' net = _get_test_network() @@ -61,6 +61,17 @@ def test_checkpoint_handler(): assert os.path.isfile(file_path) os.remove(file_path) + checkpoint_handler = [event_handler.CheckpointHandler(file_path, + monitor=acc, + save_best_only=save_best_only, + mode=mode, + epoch_period=None, + batch_period=1)] + est.fit(test_data, event_handlers=checkpoint_handler, epochs=2) + assert os.path.isfile(file_path) + os.remove(file_path) + + def test_early_stopping(): test_data = _get_test_data() From 8e557dd3ddd57bb89c8994f76a92e2c4fd041110 Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Tue, 7 May 2019 15:29:13 -0700 Subject: [PATCH 2/9] update checkpoint --- .../gluon/contrib/estimator/event_handler.py | 176 +++++++++++------- .../unittest/test_gluon_event_handler.py | 62 +++--- 2 files changed, 145 insertions(+), 93 deletions(-) diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py index 3960fe72e6dd..c09c4afefeeb 100644 --- a/python/mxnet/gluon/contrib/estimator/event_handler.py +++ b/python/mxnet/gluon/contrib/estimator/event_handler.py @@ -258,7 +258,7 @@ def epoch_begin(self, estimator, *args, **kwargs): def epoch_end(self, estimator, *args, **kwargs): if self.verbose >= self.LOG_VERBOSITY_PER_EPOCH: epoch_time = time.time() - self.epoch_start - msg = '\n[Epoch %d] finished in %.3fs: ' % (self.current_epoch, epoch_time) + msg = '[Epoch %d] finished in %.3fs: ' % (self.current_epoch, epoch_time) for monitor in self.train_metrics + self.val_metrics: name, value = monitor.get() msg += '%s : %.4f ' % (name, value) @@ -268,125 +268,171 @@ def epoch_end(self, estimator, *args, **kwargs): class CheckpointHandler(BatchEnd, EpochEnd): - """Save the model after every epoch. + """Save the model after user define period - :py:class:`CheckpointHandler` save the network parameters every epoch + :py:class:`CheckpointHandler` saves the network architecture after first batch if the model + can be hybridized(), saves model parameters and trainer states after user defined period, + default saves every epoch. Parameters ---------- - filepath : str - file name to save the parameters, it can contain directories, - for example: ./saved_model/resnet.params + model_dir : str + file directory to save all the model related files including model architecture, + model parameters, and trainer states. + model_prefix : str default 'model' + prefix to add for all checkpoint file names monitor: EvalMetric - the metrics to monitor + the metrics to monitor and determine if model has improved verbose: int, default 0 verbosity mode - save_best_only: bool - if True, only save the parameters if monitored value improved + save_best: bool + if True, save the model parameters and trainer states with the best monitored value mode: str, default 'auto' one of {auto, min, max}, if `save_best_only=True`, the comparison to make and determine if the monitored value has improved. if 'auto' mode, checkpoint handler will try to use min or max based on the monitored metric name - epoch period: int, default 1 + epoch_period: int, default 1 epoch intervals between saving the network - batch period: int, default None + batch_period: int, default None batch intervals between saving the network, by default don't save any checkpoint based on number of batches + max_checkpoints : int, default 5 + maximum number of checkpoint files to keep in the model_dir, older checkpoints + will be removed """ def __init__(self, - filepath, + model_dir, + model_prefix='model', monitor=None, verbose=0, - save_best_only=False, + save_best=False, mode='auto', epoch_period=1, - batch_period=None): + batch_period=None, + max_checkpoints=5): self.monitor = monitor self.verbose = verbose - self.filepath = filepath - self.save_best_only = save_best_only - if self.save_best_only and not isinstance(self.monitor, EvalMetric): + if not os.path.exists(model_dir): + os.makedirs(model_dir) + self.model_dir = model_dir + self.model_prefix = model_prefix + self.save_best = save_best + if self.save_best and not isinstance(self.monitor, EvalMetric): raise ValueError("To save best model only, please provide one of the metric objects as monitor, " "You can get these objects using estimator.prepare_loss_and_metric()") self.epoch_period = epoch_period self.batch_period = batch_period self.num_batches = 0 self.num_epochs = 0 + self.max_checkpoints = max_checkpoints + self.saved_checkpoints = [] self.logger = logging.getLogger(__name__) - - if mode not in ['auto', 'min', 'max']: - warnings.warn('ModelCheckpoint mode %s is unknown, ' - 'fallback to auto mode. CheckpointHandler will use' - 'max mode for f1 and accuracy metric comparison and ' - 'use min mode other wise' % (mode), - RuntimeWarning) - mode = 'auto' - - if mode == 'min': - self.monitor_op = np.less - self.best = np.Inf - elif mode == 'max': - self.monitor_op = np.greater - self.best = -np.Inf - else: - # use greater for accuracy and f1 and less otherwise - if 'acc' or 'f1' in self.monitor.get()[0].lower(): - self.logger.info("`greater` operator will be used to determine " - "if %s has improved, please use `min` for mode " - "if you want otherwise", self.monitor.get()[0]) + if self.save_best: + if mode not in ['auto', 'min', 'max']: + warnings.warn('ModelCheckpoint mode %s is unknown, ' + 'fallback to auto mode. CheckpointHandler will use' + 'max mode for f1 and accuracy metric comparison and ' + 'use min mode other wise' % (mode), + RuntimeWarning) + mode = 'auto' + + if mode == 'min': + self.monitor_op = np.less + self.best = np.Inf + elif mode == 'max': self.monitor_op = np.greater self.best = -np.Inf else: - self.logger.info("`less` operator will be used to determine " - "if %s has improved, please use `max` for mode " - "if you want otherwise", self.monitor.get()[0]) - self.monitor_op = np.less - self.best = np.Inf + # use greater for accuracy and f1 and less otherwise + if 'acc' or 'f1' in self.monitor.get()[0].lower(): + self.logger.info("`greater` operator will be used to determine " + "if %s has improved, please use `min` for mode " + "if you want otherwise", self.monitor.get()[0]) + self.monitor_op = np.greater + self.best = -np.Inf + else: + self.logger.info("`less` operator will be used to determine " + "if %s has improved, please use `max` for mode " + "if you want otherwise", self.monitor.get()[0]) + self.monitor_op = np.less + self.best = np.Inf def batch_end(self, estimator, *args, **kwargs): + # only save symbol once after first batch + if self.num_batches == 0: + self._save_symbol(estimator) if self.batch_period: - self._save_checkpoint(estimator.net, "Batch", self.batch_period, self.num_batches) - self.num_batches += 1 + self._save_checkpoint(estimator, "Batch", self.batch_period, self.num_batches) + self.num_batches += 1 def epoch_end(self, estimator, *args, **kwargs): if self.epoch_period: - self._save_checkpoint(estimator.net, "Epoch", self.epoch_period, self.num_epochs) - self.num_epochs += 1 + self._save_checkpoint(estimator, "Epoch", self.epoch_period, self.num_epochs) + self.num_epochs += 1 - def _save_checkpoint(self, net, period_name, period_value, num_of_periods): + def _save_checkpoint(self, estimator, period_name, period_value, num_of_periods): # period name can be batch or epoch # period value determine how often a checkpoint is saved # num_of_periods records the number of batch or epoch # add extension for weights - if '.params' not in self.filepath: - self.filepath += '.params' if num_of_periods % period_value == 0: - if self.save_best_only: + if self.verbose > 0: + self.logger.info('[%s %d] saving model to %s', period_name, num_of_periods, self.model_dir) + prefix = "%s-%s%d" % (self.model_prefix, period_name.lower(), num_of_periods) + self._save_params_and_trainer(estimator, prefix) + + if self.save_best: monitor_name, monitor_value = self.monitor.get() # check if monitor exists in train stats if np.isnan(monitor_value): - warnings.warn(RuntimeWarning('%s is not updated, make sure you pass one of the metric objects' - 'as monitor, you can use estimator.prepare_loss_and_metrics to' + warnings.warn(RuntimeWarning('SKipping save best because %s is not updated, make sure you ' + 'pass one of the metric objects as monitor, ' + 'you can use estimator.prepare_loss_and_metrics to' 'create all metric objects', monitor_name)) - net.save_parameters(self.filepath) else: if self.monitor_op(monitor_value, self.best): if self.verbose > 0: - self.logger.info('\n[%s %d] %s improved from %0.5f to %0.5f,' - ' saving model to %s', - period_name, period_value, monitor_name, - self.best, monitor_value, self.filepath) + self.logger.info('[%s %d] %s improved from %0.5f to %0.5f,' + ' updating best model to %s', + period_name, num_of_periods, monitor_name, + self.best, monitor_value, self.model_dir) + prefix = self.model_prefix + '-best' + self._save_params_and_trainer(estimator, prefix) self.best = monitor_value - net.save_parameters(self.filepath) else: if self.verbose > 0: - self.logger.info('\n[%s %d] %s did not improve from %0.5f, skipping save model', - period_name, period_value, monitor_name, self.best) - else: - if self.verbose > 0: - logging.info('\n%s %d: saving model to %s', period_name, period_value, self.filepath) - net.save_parameters(self.filepath) + self.logger.info('[%s %d] %s did not improve from %0.5f, ' + 'skipping updating best model', + period_name, num_of_periods, monitor_name, + self.best) + + + def _save_symbol(self, estimator): + symbol_file = os.path.join(self.model_dir, self.model_prefix + '-symbol.json') + if hasattr(estimator.net, '_cached_graph'): + sym = estimator.net._cached_graph[1] + sym.save(symbol_file) + else: + self.logger.info("Model architecture(symbol file) is not saved, please use HybridBlock" + "to construct your model, can call net.hybridize() before passing to" + "Estimator in order to save model architecture as %s.", symbol_file) + + def _save_params_and_trainer(self, estimator, file_prefix): + param_file = os.path.join(self.model_dir, file_prefix + '.params') + trainer_file = os.path.join(self.model_dir, file_prefix + '.states') + estimator.net.save_parameters(param_file) + estimator.trainer.save_states(trainer_file) + + # only count checkpoints with epoch or batch number in file name + if 'best' not in file_prefix: + self.saved_checkpoints.append(file_prefix) + # remove old checkpoint when max number of checkpoints reached + if len(self.saved_checkpoints) > self.max_checkpoints: + prefix = self.saved_checkpoints.pop(0) + for fname in os.listdir(self.model_dir): + if fname.startswith(prefix): + os.remove(os.path.join(self.model_dir, fname)) class EarlyStoppingHandler(TrainBegin, EpochEnd, TrainEnd): diff --git a/tests/python/unittest/test_gluon_event_handler.py b/tests/python/unittest/test_gluon_event_handler.py index 163ca7d3e92e..f781fda43ff9 100644 --- a/tests/python/unittest/test_gluon_event_handler.py +++ b/tests/python/unittest/test_gluon_event_handler.py @@ -41,35 +41,41 @@ def _get_test_data(): def test_checkpoint_handler(): - tmpdir = tempfile.mkdtemp() - file_path = os.path.join(tmpdir, "model.params") - test_data = _get_test_data() - - save_best_only = True - mode = 'auto' + with TemporaryDirectory() as tmpdir: + model_prefix = 'test_net' + file_path = os.path.join(tmpdir, model_prefix) + test_data = _get_test_data() - net = _get_test_network() - ce_loss = loss.SoftmaxCrossEntropyLoss() - ce_loss_metric = mx.metric.Loss(ce_loss.name) - acc = mx.metric.Accuracy() - est = estimator.Estimator(net, loss=ce_loss, metrics=acc) - checkpoint_handler = [event_handler.CheckpointHandler(file_path, - monitor=acc, - save_best_only=save_best_only, - mode=mode)] - est.fit(test_data, event_handlers=checkpoint_handler, epochs=1) - assert os.path.isfile(file_path) - os.remove(file_path) - - checkpoint_handler = [event_handler.CheckpointHandler(file_path, - monitor=acc, - save_best_only=save_best_only, - mode=mode, - epoch_period=None, - batch_period=1)] - est.fit(test_data, event_handlers=checkpoint_handler, epochs=2) - assert os.path.isfile(file_path) - os.remove(file_path) + net = _get_test_network() + ce_loss = loss.SoftmaxCrossEntropyLoss() + acc = mx.metric.Accuracy() + est = estimator.Estimator(net, loss=ce_loss, metrics=acc) + checkpoint_handler = [event_handler.CheckpointHandler(model_dir=tmpdir, + model_prefix=model_prefix, + monitor=acc, + save_best=True)] + est.fit(test_data, event_handlers=checkpoint_handler, epochs=1) + assert os.path.isfile(file_path + '-best.params') + assert os.path.isfile(file_path + '-best.states') + assert os.path.isfile(file_path + '-epoch0.params') + assert os.path.isfile(file_path + '-epoch0.states') + + model_prefix = 'test_batch' + file_path = os.path.join(tmpdir, model_prefix) + checkpoint_handler = [event_handler.CheckpointHandler(model_dir=tmpdir, + model_prefix=model_prefix, + epoch_period=None, + batch_period=1, + max_checkpoints=2)] + est.fit(test_data, event_handlers=checkpoint_handler, epochs=3) + assert not os.path.isfile(file_path + 'best.params') + assert not os.path.isfile(file_path + 'best.states') + assert not os.path.isfile(file_path + '-batch0.params') + assert not os.path.isfile(file_path + '-batch0.states') + assert os.path.isfile(file_path + '-batch1.params') + assert os.path.isfile(file_path + '-batch1.states') + assert os.path.isfile(file_path + '-batch2.params') + assert os.path.isfile(file_path + '-batch2.states') From d60404ef359657f5421d9ccf907b8312665cf1eb Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Tue, 7 May 2019 15:44:00 -0700 Subject: [PATCH 3/9] test symbol save --- tests/python/unittest/test_gluon_event_handler.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_gluon_event_handler.py b/tests/python/unittest/test_gluon_event_handler.py index f781fda43ff9..9fbec96edfc7 100644 --- a/tests/python/unittest/test_gluon_event_handler.py +++ b/tests/python/unittest/test_gluon_event_handler.py @@ -25,8 +25,7 @@ from mxnet.gluon.contrib.estimator import estimator, event_handler -def _get_test_network(): - net = nn.Sequential() +def _get_test_network(net=nn.Sequential()): net.add(nn.Dense(128, activation='relu', in_units=100, flatten=False), nn.Dense(64, activation='relu', in_units=128), nn.Dense(10, activation='relu', in_units=64)) @@ -62,6 +61,9 @@ def test_checkpoint_handler(): model_prefix = 'test_batch' file_path = os.path.join(tmpdir, model_prefix) + net = _get_test_network(nn.HybridSequential()) + net.hybridize() + est = estimator.Estimator(net, loss=ce_loss, metrics=acc) checkpoint_handler = [event_handler.CheckpointHandler(model_dir=tmpdir, model_prefix=model_prefix, epoch_period=None, @@ -72,6 +74,7 @@ def test_checkpoint_handler(): assert not os.path.isfile(file_path + 'best.states') assert not os.path.isfile(file_path + '-batch0.params') assert not os.path.isfile(file_path + '-batch0.states') + assert os.path.isfile(file_path + '-symbol.json') assert os.path.isfile(file_path + '-batch1.params') assert os.path.isfile(file_path + '-batch1.states') assert os.path.isfile(file_path + '-batch2.params') From 70854d1a74e8d566e5ef24e3ea58dd616b68d069 Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Wed, 8 May 2019 14:28:21 -0700 Subject: [PATCH 4/9] address comments --- .../gluon/contrib/estimator/estimator.py | 21 ++- .../gluon/contrib/estimator/event_handler.py | 48 ++++--- .../unittest/test_gluon_event_handler.py | 126 ++++++++++++------ 3 files changed, 132 insertions(+), 63 deletions(-) diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py index d7671cc8e4ce..017786439575 100644 --- a/python/mxnet/gluon/contrib/estimator/estimator.py +++ b/python/mxnet/gluon/contrib/estimator/estimator.py @@ -203,10 +203,10 @@ def evaluate(self, ---------- val_data : DataLoader validation data loader with data and labels - batch_axis : int, default 0 - batch axis to split the validation data into devices val_metrics : EvalMetric or list of EvalMetrics metrics to update validation result + batch_axis : int, default 0 + batch axis to split the validation data into devices """ if not isinstance(val_data, gluon.data.DataLoader): raise ValueError("Estimator only support input as Gluon DataLoader. Alternatively, you " @@ -291,15 +291,22 @@ def fit(self, train_data, self.trainer.step(batch_size) # batch end + + batch_end_result = [] for handler in batch_end: - if handler.batch_end(estimator_ref, batch=batch, - pred=pred, label=label, loss=loss): - break + batch_end_result.append(handler.batch_end(estimator_ref, batch=batch, + pred=pred, label=label, loss=loss)) + # if any handler signaled to stop + if any(batch_end_result): + break # epoch end + epoch_end_result = [] for handler in epoch_end: - if handler.epoch_end(estimator_ref): - break + epoch_end_result.append(handler.epoch_end(estimator_ref)) + # if any handler signaled to stop + if any(epoch_end_result): + break # train end for handler in train_end: diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py index c09c4afefeeb..34e17fe9770e 100644 --- a/python/mxnet/gluon/contrib/estimator/event_handler.py +++ b/python/mxnet/gluon/contrib/estimator/event_handler.py @@ -94,7 +94,7 @@ def batch_end(self, estimator, *args, **kwargs): metric.update(label, pred) -class ValidationHandler(BatchEnd, EpochEnd): +class ValidationHandler(TrainBegin, BatchEnd, EpochEnd): """"Validation Handler that evaluate model on validation dataset :py:class:`ValidationHandler` takes validation dataset, an evaluation function, @@ -135,18 +135,23 @@ def __init__(self, # validation metrics need to be calculated before other callbacks can access them self.priority = -np.Inf + def train_begin(self, estimator, *args, **kwargs): + # reset epoch and batch counter + self.num_batches = 0 + self.num_epochs = 0 + def batch_end(self, estimator, *args, **kwargs): + self.num_batches += 1 if self.batch_period and self.num_batches % self.batch_period == 0: self.eval_fn(val_data=self.val_data, val_metrics=self.val_metrics) - self.num_batches += 1 def epoch_end(self, estimator, *args, **kwargs): + self.num_epochs += 1 if self.num_epochs % self.epoch_period == 0: self.eval_fn(val_data=self.val_data, val_metrics=self.val_metrics) - self.num_epochs += 1 class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, BatchEnd): @@ -217,6 +222,10 @@ def train_begin(self, estimator, *args, **kwargs): "with current learning rate %.4f ", optimizer, lr) self.logger.info("Train for %d epochs.", estimator.max_epochs) + # reset all counters + self.current_epoch = 0 + self.batch_index = 0 + self.processed_samples = 0 def train_end(self, estimator, *args, **kwargs): train_time = time.time() - self.train_start @@ -247,7 +256,7 @@ def batch_end(self, estimator, *args, **kwargs): name, value = metric.get() msg += '%s : %.4f ' % (name, value) self.logger.info(msg) - self.batch_index += 1 + self.batch_index += 1 def epoch_begin(self, estimator, *args, **kwargs): if self.verbose >= self.LOG_VERBOSITY_PER_EPOCH: @@ -263,15 +272,15 @@ def epoch_end(self, estimator, *args, **kwargs): name, value = monitor.get() msg += '%s : %.4f ' % (name, value) self.logger.info(msg) - self.current_epoch += 1 - self.batch_index = 0 + self.current_epoch += 1 + self.batch_index = 0 -class CheckpointHandler(BatchEnd, EpochEnd): +class CheckpointHandler(TrainBegin, BatchEnd, EpochEnd): """Save the model after user define period :py:class:`CheckpointHandler` saves the network architecture after first batch if the model - can be hybridized(), saves model parameters and trainer states after user defined period, + can be fully hybridized, saves model parameters and trainer states after user defined period, default saves every epoch. Parameters @@ -295,7 +304,7 @@ class CheckpointHandler(BatchEnd, EpochEnd): epoch intervals between saving the network batch_period: int, default None batch intervals between saving the network, - by default don't save any checkpoint based on number of batches + by default, checkpoints are not saved based on the number of batches max_checkpoints : int, default 5 maximum number of checkpoint files to keep in the model_dir, older checkpoints will be removed @@ -350,13 +359,19 @@ def __init__(self, "if %s has improved, please use `min` for mode " "if you want otherwise", self.monitor.get()[0]) self.monitor_op = np.greater - self.best = -np.Inf else: self.logger.info("`less` operator will be used to determine " "if %s has improved, please use `max` for mode " "if you want otherwise", self.monitor.get()[0]) self.monitor_op = np.less - self.best = np.Inf + + def train_begin(self, estimator, *args, **kwargs): + # reset all counters + self.num_epochs = 0 + self.num_batches = 0 + if self.save_best: + self.best = np.Inf if self.monitor_op == np.less else -np.Inf + def batch_end(self, estimator, *args, **kwargs): # only save symbol once after first batch @@ -386,7 +401,7 @@ def _save_checkpoint(self, estimator, period_name, period_value, num_of_periods) monitor_name, monitor_value = self.monitor.get() # check if monitor exists in train stats if np.isnan(monitor_value): - warnings.warn(RuntimeWarning('SKipping save best because %s is not updated, make sure you ' + warnings.warn(RuntimeWarning('Skipping save best because %s is not updated, make sure you ' 'pass one of the metric objects as monitor, ' 'you can use estimator.prepare_loss_and_metrics to' 'create all metric objects', monitor_name)) @@ -440,8 +455,6 @@ class EarlyStoppingHandler(TrainBegin, EpochEnd, TrainEnd): Parameters ---------- - estimator : Estimator - The :py:class:`Estimator` to get training statistics monitor: EvalMetric the metrics to monitor min_delta: float, default 0 @@ -473,7 +486,7 @@ def __init__(self, self.min_delta = min_delta self.wait = 0 self.stopped_epoch = 0 - self.num_epochs = 0 + self.current_epoch = 0 self.stop_training = False self.logger = logging.getLogger(__name__) @@ -509,6 +522,8 @@ def __init__(self, def train_begin(self, estimator, *args, **kwargs): self.wait = 0 self.stopped_epoch = 0 + self.current_epoch = 0 + self.stop_training = False if self.baseline is not None: self.best = self.baseline else: @@ -527,8 +542,9 @@ def epoch_end(self, estimator, *args, **kwargs): else: self.wait += 1 if self.wait >= self.patience: - self.stopped_epoch = self.num_epochs + self.stopped_epoch = self.current_epoch self.stop_training = True + self.current_epoch += 1 return self.stop_training def train_end(self, estimator, *args, **kwargs): diff --git a/tests/python/unittest/test_gluon_event_handler.py b/tests/python/unittest/test_gluon_event_handler.py index 9fbec96edfc7..32f766a1d16f 100644 --- a/tests/python/unittest/test_gluon_event_handler.py +++ b/tests/python/unittest/test_gluon_event_handler.py @@ -16,7 +16,6 @@ # under the License. import os -import tempfile import mxnet as mx from common import TemporaryDirectory @@ -26,17 +25,17 @@ def _get_test_network(net=nn.Sequential()): - net.add(nn.Dense(128, activation='relu', in_units=100, flatten=False), - nn.Dense(64, activation='relu', in_units=128), - nn.Dense(10, activation='relu', in_units=64)) + net.add(nn.Dense(128, activation='relu', flatten=False), + nn.Dense(64, activation='relu'), + nn.Dense(10, activation='relu')) return net def _get_test_data(): data = nd.ones((32, 100)) - label = nd.random.randint(0, 10, (32, 1)) + label = nd.zeros((32, 1)) data_arr = mx.gluon.data.dataset.ArrayDataset(data, label) - return mx.gluon.data.DataLoader(data_arr, batch_size=32) + return mx.gluon.data.DataLoader(data_arr, batch_size=8) def test_checkpoint_handler(): @@ -49,11 +48,13 @@ def test_checkpoint_handler(): ce_loss = loss.SoftmaxCrossEntropyLoss() acc = mx.metric.Accuracy() est = estimator.Estimator(net, loss=ce_loss, metrics=acc) - checkpoint_handler = [event_handler.CheckpointHandler(model_dir=tmpdir, - model_prefix=model_prefix, - monitor=acc, - save_best=True)] - est.fit(test_data, event_handlers=checkpoint_handler, epochs=1) + checkpoint_handler = event_handler.CheckpointHandler(model_dir=tmpdir, + model_prefix=model_prefix, + monitor=acc, + save_best=True) + est.fit(test_data, event_handlers=[checkpoint_handler], epochs=1) + assert checkpoint_handler.num_epochs == 1 + assert checkpoint_handler.num_batches == 4 assert os.path.isfile(file_path + '-best.params') assert os.path.isfile(file_path + '-best.states') assert os.path.isfile(file_path + '-epoch0.params') @@ -64,45 +65,44 @@ def test_checkpoint_handler(): net = _get_test_network(nn.HybridSequential()) net.hybridize() est = estimator.Estimator(net, loss=ce_loss, metrics=acc) - checkpoint_handler = [event_handler.CheckpointHandler(model_dir=tmpdir, - model_prefix=model_prefix, - epoch_period=None, - batch_period=1, - max_checkpoints=2)] - est.fit(test_data, event_handlers=checkpoint_handler, epochs=3) + checkpoint_handler = event_handler.CheckpointHandler(model_dir=tmpdir, + model_prefix=model_prefix, + epoch_period=None, + batch_period=1, + max_checkpoints=2) + est.fit(test_data, event_handlers=[checkpoint_handler], epochs=2) + assert checkpoint_handler.num_epochs == 2 + assert checkpoint_handler.num_batches == 8 assert not os.path.isfile(file_path + 'best.params') assert not os.path.isfile(file_path + 'best.states') assert not os.path.isfile(file_path + '-batch0.params') assert not os.path.isfile(file_path + '-batch0.states') assert os.path.isfile(file_path + '-symbol.json') - assert os.path.isfile(file_path + '-batch1.params') - assert os.path.isfile(file_path + '-batch1.states') - assert os.path.isfile(file_path + '-batch2.params') - assert os.path.isfile(file_path + '-batch2.states') - + assert os.path.isfile(file_path + '-batch6.params') + assert os.path.isfile(file_path + '-batch6.states') + assert os.path.isfile(file_path + '-batch7.params') + assert os.path.isfile(file_path + '-batch7.states') def test_early_stopping(): test_data = _get_test_data() - mode = 'max' - patience = 0 - net = _get_test_network() ce_loss = loss.SoftmaxCrossEntropyLoss() acc = mx.metric.Accuracy() est = estimator.Estimator(net, loss=ce_loss, metrics=acc) - early_stopping = [event_handler.EarlyStoppingHandler(monitor=acc, - patience=patience, - mode=mode)] - est.fit(test_data, event_handlers=early_stopping, epochs=3) + early_stopping = event_handler.EarlyStoppingHandler(monitor=acc, + patience=0, + mode='min') + est.fit(test_data, event_handlers=[early_stopping], epochs=5) + assert early_stopping.current_epoch == 2 + assert early_stopping.stopped_epoch == 1 - mode = 'auto' - patience = 2 - early_stopping = [event_handler.EarlyStoppingHandler(monitor=acc, - patience=patience, - mode=mode)] - est.fit(test_data, event_handlers=early_stopping, epochs=1) + early_stopping = event_handler.EarlyStoppingHandler(monitor=acc, + patience=2, + mode='auto') + est.fit(test_data, event_handlers=[early_stopping], epochs=1) + assert early_stopping.current_epoch == 1 def test_logging(): @@ -116,9 +116,55 @@ def test_logging(): acc = mx.metric.Accuracy() est = estimator.Estimator(net, loss=ce_loss, metrics=acc) train_metrics, val_metrics = est.prepare_loss_and_metrics() - logging_handler = [event_handler.LoggingHandler(file_name=file_name, - file_location=tmpdir, - train_metrics=train_metrics, - val_metrics=val_metrics)] - est.fit(test_data, event_handlers=logging_handler, epochs=1) + logging_handler = event_handler.LoggingHandler(file_name=file_name, + file_location=tmpdir, + train_metrics=train_metrics, + val_metrics=val_metrics) + est.fit(test_data, event_handlers=[logging_handler], epochs=3) + assert logging_handler.batch_index == 0 + assert logging_handler.current_epoch == 3 assert os.path.isfile(output_dir) + + +def test_custom_handler(): + class CustomStopHandler(event_handler.TrainBegin, + event_handler.BatchEnd, + event_handler.EpochEnd): + def __init__(self, batch_stop=None, epoch_stop=None): + self.batch_stop = batch_stop + self.epoch_stop = epoch_stop + self.num_batch = 0 + self.num_epoch = 0 + self.stop_training = False + + def train_begin(self, estimator, *args, **kwargs): + self.num_batch = 0 + self.num_epoch = 0 + + def batch_end(self, estimator, *args, **kwargs): + self.num_batch += 1 + if self.num_batch == self.batch_stop: + self.stop_training = True + return self.stop_training + + def epoch_end(self, estimator, *args, **kwargs): + self.num_epoch += 1 + if self.num_epoch == self.epoch_stop: + self.stop_training = True + return self.stop_training + + # total data size is 32, batch size is 8 + # 4 batch per epoch + test_data = _get_test_data() + net = _get_test_network() + ce_loss = loss.SoftmaxCrossEntropyLoss() + acc = mx.metric.Accuracy() + est = estimator.Estimator(net, loss=ce_loss, metrics=acc) + custom_handler = CustomStopHandler(3, 2) + est.fit(test_data, event_handlers=[custom_handler], epochs=3) + assert custom_handler.num_batch == 3 + assert custom_handler.num_epoch == 1 + custom_handler = CustomStopHandler(100, 5) + est.fit(test_data, event_handlers=[custom_handler], epochs=10) + assert custom_handler.num_batch == 5 * 4 + assert custom_handler.num_epoch == 5 From 80617c190f2bd2e99d739e4acfbc957845aa9e4e Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Wed, 8 May 2019 16:58:42 -0700 Subject: [PATCH 5/9] add resume --- .../gluon/contrib/estimator/estimator.py | 2 +- .../gluon/contrib/estimator/event_handler.py | 58 +++++++++++++++++-- .../unittest/test_gluon_event_handler.py | 27 +++++++++ 3 files changed, 81 insertions(+), 6 deletions(-) diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py index 017786439575..e3f20435fc66 100644 --- a/python/mxnet/gluon/contrib/estimator/estimator.py +++ b/python/mxnet/gluon/contrib/estimator/estimator.py @@ -268,7 +268,7 @@ def fit(self, train_data, for handler in train_begin: handler.train_begin(estimator_ref) - for epoch in range(epochs): + for epoch in range(self.max_epochs): # epoch begin for handler in epoch_begin: handler.epoch_begin(estimator_ref) diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py index 34e17fe9770e..dcf42bf238fc 100644 --- a/python/mxnet/gluon/contrib/estimator/event_handler.py +++ b/python/mxnet/gluon/contrib/estimator/event_handler.py @@ -153,7 +153,6 @@ def epoch_end(self, estimator, *args, **kwargs): val_metrics=self.val_metrics) - class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, BatchEnd): """Basic Logging Handler that applies to every Gluon estimator by default. @@ -308,6 +307,8 @@ class CheckpointHandler(TrainBegin, BatchEnd, EpochEnd): max_checkpoints : int, default 5 maximum number of checkpoint files to keep in the model_dir, older checkpoints will be removed + resume_from : str, default None + one of {batch, epoch}, select which type of checkpoints to resume from """ def __init__(self, @@ -319,7 +320,8 @@ def __init__(self, mode='auto', epoch_period=1, batch_period=None, - max_checkpoints=5): + max_checkpoints=5, + resume_type=None): self.monitor = monitor self.verbose = verbose if not os.path.exists(model_dir): @@ -335,6 +337,9 @@ def __init__(self, self.num_batches = 0 self.num_epochs = 0 self.max_checkpoints = max_checkpoints + if resume_type and not resume_type in ['batch', 'epoch']: + raise ValueError("Unknown resume type, please specify as `batch` or `epoch`.") + self.resume_type = resume_type self.saved_checkpoints = [] self.logger = logging.getLogger(__name__) if self.save_best: @@ -371,7 +376,8 @@ def train_begin(self, estimator, *args, **kwargs): self.num_batches = 0 if self.save_best: self.best = np.Inf if self.monitor_op == np.less else -np.Inf - + if self.resume_type: + self._resume_from_checkpoint(estimator) def batch_end(self, estimator, *args, **kwargs): # only save symbol once after first batch @@ -394,7 +400,12 @@ def _save_checkpoint(self, estimator, period_name, period_value, num_of_periods) if num_of_periods % period_value == 0: if self.verbose > 0: self.logger.info('[%s %d] saving model to %s', period_name, num_of_periods, self.model_dir) - prefix = "%s-%s%d" % (self.model_prefix, period_name.lower(), num_of_periods) + # if resumed from checkpoint, increment checkpoint number + if self.resume_type == period_name.lower(): + saved_period = num_of_periods + self.trained_period + 1 + else: + saved_period = num_of_periods + prefix = "%s-%s%d" % (self.model_prefix, period_name.lower(), saved_period) self._save_params_and_trainer(estimator, prefix) if self.save_best: @@ -422,7 +433,6 @@ def _save_checkpoint(self, estimator, period_name, period_value, num_of_periods) period_name, num_of_periods, monitor_name, self.best) - def _save_symbol(self, estimator): symbol_file = os.path.join(self.model_dir, self.model_prefix + '-symbol.json') if hasattr(estimator.net, '_cached_graph'): @@ -449,6 +459,44 @@ def _save_params_and_trainer(self, estimator, file_prefix): if fname.startswith(prefix): os.remove(os.path.join(self.model_dir, fname)) + def _resume_from_checkpoint(self, estimator): + self.trained_period = -1 + for fname in os.listdir(self.model_dir): + if fname.startswith(self.model_prefix): + if self.resume_type in fname and '.params' in fname: + self.saved_checkpoints.append(fname[:fname.find('.params')]) + try: + # find trained number of batch or epoch + period = int(fname[fname.find(self.resume_type) + 5: fname.find('.params')]) + if period > self.trained_period: + self.trained_period = period + except ValueError: + raise ValueError("Error parsing checkpoint files, please check your checkpoints " + "have the format {model_name}-{%s}{%s_number}.params, " + "there should also be a .states file for each .params file ", + self.resume_type, self.resume_type) + if self.trained_period != -1: + # change maximum number of epoch to train is resumed from epoch checkpoint + if self.resume_type == 'epoch': + if self.trained_period >= estimator.max_epochs - 1: + raise ValueError("Found checkpoint with maximum number of epoch %d reached, please specify" + "resume_type=None (default value) if you wan to train from scratch.", + self.trained_period + 1) + estimator.max_epochs = estimator.max_epochs - self.trained_period - 1 + param_file = "%s-%s%d.params" % (self.model_prefix, self.resume_type, self.trained_period) + param_file = os.path.join(self.model_dir, param_file) + trainer_file = "%s-%s%d.states" % (self.model_prefix, self.resume_type, self.trained_period) + trainer_file = os.path.join(self.model_dir, trainer_file) + assert os.path.exists(param_file), "Failed to load checkpoint, %s does not exist" % param_file + assert os.path.exists(trainer_file), "Failed to load checkpoint, %s does not exist" % trainer_file + estimator.net.load_parameters(param_file, ctx=estimator.context) + estimator.trainer.load_states(trainer_file) + self.logger.warning("Checkpoint resumed from %s %d, continue training for %d epochs.", + self.resume_type, self.trained_period, estimator.max_epochs) + else: + self.logger.info("No checkpoint found, training from scratch for %d epochs.", + estimator.max_epochs) + class EarlyStoppingHandler(TrainBegin, EpochEnd, TrainEnd): """Early stop training if monitored value is not improving diff --git a/tests/python/unittest/test_gluon_event_handler.py b/tests/python/unittest/test_gluon_event_handler.py index 32f766a1d16f..12fc6ba133f5 100644 --- a/tests/python/unittest/test_gluon_event_handler.py +++ b/tests/python/unittest/test_gluon_event_handler.py @@ -83,6 +83,33 @@ def test_checkpoint_handler(): assert os.path.isfile(file_path + '-batch7.params') assert os.path.isfile(file_path + '-batch7.states') +def test_resume_checkpoint(): + with TemporaryDirectory() as tmpdir: + model_prefix = 'test_net' + file_path = os.path.join(tmpdir, model_prefix) + test_data = _get_test_data() + + net = _get_test_network() + ce_loss = loss.SoftmaxCrossEntropyLoss() + acc = mx.metric.Accuracy() + est = estimator.Estimator(net, loss=ce_loss, metrics=acc) + checkpoint_handler = event_handler.CheckpointHandler(model_dir=tmpdir, + model_prefix=model_prefix, + monitor=acc, + max_checkpoints=1) + est.fit(test_data, event_handlers=[checkpoint_handler], epochs=2) + assert os.path.isfile(file_path + '-epoch1.params') + assert os.path.isfile(file_path + '-epoch1.states') + checkpoint_handler = event_handler.CheckpointHandler(model_dir=tmpdir, + model_prefix=model_prefix, + monitor=acc, + max_checkpoints=1, + resume_type='epoch') + est.fit(test_data, event_handlers=[checkpoint_handler], epochs=5) + # should only continue to train 3 epochs and last checkpoint file is epoch4 + assert est.max_epochs == 3 + assert os.path.isfile(file_path + '-epoch4.states') + def test_early_stopping(): test_data = _get_test_data() From ae38c87348fae41f6098ff669f7740f82095d692 Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Thu, 9 May 2019 20:53:36 -0700 Subject: [PATCH 6/9] update doc and resume checkpoint --- .../gluon/contrib/estimator/estimator.py | 48 ++- .../gluon/contrib/estimator/event_handler.py | 352 +++++++++++------- tests/python/unittest/test_gluon_estimator.py | 5 +- .../unittest/test_gluon_event_handler.py | 43 +-- 4 files changed, 278 insertions(+), 170 deletions(-) diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py index e3f20435fc66..33aba14b9b2f 100644 --- a/python/mxnet/gluon/contrib/estimator/estimator.py +++ b/python/mxnet/gluon/contrib/estimator/estimator.py @@ -22,7 +22,7 @@ import copy import warnings -from .event_handler import MetricHandler, ValidationHandler, LoggingHandler +from .event_handler import MetricHandler, ValidationHandler, LoggingHandler, StoppingHandler from .event_handler import TrainBegin, EpochBegin, BatchBegin, BatchEnd, EpochEnd, TrainEnd from .... import gluon, autograd from ....context import Context, cpu, gpu, num_gpus @@ -46,11 +46,11 @@ class Estimator(object): metrics : EvalMetric or list of EvalMetric Metrics for evaluating models initializer : Initializer - initializer to initialize the network + Initializer to initialize the network trainer : Trainer Trainer to apply optimizer on network parameters context : Context or list of Context - device(s) to run the training on + Device(s) to run the training on """ def __init__(self, net, @@ -202,11 +202,11 @@ def evaluate(self, Parameters ---------- val_data : DataLoader - validation data loader with data and labels + Validation data loader with data and labels val_metrics : EvalMetric or list of EvalMetrics - metrics to update validation result + Metrics to update validation result batch_axis : int, default 0 - batch axis to split the validation data into devices + Batch axis to split the validation data into devices """ if not isinstance(val_data, gluon.data.DataLoader): raise ValueError("Estimator only support input as Gluon DataLoader. Alternatively, you " @@ -229,8 +229,9 @@ def evaluate(self, def fit(self, train_data, val_data=None, - epochs=1, + epochs=None, event_handlers=None, + batches=None, batch_axis=0): """Trains the model on a given dataset for a specified number of epochs. Also, the batch size is inferred from the @@ -239,22 +240,34 @@ def fit(self, train_data, Parameters ---------- train_data : DataLoader - training data loader with data and labels + Training data loader with data and labels val_data : DataLoader - validation data loader with data and labels - epochs : int, default 1 - number of epochs to iterate on the training data. + Validation data loader with data and labels + epochs : int, default None + Number of epochs to iterate on the training data. + You can only specify one and only one of epochs or batches. event_handlers : EventHandler or list of EventHandler - list of EventHandlers to apply during training + List of EventHandlers to apply during training + batches : int, default None + Number of batches to iterate on the training data. + You can only specify one and only one of epochs or batches batch_axis : int, default 0 - batch axis to split the validation data into devices + Batch axis to split the validation data into devices """ if not isinstance(train_data, gluon.data.DataLoader): raise ValueError("Estimator only support input as Gluon DataLoader. Alternatively, you " "can transform your DataIter or any NDArray into Gluon DataLoader. " "Refer to gluon.data.dataloader") - self.max_epochs = epochs + # must specify one and only one of epochs or batches + if (not epochs) == (not batches): + raise ValueError( + "Fit only support exactly one type of iteration, " + "train by number of epochs or number of batches." + "Please specify one and only one of: epochs or batches.") + + self.max_epoch = epochs + self.max_batch = batches # provide default handlers event_handlers = self._prepare_default_handlers(val_data, event_handlers) @@ -268,7 +281,7 @@ def fit(self, train_data, for handler in train_begin: handler.train_begin(estimator_ref) - for epoch in range(self.max_epochs): + while True: # epoch begin for handler in epoch_begin: handler.epoch_begin(estimator_ref) @@ -317,6 +330,9 @@ def _prepare_default_handlers(self, val_data, event_handlers): default_handlers = [] train_metrics, val_metrics = self.prepare_loss_and_metrics() + # no need to add to default handler check as StoppingHandler does not use metrics + event_handlers.append(StoppingHandler(self.max_epoch, self.max_batch)) + if not any(isinstance(handler, MetricHandler) for handler in event_handlers): event_handlers.append(MetricHandler(train_metrics=train_metrics)) default_handlers.append("MetricHandler") @@ -332,7 +348,7 @@ def _prepare_default_handlers(self, val_data, event_handlers): default_handlers.append("LoggingHandler") # if there is a mix of user defined event handlers and default event handlers - # they should have the save set of loss and metrics + # they should have the same set of loss and metrics if default_handlers: msg = "You are training with the following default event handlers: %s. " \ "They use loss and metrics from estimator.prepare_loss_and_metrics(). " \ diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py index dcf42bf238fc..81140bed6f44 100644 --- a/python/mxnet/gluon/contrib/estimator/event_handler.py +++ b/python/mxnet/gluon/contrib/estimator/event_handler.py @@ -59,17 +59,56 @@ def batch_end(self, estimator, *args, **kwargs): return False +class StoppingHandler(TrainBegin, BatchEnd, EpochEnd): + """Stop conditions to stop training + Stop training if maximum number of batches or epochs + reached. + + Parameters + ---------- + max_batch : int, default None + Number of maximum batches to train. + max_epoch : int, default None + Number of maximum epochs to train + """ + + def __init__(self, max_epoch=None, max_batch=None): + self.max_epoch = max_epoch + self.max_batch = max_batch + self.current_batch = 0 + self.current_epoch = 0 + self.stop_training = False + + def train_begin(self, estimator, *args, **kwargs): + self.max_epoch = estimator.max_epoch + self.max_batch = estimator.max_batch + self.current_batch = 0 + self.current_epoch = 0 + + def batch_end(self, estimator, *args, **kwargs): + self.current_batch += 1 + if self.current_batch == self.max_batch: + self.stop_training = True + return self.stop_training + + def epoch_end(self, estimator, *args, **kwargs): + self.current_epoch += 1 + if self.current_epoch == self.max_epoch: + self.stop_training = True + return self.stop_training + + class MetricHandler(EpochBegin, BatchEnd): """Metric Handler that update metric values at batch end :py:class:`MetricHandler` takes model predictions and true labels - and update the metrics, it also update metric wrapper for loss with loss values + and update the metrics, it also update metric wrapper for loss with loss values. Validation loss and metrics will be handled by :py:class:`ValidationHandler` Parameters ---------- train_metrics : List of EvalMetrics - training metrics to be updated at batch end + Training metrics to be updated at batch end """ def __init__(self, train_metrics): @@ -104,18 +143,18 @@ class ValidationHandler(TrainBegin, BatchEnd, EpochEnd): Parameters ---------- val_data : DataLoader - validation data set to run evaluation + Validation data set to run evaluation eval_fn : function - a function defines how to run evaluation and + A function defines how to run evaluation and calculate loss and metrics val_metrics : List of EvalMetrics - validation metrics to be updated + Validation metrics to be updated epoch_period : int, default 1 - how often to run validation at epoch end, by default - validate every epoch + How often to run validation at epoch end, by default + :py:class:`ValidationHandler` validate every epoch batch_period : int, default None - how often to run validation at batch end, by default - does not validate at batch end + How often to run validation at batch end, by default + :py:class:`ValidationHandler` does not validate at batch end """ def __init__(self, @@ -129,26 +168,26 @@ def __init__(self, self.epoch_period = epoch_period self.batch_period = batch_period self.val_metrics = val_metrics - self.num_batches = 0 - self.num_epochs = 0 + self.current_batch = 0 + self.current_epoch = 0 # order to be called among all callbacks # validation metrics need to be calculated before other callbacks can access them self.priority = -np.Inf def train_begin(self, estimator, *args, **kwargs): # reset epoch and batch counter - self.num_batches = 0 - self.num_epochs = 0 + self.current_batch = 0 + self.current_epoch = 0 def batch_end(self, estimator, *args, **kwargs): - self.num_batches += 1 - if self.batch_period and self.num_batches % self.batch_period == 0: + self.current_batch += 1 + if self.batch_period and self.current_batch % self.batch_period == 0: self.eval_fn(val_data=self.val_data, val_metrics=self.val_metrics) def epoch_end(self, estimator, *args, **kwargs): - self.num_epochs += 1 - if self.num_epochs % self.epoch_period == 0: + self.current_epoch += 1 + if self.epoch_period and self.current_epoch % self.epoch_period == 0: self.eval_fn(val_data=self.val_data, val_metrics=self.val_metrics) @@ -162,19 +201,19 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat Parameters ---------- file_name : str - file name to save the logs + File name to save the logs file_location : str - file location to save the logs + File location to save the logs filemode : str, default 'a' - logging file mode, default using append mode + Logging file mode, default using append mode verbose : int, default LOG_VERBOSITY_PER_EPOCH Limit the granularity of metrics displayed during training process verbose=LOG_VERBOSITY_PER_EPOCH: display metrics every epoch verbose=LOG_VERBOSITY_PER_BATCH: display metrics every batch train_metrics : list of EvalMetrics - training metrics to be logged, logged at batch end, epoch end, train end + Training metrics to be logged, logged at batch end, epoch end, train end val_metrics : list of EvalMetrics - validation metrics to be logged, logged at epoch end, train end + Validation metrics to be logged, logged at epoch end, train end """ LOG_VERBOSITY_PER_EPOCH = 1 @@ -191,18 +230,18 @@ def __init__(self, file_name=None, self.logger.setLevel(logging.INFO) stream_handler = logging.StreamHandler() self.logger.addHandler(stream_handler) - if verbose not in [self.LOG_VERBOSITY_PER_EPOCH, self.LOG_VERBOSITY_PER_BATCH]: - raise ValueError("verbose level must be either LOG_VERBOSITY_PER_EPOCH or " - "LOG_VERBOSITY_PER_BATCH, received %s. " - "E.g: LoggingHandler(verbose=LoggingHandler.LOG_VERBOSITY_PER_EPOCH)" - % verbose) - self.verbose = verbose # save logger to file only if file name or location is specified if file_name or file_location: file_name = file_name or 'estimator_log' file_location = file_location or './' file_handler = logging.FileHandler(os.path.join(file_location, file_name), mode=filemode) self.logger.addHandler(file_handler) + if verbose not in [self.LOG_VERBOSITY_PER_EPOCH, self.LOG_VERBOSITY_PER_BATCH]: + raise ValueError("verbose level must be either LOG_VERBOSITY_PER_EPOCH or " + "LOG_VERBOSITY_PER_BATCH, received %s. " + "E.g: LoggingHandler(verbose=LoggingHandler.LOG_VERBOSITY_PER_EPOCH)" + % verbose) + self.verbose = verbose self.train_metrics = train_metrics or [] self.val_metrics = val_metrics or [] self.batch_index = 0 @@ -220,7 +259,10 @@ def train_begin(self, estimator, *args, **kwargs): self.logger.info("Training begin: using optimizer %s " "with current learning rate %.4f ", optimizer, lr) - self.logger.info("Train for %d epochs.", estimator.max_epochs) + if estimator.max_epoch: + self.logger.info("Train for %d epochs.", estimator.max_epoch) + else: + self.logger.info("Train for %d batches.", estimator.max_batch) # reset all counters self.current_epoch = 0 self.batch_index = 0 @@ -234,7 +276,9 @@ def train_end(self, estimator, *args, **kwargs): name, value = metric.get() msg += '%s : %.4f ' % (name, value) self.logger.info(msg) - for handler in self.logger.handlers: + # make a copy of handler list and remove one by one + # as removing handler will edit the handler list + for handler in self.logger.handlers[:]: handler.close() self.logger.removeHandler(handler) logging.shutdown() @@ -285,30 +329,32 @@ class CheckpointHandler(TrainBegin, BatchEnd, EpochEnd): Parameters ---------- model_dir : str - file directory to save all the model related files including model architecture, + File directory to save all the model related files including model architecture, model parameters, and trainer states. model_prefix : str default 'model' - prefix to add for all checkpoint file names + Prefix to add for all checkpoint file names monitor: EvalMetric - the metrics to monitor and determine if model has improved + The metrics to monitor and determine if model has improved verbose: int, default 0 - verbosity mode + Verbosity mode, 1 means inform user every time a checkpoint is saved save_best: bool - if True, save the model parameters and trainer states with the best monitored value + If True, save the model parameters and trainer states with the best monitored value mode: str, default 'auto' - one of {auto, min, max}, if `save_best_only=True`, the comparison to make + One of {auto, min, max}, if `save_best_only=True`, the comparison to make and determine if the monitored value has improved. if 'auto' mode, checkpoint handler will try to use min or max based on the monitored metric name epoch_period: int, default 1 - epoch intervals between saving the network + Epoch intervals between saving the network batch_period: int, default None - batch intervals between saving the network, + Batch intervals between saving the network, by default, checkpoints are not saved based on the number of batches max_checkpoints : int, default 5 - maximum number of checkpoint files to keep in the model_dir, older checkpoints - will be removed - resume_from : str, default None - one of {batch, epoch}, select which type of checkpoints to resume from + Maximum number of checkpoint files to keep in the model_dir, older checkpoints + will be removed. Does not count best checkpoint + resume_from_checkpoint : bool, default False + Whether to resume training from checkpoint in model_dir. If True and checkpoints + found, :py:class:`CheckpointHandler` will load net parameters and trainer states, + and train the remaining of epochs and batches. """ def __init__(self, @@ -321,7 +367,7 @@ def __init__(self, epoch_period=1, batch_period=None, max_checkpoints=5, - resume_type=None): + resume_from_checkpoint=False): self.monitor = monitor self.verbose = verbose if not os.path.exists(model_dir): @@ -334,12 +380,10 @@ def __init__(self, "You can get these objects using estimator.prepare_loss_and_metric()") self.epoch_period = epoch_period self.batch_period = batch_period - self.num_batches = 0 - self.num_epochs = 0 + self.current_batch = 0 + self.current_epoch = 0 self.max_checkpoints = max_checkpoints - if resume_type and not resume_type in ['batch', 'epoch']: - raise ValueError("Unknown resume type, please specify as `batch` or `epoch`.") - self.resume_type = resume_type + self.resume_from_checkpoint = resume_from_checkpoint self.saved_checkpoints = [] self.logger = logging.getLogger(__name__) if self.save_best: @@ -372,66 +416,79 @@ def __init__(self, def train_begin(self, estimator, *args, **kwargs): # reset all counters - self.num_epochs = 0 - self.num_batches = 0 + self.current_epoch = 0 + self.current_batch = 0 if self.save_best: self.best = np.Inf if self.monitor_op == np.less else -np.Inf - if self.resume_type: + if self.resume_from_checkpoint: + error_msg = "To use resume from checkpoint, you must only specify " \ + "the same type of period you used for training." \ + "For example, if you are training based on number of epochs," \ + "you must save only based on epochs, and set batch_period to None." + if estimator.max_batch: + assert self.batch_period, error_msg + assert not self.epoch_period, error_msg + if estimator.max_epoch: + assert self.epoch_period, error_msg + assert not self.batch_period, error_msg + self._resume_from_checkpoint(estimator) def batch_end(self, estimator, *args, **kwargs): # only save symbol once after first batch - if self.num_batches == 0: + if self.current_batch == 0: self._save_symbol(estimator) - if self.batch_period: - self._save_checkpoint(estimator, "Batch", self.batch_period, self.num_batches) - self.num_batches += 1 + if self.batch_period and (self.current_batch + 1) % self.batch_period == 0: + self._save_checkpoint(estimator) + self.current_batch += 1 def epoch_end(self, estimator, *args, **kwargs): - if self.epoch_period: - self._save_checkpoint(estimator, "Epoch", self.epoch_period, self.num_epochs) - self.num_epochs += 1 - - def _save_checkpoint(self, estimator, period_name, period_value, num_of_periods): - # period name can be batch or epoch - # period value determine how often a checkpoint is saved - # num_of_periods records the number of batch or epoch - # add extension for weights - if num_of_periods % period_value == 0: - if self.verbose > 0: - self.logger.info('[%s %d] saving model to %s', period_name, num_of_periods, self.model_dir) - # if resumed from checkpoint, increment checkpoint number - if self.resume_type == period_name.lower(): - saved_period = num_of_periods + self.trained_period + 1 + if self.epoch_period and (self.current_epoch + 1) % self.epoch_period == 0: + self._save_checkpoint(estimator) + self.current_epoch += 1 + + def _save_checkpoint(self, estimator): + if self.verbose > 0: + self.logger.info('[Epoch %d][Batch %d] saving model to %s', + self.current_epoch, self.current_batch, self.model_dir) + # if resumed from checkpoint, increment checkpoint number + if self.resume_from_checkpoint: + save_epoch_number = self.current_epoch + self.trained_epoch + 1 + if estimator.max_epoch: + # checkpoint saved at epoch end, batch number already incremented + save_batch_number = self.current_batch + self.trained_batch + else: + save_batch_number = self.current_batch + self.trained_batch + 1 + else: + save_epoch_number = self.current_epoch + save_batch_number = self.current_batch + prefix = "%s-epoch%dbatch%d" % (self.model_prefix, save_epoch_number, save_batch_number) + self._save_params_and_trainer(estimator, prefix) + + if self.save_best: + monitor_name, monitor_value = self.monitor.get() + # check if monitor exists in train stats + if np.isnan(monitor_value): + warnings.warn(RuntimeWarning('Skipping save best because %s is not updated, make sure you ' + 'pass one of the metric objects as monitor, ' + 'you can use estimator.prepare_loss_and_metrics to' + 'create all metric objects', monitor_name)) else: - saved_period = num_of_periods - prefix = "%s-%s%d" % (self.model_prefix, period_name.lower(), saved_period) - self._save_params_and_trainer(estimator, prefix) - - if self.save_best: - monitor_name, monitor_value = self.monitor.get() - # check if monitor exists in train stats - if np.isnan(monitor_value): - warnings.warn(RuntimeWarning('Skipping save best because %s is not updated, make sure you ' - 'pass one of the metric objects as monitor, ' - 'you can use estimator.prepare_loss_and_metrics to' - 'create all metric objects', monitor_name)) + if self.monitor_op(monitor_value, self.best): + if self.verbose > 0: + self.logger.info('[Epoch %d][Batch %d] %s improved from %0.5f to %0.5f,' + ' updating best model to %s', + self.current_epoch, self.current_batch, monitor_name, + self.best, monitor_value, self.model_dir) + prefix = self.model_prefix + '-best' + self._save_params_and_trainer(estimator, prefix) + self.best = monitor_value else: - if self.monitor_op(monitor_value, self.best): - if self.verbose > 0: - self.logger.info('[%s %d] %s improved from %0.5f to %0.5f,' - ' updating best model to %s', - period_name, num_of_periods, monitor_name, - self.best, monitor_value, self.model_dir) - prefix = self.model_prefix + '-best' - self._save_params_and_trainer(estimator, prefix) - self.best = monitor_value - else: - if self.verbose > 0: - self.logger.info('[%s %d] %s did not improve from %0.5f, ' - 'skipping updating best model', - period_name, num_of_periods, monitor_name, - self.best) + if self.verbose > 0: + self.logger.info('[Epoch %d][Batch %d] %s did not improve from %0.5f, ' + 'skipping updating best model', + self.current_epoch, self.current_batch, monitor_name, + self.best) def _save_symbol(self, estimator): symbol_file = os.path.join(self.model_dir, self.model_prefix + '-symbol.json') @@ -460,42 +517,75 @@ def _save_params_and_trainer(self, estimator, file_prefix): os.remove(os.path.join(self.model_dir, fname)) def _resume_from_checkpoint(self, estimator): - self.trained_period = -1 - for fname in os.listdir(self.model_dir): - if fname.startswith(self.model_prefix): - if self.resume_type in fname and '.params' in fname: - self.saved_checkpoints.append(fname[:fname.find('.params')]) - try: - # find trained number of batch or epoch - period = int(fname[fname.find(self.resume_type) + 5: fname.find('.params')]) - if period > self.trained_period: - self.trained_period = period - except ValueError: - raise ValueError("Error parsing checkpoint files, please check your checkpoints " - "have the format {model_name}-{%s}{%s_number}.params, " - "there should also be a .states file for each .params file ", - self.resume_type, self.resume_type) - if self.trained_period != -1: - # change maximum number of epoch to train is resumed from epoch checkpoint - if self.resume_type == 'epoch': - if self.trained_period >= estimator.max_epochs - 1: + prefix = self.model_prefix + '-epoch' + self.trained_epoch = self._find_max_iteration( + dir=self.model_dir, + prefix=prefix, + start='epoch', + end='batch', + saved_checkpoints=self.saved_checkpoints) + prefix += str(self.trained_epoch) + self.trained_batch = self._find_max_iteration( + dir=self.model_dir, + prefix=prefix, + start='batch', + end='.params') + + if self.trained_epoch == -1: + msg = "No checkpoint found, training from scratch for " + if estimator.max_batch: + msg += "%d batches" % estimator.max_batch + else: + msg += "%d epochs" % estimator.max_epoch + self.logger.info(msg) + else: + msg = "Checkpoint resumed from epoch %d batch %d, continue to train for " % ( + self.trained_epoch, self.trained_batch) + # change maximum number of epoch or batch to train if resumed from epoch checkpoint + if estimator.max_epoch: + if self.trained_epoch >= estimator.max_epoch - 1: raise ValueError("Found checkpoint with maximum number of epoch %d reached, please specify" - "resume_type=None (default value) if you wan to train from scratch.", - self.trained_period + 1) - estimator.max_epochs = estimator.max_epochs - self.trained_period - 1 - param_file = "%s-%s%d.params" % (self.model_prefix, self.resume_type, self.trained_period) + "resume_from_checkpoint=False (default value) if you wan to train from scratch.", + estimator.max_epoch) + estimator.max_epoch = estimator.max_epoch - self.trained_epoch - 1 + msg += "%d epochs " % estimator.max_epoch + if estimator.max_batch: + if self.trained_batch >= estimator.max_batch - 1: + raise ValueError("Found checkpoint with maximum number of batch %d reached, please specify" + "resume_from_checkpoint=False (default value) if you wan to train from scratch.", + self.trained_batch) + estimator.max_batch = estimator.max_batch - self.trained_batch - 1 + msg += "%d batches " % estimator.max_batch + # load checkpoint + param_file = "%s-epoch%dbatch%d.params" % (self.model_prefix, self.trained_epoch, self.trained_batch) param_file = os.path.join(self.model_dir, param_file) - trainer_file = "%s-%s%d.states" % (self.model_prefix, self.resume_type, self.trained_period) + trainer_file = "%s-epoch%dbatch%d.states" % (self.model_prefix, self.trained_epoch, self.trained_batch) trainer_file = os.path.join(self.model_dir, trainer_file) assert os.path.exists(param_file), "Failed to load checkpoint, %s does not exist" % param_file assert os.path.exists(trainer_file), "Failed to load checkpoint, %s does not exist" % trainer_file estimator.net.load_parameters(param_file, ctx=estimator.context) estimator.trainer.load_states(trainer_file) - self.logger.warning("Checkpoint resumed from %s %d, continue training for %d epochs.", - self.resume_type, self.trained_period, estimator.max_epochs) - else: - self.logger.info("No checkpoint found, training from scratch for %d epochs.", - estimator.max_epochs) + self.logger.warning(msg) + + def _find_max_iteration(self, dir, prefix, start, end, saved_checkpoints=None): + error_msg = "Error parsing checkpoint file, please check your " \ + "checkpoints have the format: " \ + "{model_name}-epoch{epoch_number}batch{batch_number}.params, " \ + "there should also be a .states file for each .params file " + max_iter = -1 + for fname in os.listdir(dir): + if fname.startswith(prefix) and '.params' in fname: + if saved_checkpoints: + # save prefix of existing checkpoints + saved_checkpoints.append(fname[:fname.find('.params')]) + try: + # find trained number of epoch + iter = int(fname[fname.find(start) + len(start): fname.find(end)]) + if iter > max_iter: + max_iter = iter + except ValueError: + raise ValueError(error_msg) + return max_iter class EarlyStoppingHandler(TrainBegin, EpochEnd, TrainEnd): @@ -504,17 +594,17 @@ class EarlyStoppingHandler(TrainBegin, EpochEnd, TrainEnd): Parameters ---------- monitor: EvalMetric - the metrics to monitor + The metric to monitor, and stop training if this metric does not improve min_delta: float, default 0 - minimal change in monitored value to be considered as an improvement + Minimal change in monitored value to be considered as an improvement patience: int, default 0 - number of epochs to wait for improvement before terminate training + Number of epochs to wait for improvement before terminate training mode: str, default 'auto' - one of {auto, min, max}, if `save_best_only=True`, the comparison to make + One of {auto, min, max}, if `save_best_only=True`, the comparison to make and determine if the monitored value has improved. if 'auto' mode, checkpoint handler will try to use min or max based on the monitored metric name baseline: float - baseline value to compare the monitored value with + Baseline value to compare the monitored value with """ def __init__(self, diff --git a/tests/python/unittest/test_gluon_estimator.py b/tests/python/unittest/test_gluon_estimator.py index 315c5140fc39..a6822cb39a33 100644 --- a/tests/python/unittest/test_gluon_estimator.py +++ b/tests/python/unittest/test_gluon_estimator.py @@ -363,8 +363,9 @@ def test_default_handlers(): est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[logging]) # test handler order + train_metrics, val_metrics = est.prepare_loss_and_metrics() early_stopping = EarlyStoppingHandler(monitor=val_metrics[0]) handlers = est._prepare_default_handlers(val_data=None, event_handlers=[early_stopping]) - assert len(handlers) == 3 + assert len(handlers) == 4 assert isinstance(handlers[0], MetricHandler) - assert isinstance(handlers[2], LoggingHandler) + assert isinstance(handlers[3], LoggingHandler) diff --git a/tests/python/unittest/test_gluon_event_handler.py b/tests/python/unittest/test_gluon_event_handler.py index 12fc6ba133f5..7ea5ff3f4b62 100644 --- a/tests/python/unittest/test_gluon_event_handler.py +++ b/tests/python/unittest/test_gluon_event_handler.py @@ -40,7 +40,7 @@ def _get_test_data(): def test_checkpoint_handler(): with TemporaryDirectory() as tmpdir: - model_prefix = 'test_net' + model_prefix = 'test_epoch' file_path = os.path.join(tmpdir, model_prefix) test_data = _get_test_data() @@ -51,14 +51,15 @@ def test_checkpoint_handler(): checkpoint_handler = event_handler.CheckpointHandler(model_dir=tmpdir, model_prefix=model_prefix, monitor=acc, - save_best=True) + save_best=True, + epoch_period=1) est.fit(test_data, event_handlers=[checkpoint_handler], epochs=1) - assert checkpoint_handler.num_epochs == 1 - assert checkpoint_handler.num_batches == 4 + assert checkpoint_handler.current_epoch == 1 + assert checkpoint_handler.current_batch == 4 assert os.path.isfile(file_path + '-best.params') assert os.path.isfile(file_path + '-best.states') - assert os.path.isfile(file_path + '-epoch0.params') - assert os.path.isfile(file_path + '-epoch0.states') + assert os.path.isfile(file_path + '-epoch0batch4.params') + assert os.path.isfile(file_path + '-epoch0batch4.states') model_prefix = 'test_batch' file_path = os.path.join(tmpdir, model_prefix) @@ -68,20 +69,20 @@ def test_checkpoint_handler(): checkpoint_handler = event_handler.CheckpointHandler(model_dir=tmpdir, model_prefix=model_prefix, epoch_period=None, - batch_period=1, + batch_period=2, max_checkpoints=2) - est.fit(test_data, event_handlers=[checkpoint_handler], epochs=2) - assert checkpoint_handler.num_epochs == 2 - assert checkpoint_handler.num_batches == 8 + est.fit(test_data, event_handlers=[checkpoint_handler], batches=10) + assert checkpoint_handler.current_batch == 10 + assert checkpoint_handler.current_epoch == 3 assert not os.path.isfile(file_path + 'best.params') assert not os.path.isfile(file_path + 'best.states') - assert not os.path.isfile(file_path + '-batch0.params') - assert not os.path.isfile(file_path + '-batch0.states') + assert not os.path.isfile(file_path + '-epoch0batch0.params') + assert not os.path.isfile(file_path + '-epoch0batch0.states') assert os.path.isfile(file_path + '-symbol.json') - assert os.path.isfile(file_path + '-batch6.params') - assert os.path.isfile(file_path + '-batch6.states') - assert os.path.isfile(file_path + '-batch7.params') - assert os.path.isfile(file_path + '-batch7.states') + assert os.path.isfile(file_path + '-epoch1batch7.params') + assert os.path.isfile(file_path + '-epoch1batch7.states') + assert os.path.isfile(file_path + '-epoch2batch9.params') + assert os.path.isfile(file_path + '-epoch2batch9.states') def test_resume_checkpoint(): with TemporaryDirectory() as tmpdir: @@ -98,17 +99,17 @@ def test_resume_checkpoint(): monitor=acc, max_checkpoints=1) est.fit(test_data, event_handlers=[checkpoint_handler], epochs=2) - assert os.path.isfile(file_path + '-epoch1.params') - assert os.path.isfile(file_path + '-epoch1.states') + assert os.path.isfile(file_path + '-epoch1batch8.params') + assert os.path.isfile(file_path + '-epoch1batch8.states') checkpoint_handler = event_handler.CheckpointHandler(model_dir=tmpdir, model_prefix=model_prefix, monitor=acc, max_checkpoints=1, - resume_type='epoch') + resume_from_checkpoint=True) est.fit(test_data, event_handlers=[checkpoint_handler], epochs=5) # should only continue to train 3 epochs and last checkpoint file is epoch4 - assert est.max_epochs == 3 - assert os.path.isfile(file_path + '-epoch4.states') + assert est.max_epoch == 3 + assert os.path.isfile(file_path + '-epoch4batch20.states') def test_early_stopping(): From 88c00b2b6c0871e7b422476e0597a1626991a4a0 Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Fri, 10 May 2019 02:37:51 -0700 Subject: [PATCH 7/9] update docs --- .../gluon/contrib/estimator/estimator.py | 47 ++--- .../gluon/contrib/estimator/event_handler.py | 160 ++++++++++-------- tests/python/unittest/test_gluon_estimator.py | 8 +- 3 files changed, 116 insertions(+), 99 deletions(-) diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py index 33aba14b9b2f..da1a3915caec 100644 --- a/python/mxnet/gluon/contrib/estimator/estimator.py +++ b/python/mxnet/gluon/contrib/estimator/estimator.py @@ -40,17 +40,17 @@ class Estimator(object): Parameters ---------- net : Block - The model used for training + The model used for training. loss : gluon.loss.Loss or list of gluon.loss.Loss - Loss(objective functions) to calculate during training + Loss(objective functions) to calculate during training. metrics : EvalMetric or list of EvalMetric - Metrics for evaluating models + Metrics for evaluating models. initializer : Initializer - Initializer to initialize the network + Initializer to initialize the network. trainer : Trainer - Trainer to apply optimizer on network parameters + Trainer to apply optimizer on network parameters. context : Context or list of Context - Device(s) to run the training on + Device(s) to run the training on. """ def __init__(self, net, @@ -188,8 +188,8 @@ def prepare_loss_and_metrics(self): self.train_metrics.append(Loss(loss.name.rstrip('1234567890'))) for metric in self.train_metrics: val_metric = copy.deepcopy(metric) - metric.name = "Train " + metric.name - val_metric.name = "Validation " + val_metric.name + metric.name = "train " + metric.name + val_metric.name = "validation " + val_metric.name self.val_metrics.append(val_metric) return self.train_metrics, self.val_metrics @@ -202,11 +202,11 @@ def evaluate(self, Parameters ---------- val_data : DataLoader - Validation data loader with data and labels + Validation data loader with data and labels. val_metrics : EvalMetric or list of EvalMetrics - Metrics to update validation result + Metrics to update validation result. batch_axis : int, default 0 - Batch axis to split the validation data into devices + Batch axis to split the validation data into devices. """ if not isinstance(val_data, gluon.data.DataLoader): raise ValueError("Estimator only support input as Gluon DataLoader. Alternatively, you " @@ -233,26 +233,26 @@ def fit(self, train_data, event_handlers=None, batches=None, batch_axis=0): - """Trains the model on a given dataset for a specified - number of epochs. Also, the batch size is inferred from the - DataLoader's batch_size. + """Trains the model with a given :py:class:`DataLoader` for a specified + number of epochs or batches. The batch size is inferred from the + data loader's batch_size. Parameters ---------- train_data : DataLoader - Training data loader with data and labels - val_data : DataLoader - Validation data loader with data and labels + Training data loader with data and labels. + val_data : DataLoader, default None + Validation data loader with data and labels. epochs : int, default None Number of epochs to iterate on the training data. - You can only specify one and only one of epochs or batches. + You can only specify one and only one type of iteration(epochs or batches). event_handlers : EventHandler or list of EventHandler - List of EventHandlers to apply during training + List of :py:class:`EventHandlers` to apply during training. batches : int, default None Number of batches to iterate on the training data. - You can only specify one and only one of epochs or batches + You can only specify one and only one type of iteration(epochs or batches). batch_axis : int, default 0 - Batch axis to split the validation data into devices + Batch axis to split the training data into devices. """ if not isinstance(train_data, gluon.data.DataLoader): raise ValueError("Estimator only support input as Gluon DataLoader. Alternatively, you " @@ -355,6 +355,7 @@ def _prepare_default_handlers(self, val_data, event_handlers): "Please use the same set of metrics for all your other handlers." % \ ", ".join(default_handlers) warnings.warn(msg) + # check if all handlers has the same set of references to loss and metrics references = [] for handler in event_handlers: for attribute in dir(handler): @@ -364,8 +365,10 @@ def _prepare_default_handlers(self, val_data, event_handlers): references += reference else: references.append(reference) + # remove None metric references + references = set([ref for ref in references if ref]) for metric in references: - if metric and metric not in train_metrics + val_metrics: + if metric not in train_metrics + val_metrics: msg = "We have added following default handlers for you: %s and used " \ "estimator.prepare_loss_and_metrics() to pass metrics to " \ "those handlers. Please use the same set of metrics " \ diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py index 81140bed6f44..ce5890e0bcae 100644 --- a/python/mxnet/gluon/contrib/estimator/event_handler.py +++ b/python/mxnet/gluon/contrib/estimator/event_handler.py @@ -66,10 +66,11 @@ class StoppingHandler(TrainBegin, BatchEnd, EpochEnd): Parameters ---------- + max_epoch : int, default None + Number of maximum epochs to train. max_batch : int, default None Number of maximum batches to train. - max_epoch : int, default None - Number of maximum epochs to train + """ def __init__(self, max_epoch=None, max_batch=None): @@ -108,7 +109,7 @@ class MetricHandler(EpochBegin, BatchEnd): Parameters ---------- train_metrics : List of EvalMetrics - Training metrics to be updated at batch end + Training metrics to be updated at batch end. """ def __init__(self, train_metrics): @@ -143,18 +144,18 @@ class ValidationHandler(TrainBegin, BatchEnd, EpochEnd): Parameters ---------- val_data : DataLoader - Validation data set to run evaluation + Validation data set to run evaluation. eval_fn : function A function defines how to run evaluation and - calculate loss and metrics + calculate loss and metrics. val_metrics : List of EvalMetrics - Validation metrics to be updated + Validation metrics to be updated. epoch_period : int, default 1 How often to run validation at epoch end, by default - :py:class:`ValidationHandler` validate every epoch + :py:class:`ValidationHandler` validate every epoch. batch_period : int, default None How often to run validation at batch end, by default - :py:class:`ValidationHandler` does not validate at batch end + :py:class:`ValidationHandler` does not validate at batch end. """ def __init__(self, @@ -173,6 +174,7 @@ def __init__(self, # order to be called among all callbacks # validation metrics need to be calculated before other callbacks can access them self.priority = -np.Inf + self.logger = logging.getLogger(__name__) def train_begin(self, estimator, *args, **kwargs): # reset epoch and batch counter @@ -184,6 +186,12 @@ def batch_end(self, estimator, *args, **kwargs): if self.batch_period and self.current_batch % self.batch_period == 0: self.eval_fn(val_data=self.val_data, val_metrics=self.val_metrics) + msg = '[Epoch %d] ValidationHandler: %d batches reached, ' \ + % (self.current_epoch, self.current_batch) + for monitor in self.val_metrics: + name, value = monitor.get() + msg += '%s: %.4f, ' % (name, value) + self.logger.info(msg.rstrip(',')) def epoch_end(self, estimator, *args, **kwargs): self.current_epoch += 1 @@ -201,28 +209,28 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat Parameters ---------- file_name : str - File name to save the logs + File name to save the logs. file_location : str - File location to save the logs + File location to save the logs. filemode : str, default 'a' - Logging file mode, default using append mode - verbose : int, default LOG_VERBOSITY_PER_EPOCH - Limit the granularity of metrics displayed during training process - verbose=LOG_VERBOSITY_PER_EPOCH: display metrics every epoch - verbose=LOG_VERBOSITY_PER_BATCH: display metrics every batch + Logging file mode, default using append mode. + verbose : int, default LOG_PER_EPOCH + Limit the granularity of metrics displayed during training process. + verbose=LOG_PER_EPOCH: display metrics every epoch + verbose=LOG_PER_BATCH: display metrics every batch train_metrics : list of EvalMetrics - Training metrics to be logged, logged at batch end, epoch end, train end + Training metrics to be logged, logged at batch end, epoch end, train end. val_metrics : list of EvalMetrics - Validation metrics to be logged, logged at epoch end, train end + Validation metrics to be logged, logged at epoch end, train end. """ - LOG_VERBOSITY_PER_EPOCH = 1 - LOG_VERBOSITY_PER_BATCH = 2 + LOG_PER_EPOCH = 1 + LOG_PER_BATCH = 2 def __init__(self, file_name=None, file_location=None, filemode='a', - verbose=LOG_VERBOSITY_PER_EPOCH, + verbose=LOG_PER_EPOCH, train_metrics=None, val_metrics=None): super(LoggingHandler, self).__init__() @@ -236,10 +244,10 @@ def __init__(self, file_name=None, file_location = file_location or './' file_handler = logging.FileHandler(os.path.join(file_location, file_name), mode=filemode) self.logger.addHandler(file_handler) - if verbose not in [self.LOG_VERBOSITY_PER_EPOCH, self.LOG_VERBOSITY_PER_BATCH]: - raise ValueError("verbose level must be either LOG_VERBOSITY_PER_EPOCH or " - "LOG_VERBOSITY_PER_BATCH, received %s. " - "E.g: LoggingHandler(verbose=LoggingHandler.LOG_VERBOSITY_PER_EPOCH)" + if verbose not in [self.LOG_PER_EPOCH, self.LOG_PER_BATCH]: + raise ValueError("verbose level must be either LOG_PER_EPOCH or " + "LOG_PER_BATCH, received %s. " + "E.g: LoggingHandler(verbose=LoggingHandler.LOG_PER_EPOCH)" % verbose) self.verbose = verbose self.train_metrics = train_metrics or [] @@ -270,12 +278,12 @@ def train_begin(self, estimator, *args, **kwargs): def train_end(self, estimator, *args, **kwargs): train_time = time.time() - self.train_start - msg = 'Train finished using total %ds with %d epochs.' % (train_time, self.current_epoch) + msg = 'Train finished using total %ds with %d epochs. ' % (train_time, self.current_epoch) # log every result in train stats including train/validation loss & metrics for metric in self.train_metrics + self.val_metrics: name, value = metric.get() - msg += '%s : %.4f ' % (name, value) - self.logger.info(msg) + msg += '%s: %.4f, ' % (name, value) + self.logger.info(msg.rstrip(', ')) # make a copy of handler list and remove one by one # as removing handler will edit the handler list for handler in self.logger.handlers[:]: @@ -284,37 +292,37 @@ def train_end(self, estimator, *args, **kwargs): logging.shutdown() def batch_begin(self, estimator, *args, **kwargs): - if self.verbose == self.LOG_VERBOSITY_PER_BATCH: + if self.verbose == self.LOG_PER_BATCH: self.batch_start = time.time() def batch_end(self, estimator, *args, **kwargs): - if self.verbose == self.LOG_VERBOSITY_PER_BATCH: + if self.verbose == self.LOG_PER_BATCH: batch_time = time.time() - self.batch_start - msg = '[Epoch %d] [Batch %d] ' % (self.current_epoch, self.batch_index) + msg = '[Epoch %d][Batch %d]' % (self.current_epoch, self.batch_index) self.processed_samples += kwargs['batch'][0].shape[0] msg += '[Samples %s] ' % (self.processed_samples) msg += 'time/batch: %.3fs ' % batch_time for metric in self.train_metrics: # only log current training loss & metric after each batch name, value = metric.get() - msg += '%s : %.4f ' % (name, value) - self.logger.info(msg) + msg += '%s: %.4f, ' % (name, value) + self.logger.info(msg.rstrip(', ')) self.batch_index += 1 def epoch_begin(self, estimator, *args, **kwargs): - if self.verbose >= self.LOG_VERBOSITY_PER_EPOCH: + if self.verbose >= self.LOG_PER_EPOCH: self.epoch_start = time.time() - self.logger.info("[Epoch %d] begin, current learning rate: %.4f", + self.logger.info("[Epoch %d] Begin, current learning rate: %.4f", self.current_epoch, estimator.trainer.learning_rate) def epoch_end(self, estimator, *args, **kwargs): - if self.verbose >= self.LOG_VERBOSITY_PER_EPOCH: + if self.verbose >= self.LOG_PER_EPOCH: epoch_time = time.time() - self.epoch_start - msg = '[Epoch %d] finished in %.3fs: ' % (self.current_epoch, epoch_time) + msg = '[Epoch %d] Finished in %.3fs, ' % (self.current_epoch, epoch_time) for monitor in self.train_metrics + self.val_metrics: name, value = monitor.get() - msg += '%s : %.4f ' % (name, value) - self.logger.info(msg) + msg += '%s: %.4f, ' % (name, value) + self.logger.info(msg.rstrip(', ')) self.current_epoch += 1 self.batch_index = 0 @@ -332,25 +340,28 @@ class CheckpointHandler(TrainBegin, BatchEnd, EpochEnd): File directory to save all the model related files including model architecture, model parameters, and trainer states. model_prefix : str default 'model' - Prefix to add for all checkpoint file names - monitor: EvalMetric + Prefix to add for all checkpoint file names. + monitor: EvalMetric, default None The metrics to monitor and determine if model has improved verbose: int, default 0 Verbosity mode, 1 means inform user every time a checkpoint is saved - save_best: bool - If True, save the model parameters and trainer states with the best monitored value + save_best: bool, default False + If True, monitor must not be None, :py:class:`CheckpointHandler` will save the + model parameters and trainer states with the best monitored value. mode: str, default 'auto' - One of {auto, min, max}, if `save_best_only=True`, the comparison to make - and determine if the monitored value has improved. if 'auto' mode, checkpoint - handler will try to use min or max based on the monitored metric name + One of {auto, min, max}, if `save_best=True`, the comparison to make + and determine if the monitored value has improved. if 'auto' mode, + :py:class:`CheckpointHandler` will try to use min or max based on + the monitored metric name. epoch_period: int, default 1 - Epoch intervals between saving the network + Epoch intervals between saving the network. By default, checkpoints are + saved every epoch. batch_period: int, default None - Batch intervals between saving the network, - by default, checkpoints are not saved based on the number of batches + Batch intervals between saving the network. + By default, checkpoints are not saved based on the number of batches. max_checkpoints : int, default 5 Maximum number of checkpoint files to keep in the model_dir, older checkpoints - will be removed. Does not count best checkpoint + will be removed. Best checkpoint file is not counted. resume_from_checkpoint : bool, default False Whether to resume training from checkpoint in model_dir. If True and checkpoints found, :py:class:`CheckpointHandler` will load net parameters and trainer states, @@ -448,9 +459,6 @@ def epoch_end(self, estimator, *args, **kwargs): self.current_epoch += 1 def _save_checkpoint(self, estimator): - if self.verbose > 0: - self.logger.info('[Epoch %d][Batch %d] saving model to %s', - self.current_epoch, self.current_batch, self.model_dir) # if resumed from checkpoint, increment checkpoint number if self.resume_from_checkpoint: save_epoch_number = self.current_epoch + self.trained_epoch + 1 @@ -464,6 +472,10 @@ def _save_checkpoint(self, estimator): save_batch_number = self.current_batch prefix = "%s-epoch%dbatch%d" % (self.model_prefix, save_epoch_number, save_batch_number) self._save_params_and_trainer(estimator, prefix) + if self.verbose > 0: + self.logger.info('[Epoch %d] CheckpointHandler: trained total %d batches, ' + 'saving model at %s with prefix: %s', + self.current_epoch, self.current_batch + 1, self.model_dir, prefix) if self.save_best: monitor_name, monitor_value = self.monitor.get() @@ -475,19 +487,21 @@ def _save_checkpoint(self, estimator): 'create all metric objects', monitor_name)) else: if self.monitor_op(monitor_value, self.best): - if self.verbose > 0: - self.logger.info('[Epoch %d][Batch %d] %s improved from %0.5f to %0.5f,' - ' updating best model to %s', - self.current_epoch, self.current_batch, monitor_name, - self.best, monitor_value, self.model_dir) prefix = self.model_prefix + '-best' self._save_params_and_trainer(estimator, prefix) self.best = monitor_value + if self.verbose > 0: + self.logger.info('[Epoch %d] CheckpointHandler: ' + '%s improved from %0.5f to %0.5f, ' + 'updating best model at %s with prefix: %s', + self.current_epoch, monitor_name, + self.best, monitor_value, self.model_dir, prefix) else: if self.verbose > 0: - self.logger.info('[Epoch %d][Batch %d] %s did not improve from %0.5f, ' + self.logger.info('[Epoch %d] CheckpointHandler: ' + '%s did not improve from %0.5f, ' 'skipping updating best model', - self.current_epoch, self.current_batch, monitor_name, + self.current_batch, monitor_name, self.best) def _save_symbol(self, estimator): @@ -532,28 +546,28 @@ def _resume_from_checkpoint(self, estimator): end='.params') if self.trained_epoch == -1: - msg = "No checkpoint found, training from scratch for " + msg = "CheckpointHandler: No checkpoint found, training from scratch for " if estimator.max_batch: msg += "%d batches" % estimator.max_batch else: msg += "%d epochs" % estimator.max_epoch self.logger.info(msg) else: - msg = "Checkpoint resumed from epoch %d batch %d, continue to train for " % ( - self.trained_epoch, self.trained_batch) + msg = "CheckpointHandler: Checkpoint resumed from epoch %d batch %d, " \ + "continue to train for " % (self.trained_epoch, self.trained_batch) # change maximum number of epoch or batch to train if resumed from epoch checkpoint if estimator.max_epoch: if self.trained_epoch >= estimator.max_epoch - 1: - raise ValueError("Found checkpoint with maximum number of epoch %d reached, please specify" - "resume_from_checkpoint=False (default value) if you wan to train from scratch.", - estimator.max_epoch) + raise ValueError("Found checkpoint with maximum number of epoch %d reached, please specify " + "resume_from_checkpoint=False (default value) if you wan to train from scratch." + % estimator.max_epoch) estimator.max_epoch = estimator.max_epoch - self.trained_epoch - 1 msg += "%d epochs " % estimator.max_epoch if estimator.max_batch: if self.trained_batch >= estimator.max_batch - 1: raise ValueError("Found checkpoint with maximum number of batch %d reached, please specify" - "resume_from_checkpoint=False (default value) if you wan to train from scratch.", - self.trained_batch) + "resume_from_checkpoint=False (default value) if you wan to train from scratch." + % self.trained_batch) estimator.max_batch = estimator.max_batch - self.trained_batch - 1 msg += "%d batches " % estimator.max_batch # load checkpoint @@ -594,17 +608,17 @@ class EarlyStoppingHandler(TrainBegin, EpochEnd, TrainEnd): Parameters ---------- monitor: EvalMetric - The metric to monitor, and stop training if this metric does not improve + The metric to monitor, and stop training if this metric does not improve. min_delta: float, default 0 - Minimal change in monitored value to be considered as an improvement + Minimal change in monitored value to be considered as an improvement. patience: int, default 0 - Number of epochs to wait for improvement before terminate training + Number of epochs to wait for improvement before terminate training. mode: str, default 'auto' One of {auto, min, max}, if `save_best_only=True`, the comparison to make and determine if the monitored value has improved. if 'auto' mode, checkpoint - handler will try to use min or max based on the monitored metric name + handler will try to use min or max based on the monitored metric name. baseline: float - Baseline value to compare the monitored value with + Baseline value to compare the monitored value with. """ def __init__(self, @@ -687,5 +701,5 @@ def epoch_end(self, estimator, *args, **kwargs): def train_end(self, estimator, *args, **kwargs): if self.stopped_epoch > 0: - self.logger.info('Epoch %d: early stopping due to %s not improving', + self.logger.info('[Epoch %d] EarlyStoppingHanlder: early stopping due to %s not improving', self.stopped_epoch, self.monitor.get()[0]) diff --git a/tests/python/unittest/test_gluon_estimator.py b/tests/python/unittest/test_gluon_estimator.py index a6822cb39a33..d2e8c082aa08 100644 --- a/tests/python/unittest/test_gluon_estimator.py +++ b/tests/python/unittest/test_gluon_estimator.py @@ -341,6 +341,7 @@ def test_default_handlers(): assert 'You are training with the' in str(w[-1].message) # handler with prepared loss and metrics + # use mix of default and user defined handlers train_metrics, val_metrics = est.prepare_loss_and_metrics() logging = LoggingHandler(train_metrics=train_metrics, val_metrics=val_metrics) with warnings.catch_warnings(record=True) as w: @@ -350,15 +351,14 @@ def test_default_handlers(): assert 'MetricHandler' in str(w[-1].message) # handler with all user defined metrics - val_metrics = [mx.metric.RMSE("val acc")] + # use mix of default and user defined handlers metric = MetricHandler(train_metrics=[train_acc]) - logging = LoggingHandler(train_metrics=train_metrics, val_metrics=val_metrics) + logging = LoggingHandler(train_metrics=[train_acc], val_metrics=[mx.metric.RMSE("val acc")]) est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[metric, logging]) # handler with mixed metrics, some handler use metrics prepared by estimator # some handler use metrics user prepared - val_metrics = [mx.metric.RMSE("val acc")] - logging = LoggingHandler(train_metrics=train_metrics, val_metrics=val_metrics) + logging = LoggingHandler(train_metrics=train_metrics, val_metrics=[mx.metric.RMSE("val acc")]) with assert_raises(ValueError): est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[logging]) From 1c4c853a5ded1a9673744f10c2a928c734ab2e35 Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Mon, 13 May 2019 12:39:37 -0700 Subject: [PATCH 8/9] trigger ci From e0c37ca2bf70391a7fc2e17acf02a6ed791d4c47 Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Mon, 13 May 2019 15:02:08 -0700 Subject: [PATCH 9/9] trigger ci