diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py index 3cdc407407c1..53ba07dc836a 100644 --- a/python/mxnet/gluon/contrib/estimator/event_handler.py +++ b/python/mxnet/gluon/contrib/estimator/event_handler.py @@ -227,29 +227,22 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat Parameters ---------- - verbose : int, default LOG_PER_EPOCH - Limit the granularity of metrics displayed during training process. - verbose=LOG_PER_EPOCH: display metrics every epoch - verbose=LOG_PER_BATCH: display metrics every batch + log_interval: int or str, default 'epoch' + Logging interval during training. + log_interval='epoch': display metrics every epoch + log_interval=integer k: display metrics every interval of k batches train_metrics : list of EvalMetrics Training metrics to be logged, logged at batch end, epoch end, train end. val_metrics : list of EvalMetrics Validation metrics to be logged, logged at epoch end, train end. """ - LOG_PER_EPOCH = 1 - LOG_PER_BATCH = 2 - - def __init__(self, verbose=LOG_PER_EPOCH, + def __init__(self, log_interval='epoch', train_metrics=None, val_metrics=None): super(LoggingHandler, self).__init__() - if verbose not in [self.LOG_PER_EPOCH, self.LOG_PER_BATCH]: - raise ValueError("verbose level must be either LOG_PER_EPOCH or " - "LOG_PER_BATCH, received %s. " - "E.g: LoggingHandler(verbose=LoggingHandler.LOG_PER_EPOCH)" - % verbose) - self.verbose = verbose + if not isinstance(log_interval, int) and log_interval != 'epoch': + raise ValueError("log_interval must be either an integer or string 'epoch'") self.train_metrics = _check_metrics(train_metrics) self.val_metrics = _check_metrics(val_metrics) self.batch_index = 0 @@ -258,6 +251,7 @@ def __init__(self, verbose=LOG_PER_EPOCH, # logging handler need to be called at last to make sure all states are updated # it will also shut down logging at train end self.priority = np.Inf + self.log_interval = log_interval def train_begin(self, estimator, *args, **kwargs): self.train_start = time.time() @@ -275,6 +269,7 @@ def train_begin(self, estimator, *args, **kwargs): self.current_epoch = 0 self.batch_index = 0 self.processed_samples = 0 + self.log_interval_time = 0 def train_end(self, estimator, *args, **kwargs): train_time = time.time() - self.train_start @@ -286,31 +281,34 @@ def train_end(self, estimator, *args, **kwargs): estimator.logger.info(msg.rstrip(', ')) def batch_begin(self, estimator, *args, **kwargs): - if self.verbose == self.LOG_PER_BATCH: + if isinstance(self.log_interval, int): self.batch_start = time.time() def batch_end(self, estimator, *args, **kwargs): - if self.verbose == self.LOG_PER_BATCH: + if isinstance(self.log_interval, int): batch_time = time.time() - self.batch_start msg = '[Epoch %d][Batch %d]' % (self.current_epoch, self.batch_index) self.processed_samples += kwargs['batch'][0].shape[0] msg += '[Samples %s] ' % (self.processed_samples) - msg += 'time/batch: %.3fs ' % batch_time - for metric in self.train_metrics: - # only log current training loss & metric after each batch - name, value = metric.get() - msg += '%s: %.4f, ' % (name, value) - estimator.logger.info(msg.rstrip(', ')) + self.log_interval_time += batch_time + if self.batch_index % self.log_interval == 0: + msg += 'time/interval: %.3fs ' % self.log_interval_time + self.log_interval_time = 0 + for metric in self.train_metrics: + # only log current training loss & metric after each interval + name, value = metric.get() + msg += '%s: %.4f, ' % (name, value) + estimator.logger.info(msg.rstrip(', ')) self.batch_index += 1 def epoch_begin(self, estimator, *args, **kwargs): - if self.verbose >= self.LOG_PER_EPOCH: + if isinstance(self.log_interval, int) or self.log_interval == 'epoch': self.epoch_start = time.time() estimator.logger.info("[Epoch %d] Begin, current learning rate: %.4f", self.current_epoch, estimator.trainer.learning_rate) def epoch_end(self, estimator, *args, **kwargs): - if self.verbose >= self.LOG_PER_EPOCH: + if isinstance(self.log_interval, int) or self.log_interval == 'epoch': epoch_time = time.time() - self.epoch_start msg = '[Epoch %d] Finished in %.3fs, ' % (self.current_epoch, epoch_time) for monitor in self.train_metrics + self.val_metrics: diff --git a/tests/python/unittest/test_gluon_event_handler.py b/tests/python/unittest/test_gluon_event_handler.py index 17c75813d516..658fb88f47e5 100644 --- a/tests/python/unittest/test_gluon_event_handler.py +++ b/tests/python/unittest/test_gluon_event_handler.py @@ -17,13 +17,19 @@ import os import logging +import sys +import re import mxnet as mx from common import TemporaryDirectory from mxnet import nd from mxnet.gluon import nn, loss from mxnet.gluon.contrib.estimator import estimator, event_handler - +from mxnet.gluon.contrib.estimator.event_handler import LoggingHandler +try: + from StringIO import StringIO +except ImportError: + from io import StringIO def _get_test_network(net=nn.Sequential()): net.add(nn.Dense(128, activation='relu', flatten=False), @@ -32,9 +38,9 @@ def _get_test_network(net=nn.Sequential()): return net -def _get_test_data(): - data = nd.ones((32, 100)) - label = nd.zeros((32, 1)) +def _get_test_data(in_size=32): + data = nd.ones((in_size, 100)) + label = nd.zeros((in_size, 1)) data_arr = mx.gluon.data.dataset.ArrayDataset(data, label) return mx.gluon.data.DataLoader(data_arr, batch_size=8) @@ -200,3 +206,61 @@ def epoch_end(self, estimator, *args, **kwargs): est.fit(test_data, event_handlers=[custom_handler], epochs=10) assert custom_handler.num_batch == 5 * 4 assert custom_handler.num_epoch == 5 + +def test_logging_interval(): + ''' test different options for logging handler ''' + ''' test case #1: log interval is 1 ''' + batch_size = 8 + data_size = 100 + old_stdout = sys.stdout + sys.stdout = mystdout = StringIO() + log_interval = 1 + net = _get_test_network() + dataloader = _get_test_data(in_size=data_size) + num_epochs = 1 + ce_loss = loss.SoftmaxCrossEntropyLoss() + acc = mx.metric.Accuracy() + logging = LoggingHandler(train_metrics=[acc], log_interval=log_interval) + est = estimator.Estimator(net=net, + loss=ce_loss, + metrics=acc) + + est.fit(train_data=dataloader, + epochs=num_epochs, + event_handlers=[logging]) + + sys.stdout = old_stdout + log_info_list = mystdout.getvalue().splitlines() + info_len = 0 + for info in log_info_list: + match = re.match( + '(\[Epoch \d+\]\[Batch \d+\]\[Samples \d+\] time\/interval: \d+.\d+s' + + ' training accuracy: \d+.\d+)', info) + if match: + info_len += 1 + + assert(info_len == int(data_size/batch_size/log_interval) + 1) + ''' test case #2: log interval is 5 ''' + old_stdout = sys.stdout + sys.stdout = mystdout = StringIO() + acc = mx.metric.Accuracy() + log_interval = 5 + logging = LoggingHandler(train_metrics=[acc], log_interval=log_interval) + est = estimator.Estimator(net=net, + loss=ce_loss, + metrics=acc) + est.fit(train_data=dataloader, + epochs=num_epochs, + event_handlers=[logging]) + sys.stdout = old_stdout + log_info_list = mystdout.getvalue().splitlines() + info_len = 0 + for info in log_info_list: + match = re.match( + '(\[Epoch \d+\]\[Batch \d+\]\[Samples \d+\] time\/interval: \d+.\d+s' + + ' training accuracy: \d+.\d+)', info) + if match: + info_len += 1 + + assert(info_len == int(data_size/batch_size/log_interval) + 1) +