diff --git a/python/mxnet/gluon/estimator/estimator.py b/python/mxnet/gluon/estimator/estimator.py index c5da0c0e5071..529499129df8 100644 --- a/python/mxnet/gluon/estimator/estimator.py +++ b/python/mxnet/gluon/estimator/estimator.py @@ -21,11 +21,9 @@ import copy import warnings - from .event_handler import EventHandler, LoggingHandler from ... import gluon, autograd from ...context import Context, cpu, gpu, num_gpus -from ...io import DataIter from ...metric import EvalMetric, Loss, Accuracy __all__ = ['Estimator'] @@ -168,7 +166,7 @@ def evaluate(self, Parameters ---------- - val_data : DataLoader or DataIter + val_data : DataLoader validation data with data and labels batch_fn : function custom batch function to extract data and label @@ -182,13 +180,10 @@ def evaluate(self, if not batch_fn: if isinstance(val_data, gluon.data.DataLoader): data, label = self._batch_fn(batch, self.context) - elif isinstance(val_data, DataIter): - data, label = self._batch_fn(batch, self.context, is_iterator=True) else: raise ValueError("You are using a custom iteration, please also provide " "batch_fn to extract data and label. Alternatively, you " - "can provide the data as gluon.data.DataLoader or " - "mx.io.DataIter") + "can provide the data as gluon.data.DataLoader.") else: data, label = batch_fn(batch, self.context) pred = [self.net(x) for x in data] @@ -208,16 +203,17 @@ def evaluate(self, def fit(self, train_data, val_data=None, epochs=1, - batch_size=None, event_handlers=None, batch_fn=None): - """Main training loop + """Trains the model on a given dataset for a specified + number of epochs. Also, the batch size is inferred from the + DataLoader's batch_size. Parameters ---------- - train_data : DataLoader or DataIter + train_data : DataLoader training data with data and labels - val_data : DataLoader or DataIter + val_data : DataLoader validation data with data and labels epochs : int, default 1 number of epochs to iterate on the training data. @@ -232,12 +228,8 @@ def fit(self, train_data, """ self.max_epoch = epochs - if not batch_size: - self.batch_size = 32 * len(self.context) - else: - self.batch_size = batch_size self.stop_training = False - self.samples = None + self.processed_samples = None self.batch_idx = 0 event_handlers = event_handlers or [] @@ -245,6 +237,9 @@ def fit(self, train_data, if not event_handlers or \ not any(isinstance(handler, LoggingHandler) for handler in event_handlers): event_handlers.append(LoggingHandler()) + warnings.warn("No Event Handler specified, default `LoggingHandler()` " + "is used with verbose=LoggingHandler.LOG_VERBOSITY_PER_EPOCH. " + "Please look at gluon.estimator.event_handler for more detail.") train_begin, epoch_begin, batch_begin, \ batch_end, epoch_end, train_end = self._categorize_handlers(event_handlers) @@ -261,6 +256,8 @@ def fit(self, train_data, for epoch in range(self.max_epoch): # epoch begin self.current_epoch = epoch + # Number of samples trained after every batch + completed_samples = 0 for handler in epoch_begin: handler.epoch_begin() @@ -272,16 +269,15 @@ def fit(self, train_data, if not batch_fn: if isinstance(train_data, gluon.data.DataLoader): data, label = self._batch_fn(batch, self.context) - elif isinstance(train_data, DataIter): - data, label = self._batch_fn(batch, self.context, is_iterator=True) else: raise ValueError("You are using a custom iteration, please also provide " "batch_fn to extract data and label. Alternatively, you " - "can provide the data as gluon.data.DataLoader or " - "mx.io.DataIter") + "can provide the data as gluon.data.DataLoader") else: data, label = batch_fn(batch, self.context) + batch_size = batch[0].shape[0] + # batch begin for handler in batch_begin: handler.batch_begin() @@ -309,12 +305,15 @@ def fit(self, train_data, name, value = loss_metric.get() self.train_stats['train_' + name] = value + completed_samples += batch_size + self.batch_idx = i # record trained samples v.s. total samples if using Gluon DataLoader if isinstance(train_data, gluon.data.DataLoader): - self.samples = "{}/{}".format(self.batch_size * (i + 1), len(train_data._dataset)) + self.processed_samples = "{}/{}".format(completed_samples, + len(train_data._dataset)) - self.trainer.step(self.batch_size) + self.trainer.step(batch_size) # batch end for handler in batch_end: handler.batch_end() diff --git a/python/mxnet/gluon/estimator/event_handler.py b/python/mxnet/gluon/estimator/event_handler.py index 781007464954..53c0bf5bde86 100644 --- a/python/mxnet/gluon/estimator/event_handler.py +++ b/python/mxnet/gluon/estimator/event_handler.py @@ -85,14 +85,27 @@ class LoggingHandler(EventHandler): file name to save the logs file_location: str file location to save the logs + verbose: int, default LOG_VERBOSITY_PER_EPOCH + Limit the granularity of metrics displayed during training process + verbose=LOG_VERBOSITY_PER_EPOCH: display metrics every epoch + verbose=LOG_VERBOSITY_PER_BATCH: display metrics every batch """ - def __init__(self, file_name=None, file_location=None): + LOG_VERBOSITY_PER_EPOCH = 1 + LOG_VERBOSITY_PER_BATCH = 2 + + def __init__(self, file_name=None, file_location=None, verbose=LOG_VERBOSITY_PER_EPOCH): super(LoggingHandler, self).__init__() self.logger = logging.getLogger(__name__) self.logger.setLevel(logging.INFO) stream_handler = logging.StreamHandler() self.logger.addHandler(stream_handler) + if verbose not in [self.LOG_VERBOSITY_PER_EPOCH, self.LOG_VERBOSITY_PER_BATCH]: + raise ValueError("verbose level must be either LOG_VERBOSITY_PER_EPOCH or " + "LOG_VERBOSITY_PER_BATCH, received %s. " + "E.g: LoggingHandler(verbose=LoggingHandler.LOG_VERBOSITY_PER_EPOCH)" + % verbose) + self.verbose = verbose # save logger to file only if file name or location is specified if file_name or file_location: file_name = file_name or 'estimator_log' @@ -118,33 +131,37 @@ def train_end(self): self.logger.info(msg) def batch_begin(self): - self.batch_start = time.time() + if self.verbose == self.LOG_VERBOSITY_PER_BATCH: + self.batch_start = time.time() def batch_end(self): - batch_time = time.time() - self.batch_start - epoch = self.estimator.current_epoch - batch = self.estimator.batch_idx - msg = '[Epoch %d] [Batch %d] ' % (epoch, batch) - if self.estimator.samples: - msg += '[Samples %s] ' % (self.estimator.samples) - msg += 'time/batch: %.3fs ' % batch_time - for key in self.estimator.train_stats: - # only log current training loss & metric after each batch - if key.startswith('train_'): - msg += key + ': ' + '%.4f ' % self.estimator.train_stats[key] - self.logger.info(msg) + if self.verbose == self.LOG_VERBOSITY_PER_BATCH: + batch_time = time.time() - self.batch_start + epoch = self.estimator.current_epoch + batch = self.estimator.batch_idx + msg = '[Epoch %d] [Batch %d] ' % (epoch, batch) + if self.estimator.processed_samples: + msg += '[Samples %s] ' % (self.estimator.processed_samples) + msg += 'time/batch: %.3fs ' % batch_time + for key in self.estimator.train_stats: + # only log current training loss & metric after each batch + if key.startswith('train_'): + msg += key + ': ' + '%.4f ' % self.estimator.train_stats[key] + self.logger.info(msg) def epoch_begin(self): - self.epoch_start = time.time() + if self.verbose >= self.LOG_VERBOSITY_PER_EPOCH: + self.epoch_start = time.time() def epoch_end(self): - epoch_time = time.time() - self.epoch_start - epoch = self.estimator.current_epoch - msg = '\n[Epoch %d] finished in %.3fs: ' % (epoch, epoch_time) - # log every result in train stats including train/validation loss & metrics - for key in self.estimator.train_stats: - msg += '%s : %.4f ' % (key, self.estimator.train_stats[key]) - self.logger.info(msg) + if self.verbose >= self.LOG_VERBOSITY_PER_EPOCH: + epoch_time = time.time() - self.epoch_start + epoch = self.estimator.current_epoch + msg = '\n[Epoch %d] finished in %.3fs: ' % (epoch, epoch_time) + # log every result in train stats including train/validation loss & metrics + for key in self.estimator.train_stats: + msg += '%s : %.4f ' % (key, self.estimator.train_stats[key]) + self.logger.info(msg) class CheckpointHandler(EventHandler): diff --git a/tests/nightly/estimator/test_estimator_cnn.py b/tests/nightly/estimator/test_estimator_cnn.py index b99e99af6406..b4311b3b3a1e 100644 --- a/tests/nightly/estimator/test_estimator_cnn.py +++ b/tests/nightly/estimator/test_estimator_cnn.py @@ -105,13 +105,12 @@ def test_estimator_cpu(): est = estimator.Estimator(net=net, loss=loss, metrics=mx.metric.Accuracy(), - trainers=trainer, + trainer=trainer, context=context) # Call fit() est.fit(train_data=train_data, val_data=val_data, - epochs=1, - batch_size=1) + epochs=1) def test_estimator_gpu(): ''' @@ -131,15 +130,14 @@ def test_estimator_gpu(): est = estimator.Estimator(net=net, loss=loss, metrics=acc, - trainers=trainer, + trainer=trainer, context=context) # Call fit() est.fit(train_data=train_data, val_data=test_data, - epochs=num_epochs, - batch_size=batch_size) + epochs=num_epochs) - assert est.train_stats['train_'+acc.name][num_epochs-1] > 0.80 + assert est.train_stats['train_'+acc.name] > 0.80 if __name__ == '__main__': parser = argparse.ArgumentParser(description='test gluon estimator') diff --git a/tests/nightly/estimator/test_sentiment_rnn.py b/tests/nightly/estimator/test_sentiment_rnn.py index 7e42831786ce..c9dcbd2c8050 100644 --- a/tests/nightly/estimator/test_sentiment_rnn.py +++ b/tests/nightly/estimator/test_sentiment_rnn.py @@ -179,10 +179,10 @@ def run(net, train_dataloader, test_dataloader, **kwargs): # Define estimator est = estimator.Estimator(net=net, loss=loss, metrics=acc, - trainers=trainer, context=ctx) + trainer=trainer, context=ctx) # Begin training est.fit(train_data=train_dataloader, val_data=test_dataloader, - epochs=num_epochs, batch_size=batch_size) + epochs=num_epochs) return est @@ -252,7 +252,7 @@ def test_estimator_gpu(**kwargs): est = run(net, train_dataloader, test_dataloader, **kwargs) - assert est.train_stats['train_accuracy'][num_epochs - 1] > 0.70 + assert est.train_stats['train_accuracy'] > 0.70 parser = argparse.ArgumentParser(description='test gluon estimator') diff --git a/tests/python/unittest/test_gluon_estimator.py b/tests/python/unittest/test_gluon_estimator.py index 25a410e93479..c86f4ff8587e 100644 --- a/tests/python/unittest/test_gluon_estimator.py +++ b/tests/python/unittest/test_gluon_estimator.py @@ -55,20 +55,18 @@ def test_fit(): dataset = gluon.data.dataset.ArrayDataset(in_data, out_data) train_dataloader = gluon.data.DataLoader(dataset, batch_size=batch_size) est.fit(train_data=train_dataloader, - epochs=num_epochs, - batch_size=batch_size) + epochs=num_epochs) # Input dataiter train_dataiter = mx.io.NDArrayIter(data=in_data, label=out_data, batch_size=batch_size) - est.fit(train_data=train_dataiter, - epochs=num_epochs, - batch_size=batch_size) + with assert_raises(ValueError): + est.fit(train_data=train_dataiter, + epochs=num_epochs) # Input NDArray with assert_raises(ValueError): est.fit(train_data=[in_data, out_data], - epochs=num_epochs, - batch_size=batch_size) + epochs=num_epochs) def test_validation(): @@ -94,22 +92,20 @@ def test_validation(): val_dataloader = gluon.data.DataLoader(dataset, batch_size=batch_size) est.fit(train_data=train_dataloader, val_data=val_dataloader, - epochs=num_epochs, - batch_size=batch_size) + epochs=num_epochs) # Input dataiter train_dataiter = mx.io.NDArrayIter(data=in_data, label=out_data, batch_size=batch_size) val_dataiter = mx.io.NDArrayIter(data=in_data, label=out_data, batch_size=batch_size) - est.fit(train_data=train_dataiter, - val_data=val_dataiter, - epochs=num_epochs, - batch_size=batch_size) + with assert_raises(ValueError): + est.fit(train_data=train_dataiter, + val_data=val_dataiter, + epochs=num_epochs) # Input NDArray with assert_raises(ValueError): est.fit(train_data=[in_data, out_data], val_data=[in_data, out_data], - epochs=num_epochs, - batch_size=batch_size) + epochs=num_epochs) @unittest.skipIf(sys.version_info.major < 3, 'Test on python 3') @@ -131,8 +127,7 @@ def test_initializer(): metrics=acc, context=ctx) est.fit(train_data=train_data, - epochs=num_epochs, - batch_size=batch_size) + epochs=num_epochs) # different initializer for net and estimator net = get_model() @@ -148,8 +143,7 @@ def test_initializer(): context=ctx) assert 'Network already initialized' in str(w[-1].message) est.fit(train_data=train_data, - epochs=num_epochs, - batch_size=batch_size) + epochs=num_epochs) @unittest.skipIf(sys.version_info.major < 3, 'Test on python 3') @@ -174,8 +168,7 @@ def test_trainer(): context=ctx) assert 'No trainer specified' in str(w[-1].message) est.fit(train_data=train_data, - epochs=num_epochs, - batch_size=batch_size) + epochs=num_epochs) # input invalid trainer trainer = 'sgd' @@ -206,8 +199,7 @@ def test_metric(): trainer=trainer, context=ctx) est.fit(train_data=train_data, - epochs=num_epochs, - batch_size=batch_size) + epochs=num_epochs) # input list of metrics metrics = [mx.metric.Accuracy(), mx.metric.Accuracy()] est = Estimator(net=net, @@ -216,8 +208,7 @@ def test_metric(): trainer=trainer, context=ctx) est.fit(train_data=train_data, - epochs=num_epochs, - batch_size=batch_size) + epochs=num_epochs) # input invalid metric with assert_raises(ValueError): est = Estimator(net=net, @@ -260,7 +251,9 @@ def test_context(): loss=loss, metrics=metrics) # input list of context - ctx = [mx.gpu(0), mx.gpu(1)] + gpus = mx.context.num_gpus() + ctx = [mx.gpu(i) for i in gpus] if gpus > 0 else [mx.cpu()] + net = get_model() est = Estimator(net=net, loss=loss, metrics=metrics, diff --git a/tests/python/unittest/test_gluon_event_handler.py b/tests/python/unittest/test_gluon_event_handler.py index ccbcb54b226b..023b04691a39 100644 --- a/tests/python/unittest/test_gluon_event_handler.py +++ b/tests/python/unittest/test_gluon_event_handler.py @@ -30,7 +30,10 @@ def _get_test_network(): return net def _get_test_data(): - return mx.io.NDArrayIter(data=nd.ones((32, 100)), label=nd.random.randint(0, 10, (32, 1))) + data = nd.ones((32, 100)) + label = nd.random.randint(0, 10, (32, 1)) + data_arr = mx.gluon.data.dataset.ArrayDataset(data, label) + return mx.gluon.data.DataLoader(data_arr, batch_size=32) def test_checkpoint_handler():