From ac8b4a91a0ac086ad99dfa28d3c356215129e7d4 Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Wed, 3 Apr 2019 14:18:13 -0700 Subject: [PATCH] [MXNet-1340][Fit API]Update train stats (#14494) * add train history * update history * update test * avoid calling empty methods * remove train history object * fix pylint * add unit test * fix test * update categorize handlers --- python/mxnet/gluon/estimator/estimator.py | 147 +++++++------ python/mxnet/gluon/estimator/event_handler.py | 102 +++++---- python/mxnet/gluon/trainer.py | 7 + tests/python/unittest/test_gluon_estimator.py | 193 +++++++++++------- .../unittest/test_gluon_event_handler.py | 12 +- 5 files changed, 280 insertions(+), 181 deletions(-) diff --git a/python/mxnet/gluon/estimator/estimator.py b/python/mxnet/gluon/estimator/estimator.py index e759fa75e290..c5da0c0e5071 100644 --- a/python/mxnet/gluon/estimator/estimator.py +++ b/python/mxnet/gluon/estimator/estimator.py @@ -22,7 +22,7 @@ import copy import warnings -from .event_handler import LoggingHandler +from .event_handler import EventHandler, LoggingHandler from ... import gluon, autograd from ...context import Context, cpu, gpu, num_gpus from ...io import DataIter @@ -39,27 +39,26 @@ class Estimator(object): Parameters ---------- - loss : Loss or list of Loss + loss : gluon.loss.Loss or list of gluon.loss.Loss Loss(objective functions) to calculate during training metrics : EvalMetric or list of EvalMetric Metrics for evaluating models initializer : Initializer initializer to initialize the network - trainers : Trainer or list of Trainer - Trainers to apply optimizers on network parameters + trainer : Trainer + Trainer to apply optimizer on network parameters context : Context or list of Context devices to run the training on """ def __init__(self, net, - loss=None, + loss, metrics=None, initializer=None, - trainers=None, + trainer=None, context=None): self.net = net - self.stop_training = False if isinstance(loss, gluon.loss.Loss): self.loss = [loss] @@ -86,27 +85,14 @@ def __init__(self, net, # store training statistics self.train_stats = {} - self.train_stats['epochs'] = [] - self.train_stats['learning_rate'] = [] - # current step of the epoch - self.train_stats['step'] = '' - for metric in self.train_metrics: - # record a history of metrics over each epoch - self.train_stats['train_' + metric.name] = [] - # only record the latest metric numbers after each batch - self.train_stats['batch_' + metric.name] = 0. - for metric in self.val_metrics: - self.train_stats['val_' + metric.name] = [] + + # separate train and validation self.train_loss_metrics = [] self.val_loss_metrics = [] # using the metric wrapper for loss to record loss value for l in self.loss: self.train_loss_metrics.append(Loss(l.name)) self.val_loss_metrics.append(Loss(l.name)) - self.train_stats['train_' + l.name] = [] - self.train_stats['val_' + l.name] = [] - # only record the latest loss numbers after each batch - self.train_stats['batch_' + l.name] = 0. # handle context if isinstance(context, Context): @@ -127,7 +113,6 @@ def __init__(self, net, raise ValueError("context must be a Context or a list of Context, " "refer to mxnet.Context:{}".format(context)) - # initialize the network self.initializer = initializer if self.initializer: @@ -135,7 +120,7 @@ def __init__(self, net, # if already initialized, re-init with user specified initializer warnings.warn("Network already initialized, re-initializing with %s. " "You don't need to pass initializer if you already " - "initialized your net."% type(self.initializer).__name__) + "initialized your net." % type(self.initializer).__name__) self.net.initialize(init=self.initializer, ctx=self.context, force_reinit=True) else: # initialize with user specified initializer @@ -144,16 +129,17 @@ def __init__(self, net, if not self._is_initialized(): self.net.initialize(ctx=self.context) - # handle trainers - if isinstance(trainers, gluon.Trainer): - self.trainers = [trainers] - elif not trainers: + # handle trainer + if not trainer: warnings.warn("No trainer specified, default SGD optimizer " "with learning rate 0.001 is used.") - self.trainers = [gluon.Trainer(self.net.collect_params(), - 'sgd', {'learning_rate': 0.001})] + self.trainer = gluon.Trainer(self.net.collect_params(), + 'sgd', {'learning_rate': 0.001}) + elif not isinstance(trainer, gluon.Trainer): + raise ValueError("Trainer must be a Gluon Trainer instance, refer to " + "gluon.Trainer:{}".format(trainer)) else: - raise ValueError("Invalid trainer specified, please provide a valid gluon.Trainer") + self.trainer = trainer def _is_initialized(self): param_dict = self.net.collect_params() @@ -212,8 +198,12 @@ def evaluate(self, # update metrics for metric in self.val_metrics: metric.update(label, pred) + name, value = metric.get() + self.train_stats['val_' + name] = value for loss, loss_metric, in zip(losses, self.val_loss_metrics): loss_metric.update(0, [l for l in loss]) + name, value = loss_metric.get() + self.train_stats['val_' + name] = value def fit(self, train_data, val_data=None, @@ -241,27 +231,38 @@ def fit(self, train_data, from a data batch and load into contexts(devices) """ - - self.epochs = epochs + self.max_epoch = epochs if not batch_size: - batch_size = 32 * len(self.context) + self.batch_size = 32 * len(self.context) + else: + self.batch_size = batch_size + self.stop_training = False + self.samples = None + self.batch_idx = 0 event_handlers = event_handlers or [] # provide default logging handler if not event_handlers or \ not any(isinstance(handler, LoggingHandler) for handler in event_handlers): - event_handlers.append(LoggingHandler(self)) + event_handlers.append(LoggingHandler()) - # training begin + train_begin, epoch_begin, batch_begin, \ + batch_end, epoch_end, train_end = self._categorize_handlers(event_handlers) + + # passing estimator to event handlers so they can access estimator information + # when a event is triggered for handler in event_handlers: + handler.estimator = self + + # training begin + for handler in train_begin: handler.train_begin() - for epoch in range(epochs): + for epoch in range(self.max_epoch): # epoch begin - self.train_stats['epochs'].append(epoch) - self.train_stats['learning_rate'].append(self.trainers[0].learning_rate) + self.current_epoch = epoch - for handler in event_handlers: + for handler in epoch_begin: handler.epoch_begin() for metric in self.train_metrics + self.train_loss_metrics: @@ -282,7 +283,7 @@ def fit(self, train_data, data, label = batch_fn(batch, self.context) # batch begin - for handler in event_handlers: + for handler in batch_begin: handler.batch_begin() with autograd.record(): @@ -298,42 +299,64 @@ def fit(self, train_data, # update train metrics for metric in self.train_metrics: metric.update(label, pred) - self.train_stats['batch_' + metric.name] = metric.get()[1] + # get metric name and current value and update train stats + name, value = metric.get() + self.train_stats['train_' + name] = value + + # update loss for loss, loss_metric, in zip(losses, self.train_loss_metrics): loss_metric.update(0, [l for l in loss]) - self.train_stats['batch_' + loss_metric.name] = loss_metric.get()[1] - - try: - completed_samples = len(train_data._dataset) if i == len(train_data._dataset) - 1 \ - else batch_size * (i + 1) - # We need to check if this is the last batch in the current epoch and select - # the value to print appropriately - self.train_stats['step'] = "{}/{}".format(completed_samples, len(train_data._dataset)) - except AttributeError: - self.train_stats['step'] = i + name, value = loss_metric.get() + self.train_stats['train_' + name] = value - for trainer in self.trainers: - trainer.step(batch_size) + self.batch_idx = i + # record trained samples v.s. total samples if using Gluon DataLoader + if isinstance(train_data, gluon.data.DataLoader): + self.samples = "{}/{}".format(self.batch_size * (i + 1), len(train_data._dataset)) + self.trainer.step(self.batch_size) # batch end - for handler in event_handlers: + for handler in batch_end: handler.batch_end() if val_data: self.evaluate(val_data, batch_fn) - for metric in self.train_metrics + self.train_loss_metrics: - self.train_stats['train_' + metric.name].append(metric.get()[1]) - for metric in self.val_metrics + self.val_loss_metrics: - self.train_stats['val_' + metric.name].append(metric.get()[1]) - # epoch end - for handler in event_handlers: + for handler in epoch_end: handler.epoch_end() if self.stop_training: break # train end - for handler in event_handlers: + for handler in train_end: handler.train_end() + + def _categorize_handlers(self, event_handlers): + """ + categorize handlers into 6 event lists to avoid calling empty methods + for example, only event handlers with train_begin method + implemented will be called at train begin + """ + + train_begin = [] + epoch_begin = [] + batch_begin = [] + batch_end = [] + epoch_end = [] + train_end = [] + for handler in event_handlers: + if not handler.__class__.train_begin == EventHandler.train_begin: + train_begin.append(handler) + if not handler.__class__.epoch_begin == EventHandler.epoch_begin: + epoch_begin.append(handler) + if not handler.__class__.batch_begin == EventHandler.batch_begin: + batch_begin.append(handler) + if not handler.__class__.batch_end == EventHandler.batch_end: + batch_end.append(handler) + if not handler.__class__.epoch_end == EventHandler.epoch_end: + epoch_end.append(handler) + if not handler.__class__.train_end == EventHandler.train_end: + train_end.append(handler) + return train_begin, epoch_begin, batch_begin, batch_end, epoch_end, train_end diff --git a/python/mxnet/gluon/estimator/event_handler.py b/python/mxnet/gluon/estimator/event_handler.py index c59644e8f726..781007464954 100644 --- a/python/mxnet/gluon/estimator/event_handler.py +++ b/python/mxnet/gluon/estimator/event_handler.py @@ -40,7 +40,16 @@ class EventHandler(object): estimator : Estimator The :py:class:`Estimator` to get training statistics """ - def __init__(self, estimator): + + def __init__(self): + self._estimator = None + + @property + def estimator(self): + return self._estimator + + @estimator.setter + def estimator(self, estimator): self._estimator = estimator def train_begin(self): @@ -78,8 +87,8 @@ class LoggingHandler(EventHandler): file location to save the logs """ - def __init__(self, estimator, file_name=None, file_location=None, ): - super(LoggingHandler, self).__init__(estimator) + def __init__(self, file_name=None, file_location=None): + super(LoggingHandler, self).__init__() self.logger = logging.getLogger(__name__) self.logger.setLevel(logging.INFO) stream_handler = logging.StreamHandler() @@ -92,22 +101,37 @@ def __init__(self, estimator, file_name=None, file_location=None, ): self.logger.addHandler(file_handler) def train_begin(self): - pass + self.train_start = time.time() + self.logger.info("Training begin: using optimizer %s " + "with current learning rate %.4f ", + self.estimator.trainer.optimizer.__class__.__name__, + self.estimator.trainer.learning_rate) + self.logger.info("Train for %d epochs.", self.estimator.max_epoch) def train_end(self): - pass + train_time = time.time() - self.train_start + epoch = self.estimator.current_epoch + msg = 'Train finished using total %ds at epoch %d. ' % (train_time, epoch) + # log every result in train stats including train/validation loss & metrics + for key in self.estimator.train_stats: + msg += '%s : %.4f ' % (key, self.estimator.train_stats[key]) + self.logger.info(msg) def batch_begin(self): self.batch_start = time.time() def batch_end(self): batch_time = time.time() - self.batch_start - epoch = self._estimator.train_stats['epochs'][-1] - step = self._estimator.train_stats['step'] - msg = '[Epoch %d] [Step %s] time/step: %.3fs ' % (epoch, step, batch_time) - for key in self._estimator.train_stats.keys(): - if key.startswith('batch_'): - msg += key[6:] + ': ' + '%.4f ' % self._estimator.train_stats[key] + epoch = self.estimator.current_epoch + batch = self.estimator.batch_idx + msg = '[Epoch %d] [Batch %d] ' % (epoch, batch) + if self.estimator.samples: + msg += '[Samples %s] ' % (self.estimator.samples) + msg += 'time/batch: %.3fs ' % batch_time + for key in self.estimator.train_stats: + # only log current training loss & metric after each batch + if key.startswith('train_'): + msg += key + ': ' + '%.4f ' % self.estimator.train_stats[key] self.logger.info(msg) def epoch_begin(self): @@ -115,11 +139,11 @@ def epoch_begin(self): def epoch_end(self): epoch_time = time.time() - self.epoch_start - epoch = self._estimator.train_stats['epochs'][-1] + epoch = self.estimator.current_epoch msg = '\n[Epoch %d] finished in %.3fs: ' % (epoch, epoch_time) - for key in self._estimator.train_stats.keys(): - if key.startswith('train_') or key.startswith('val_'): - msg += key + ': ' + '%.4f ' % self._estimator.train_stats[key][epoch] + # log every result in train stats including train/validation loss & metrics + for key in self.estimator.train_stats: + msg += '%s : %.4f ' % (key, self.estimator.train_stats[key]) self.logger.info(msg) @@ -148,14 +172,14 @@ class CheckpointHandler(EventHandler): intervals between saving the network """ - def __init__(self, estimator, + def __init__(self, filepath, - monitor='val_loss', + monitor='val_accuracy', verbose=0, save_best_only=False, mode='auto', period=1): - super(CheckpointHandler, self).__init__(estimator) + super(CheckpointHandler, self).__init__() self.monitor = monitor self.verbose = verbose self.filepath = filepath @@ -186,7 +210,7 @@ def __init__(self, estimator, self.best = np.Inf def epoch_end(self, ): - epoch = self._estimator.train_stats['epochs'][-1] + epoch = self.estimator.current_epoch # add extension for weights if '.params' not in self.filepath: self.filepath += '.params' @@ -194,20 +218,21 @@ def epoch_end(self, ): if self.epochs_since_last_save >= self.period: self.epochs_since_last_save = 0 if self.save_best_only: - # check if monitor exists in train_stats - if self.monitor not in self._estimator.train_stats: - warnings.warn(RuntimeWarning('Unable to find %s in training statistics, make sure' - 'you are passing one of the metric names as monitor', self.monitor)) - self._estimator.net.save_parameters(self.filepath) + # check if monitor exists in train stats + if self.monitor not in self.estimator.train_stats: + warnings.warn(RuntimeWarning('Unable to find %s in training statistics, make sure the monitor value' + 'starts with `train_ `or `val_` and contains loss/metric name, ', + 'for example val_accuracy', self.monitor)) + self.estimator.net.save_parameters(self.filepath) else: - current = self._estimator.train_stats[self.monitor][-1] + current = self.estimator.train_stats[self.monitor] if self.monitor_op(current, self.best): if self.verbose > 0: self.logger.info('\n[Epoch %d] %s improved from %0.5f to %0.5f,' ' saving model to %s', epoch, self.monitor, self.best, current, self.filepath) self.best = current - self._estimator.net.save_parameters(self.filepath) + self.estimator.net.save_parameters(self.filepath) else: if self.verbose > 0: self.logger.info('\n[Epoch %d] %s did not improve from %0.5f, skipping save model', @@ -215,7 +240,7 @@ def epoch_end(self, ): else: if self.verbose > 0: logging.info('\nEpoch %d: saving model to %s', epoch, self.filepath) - self._estimator.net.save_parameters(self.filepath) + self.estimator.net.save_parameters(self.filepath) class EarlyStoppingHandler(EventHandler): @@ -238,15 +263,14 @@ class EarlyStoppingHandler(EventHandler): baseline value to compare the monitored value with """ - def __init__(self, estimator, - monitor='val_loss', + def __init__(self, + monitor='val_accuracy', min_delta=0, patience=0, mode='auto', baseline=None): - super(EarlyStoppingHandler, self).__init__(estimator) + super(EarlyStoppingHandler, self).__init__() - self._estimator = estimator self.monitor = monitor self.baseline = baseline self.patience = patience @@ -284,15 +308,13 @@ def train_begin(self): self.best = np.Inf if self.monitor_op == np.less else -np.Inf def epoch_end(self): - epoch = self._estimator.train_stats['epochs'][-1] - if self.monitor not in self._estimator.train_stats: - warnings.warn(RuntimeWarning('Unable to find %s in training statistics, make sure' - 'you are passing one of the metric names as monitor', self.monitor)) + epoch = self.estimator.current_epoch + if self.monitor not in self.estimator.train_stats: + warnings.warn(RuntimeWarning('Unable to find %s in training statistics, make sure the monitor value' + 'starts with `train_ `or `val_` and contains loss/metric name, ', + 'for example val_accuracy', self.monitor)) else: - current = self._estimator.train_stats[self.monitor][-1] - if current is None: - return - + current = self.estimator.train_stats[self.monitor] if self.monitor_op(current - self.min_delta, self.best): self.best = current self.wait = 0 @@ -300,7 +322,7 @@ def epoch_end(self): self.wait += 1 if self.wait >= self.patience: self.stopped_epoch = epoch - self._estimator.stop_training = True + self.estimator.stop_training = True def train_end(self): if self.stopped_epoch > 0: diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index 6935c2752e1a..0939490a8307 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -255,6 +255,13 @@ def learning_rate(self): else: return self._optimizer.learning_rate + @property + def optimizer(self): + if isinstance(self._optimizer, opt.Optimizer): + return self._optimizer + else: + raise UserWarning("Optimizer has not been initialized yet") + def set_learning_rate(self, lr): """Sets a new learning rate of the optimizer. diff --git a/tests/python/unittest/test_gluon_estimator.py b/tests/python/unittest/test_gluon_estimator.py index 85e61ceb364d..25a410e93479 100644 --- a/tests/python/unittest/test_gluon_estimator.py +++ b/tests/python/unittest/test_gluon_estimator.py @@ -17,14 +17,15 @@ ''' Unit tests for Gluon Estimator ''' -import unittest import sys +import unittest import warnings -from nose.tools import assert_raises + import mxnet as mx from mxnet import gluon from mxnet.gluon import nn -from mxnet.gluon.estimator import estimator +from mxnet.gluon.estimator import Estimator, EventHandler +from nose.tools import assert_raises def get_model(): @@ -43,11 +44,11 @@ def test_fit(): acc = mx.metric.Accuracy() net.initialize(ctx=ctx) trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001}) - est = estimator.Estimator(net=net, - loss=loss, - metrics=acc, - trainers=trainer, - context=ctx) + est = Estimator(net=net, + loss=loss, + metrics=acc, + trainer=trainer, + context=ctx) in_data = mx.nd.random.uniform(shape=(10, 3)) out_data = mx.nd.random.uniform(shape=(10, 4)) # Input dataloader @@ -80,11 +81,11 @@ def test_validation(): acc = mx.metric.Accuracy() net.initialize(ctx=ctx) trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001}) - est = estimator.Estimator(net=net, - loss=loss, - metrics=acc, - trainers=trainer, - context=ctx) + est = Estimator(net=net, + loss=loss, + metrics=acc, + trainer=trainer, + context=ctx) in_data = mx.nd.random.uniform(shape=(10, 3)) out_data = mx.nd.random.uniform(shape=(10, 4)) # Input dataloader @@ -125,10 +126,10 @@ def test_initializer(): loss = gluon.loss.L2Loss() acc = mx.metric.Accuracy() # no initializer - est = estimator.Estimator(net=net, - loss=loss, - metrics=acc, - context=ctx) + est = Estimator(net=net, + loss=loss, + metrics=acc, + context=ctx) est.fit(train_data=train_data, epochs=num_epochs, batch_size=batch_size) @@ -139,12 +140,12 @@ def test_initializer(): trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001}) # catch reinit warning with warnings.catch_warnings(record=True) as w: - est = estimator.Estimator(net=net, - loss=loss, - metrics=acc, - initializer=mx.init.MSRAPrelu(), - trainers=trainer, - context=ctx) + est = Estimator(net=net, + loss=loss, + metrics=acc, + initializer=mx.init.MSRAPrelu(), + trainer=trainer, + context=ctx) assert 'Network already initialized' in str(w[-1].message) est.fit(train_data=train_data, epochs=num_epochs, @@ -167,10 +168,10 @@ def test_trainer(): net.initialize(ctx=ctx) # input no trainer with warnings.catch_warnings(record=True) as w: - est = estimator.Estimator(net=net, - loss=loss, - metrics=acc, - context=ctx) + est = Estimator(net=net, + loss=loss, + metrics=acc, + context=ctx) assert 'No trainer specified' in str(w[-1].message) est.fit(train_data=train_data, epochs=num_epochs, @@ -179,11 +180,11 @@ def test_trainer(): # input invalid trainer trainer = 'sgd' with assert_raises(ValueError): - est = estimator.Estimator(net=net, - loss=loss, - metrics=acc, - trainers=trainer, - context=ctx) + est = Estimator(net=net, + loss=loss, + metrics=acc, + trainer=trainer, + context=ctx) def test_metric(): @@ -200,59 +201,54 @@ def test_metric(): net.initialize(ctx=ctx) trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001}) # input no metric - est = estimator.Estimator(net=net, - loss=loss, - trainers=trainer, - context=ctx) + est = Estimator(net=net, + loss=loss, + trainer=trainer, + context=ctx) est.fit(train_data=train_data, epochs=num_epochs, batch_size=batch_size) # input list of metrics metrics = [mx.metric.Accuracy(), mx.metric.Accuracy()] - est = estimator.Estimator(net=net, - loss=loss, - metrics=metrics, - trainers=trainer, - context=ctx) + est = Estimator(net=net, + loss=loss, + metrics=metrics, + trainer=trainer, + context=ctx) est.fit(train_data=train_data, epochs=num_epochs, batch_size=batch_size) # input invalid metric with assert_raises(ValueError): - est = estimator.Estimator(net=net, - loss=loss, - metrics='acc', - trainers=trainer, - context=ctx) + est = Estimator(net=net, + loss=loss, + metrics='acc', + trainer=trainer, + context=ctx) # test default metric loss = gluon.loss.SoftmaxCrossEntropyLoss() - est = estimator.Estimator(net=net, - loss=loss, - trainers=trainer, - context=ctx) + est = Estimator(net=net, + loss=loss, + trainer=trainer, + context=ctx) assert isinstance(est.train_metrics[0], mx.metric.Accuracy) def test_loss(): - ''' test with no loss, invalid loss ''' + ''' test with invalid loss ''' net = get_model() ctx = mx.cpu() acc = mx.metric.Accuracy() net.initialize(ctx=ctx) trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001}) - # input no loss - with assert_raises(ValueError): - est = estimator.Estimator(net=net, - trainers=trainer, - metrics=acc, - context=ctx) # input invalid loss with assert_raises(ValueError): - est = estimator.Estimator(net=net, - loss='mse', - metrics=acc, - trainers=trainer, - context=ctx) + est = Estimator(net=net, + loss='mse', + metrics=acc, + trainer=trainer, + context=ctx) + def test_context(): ''' test with no context, list of context, invalid context ''' @@ -260,18 +256,69 @@ def test_context(): loss = gluon.loss.L2Loss() metrics = mx.metric.Accuracy() # input no context - est = estimator.Estimator(net=net, - loss=loss, - metrics=metrics) + est = Estimator(net=net, + loss=loss, + metrics=metrics) # input list of context ctx = [mx.gpu(0), mx.gpu(1)] - est = estimator.Estimator(net=net, - loss=loss, - metrics=metrics, - context=ctx) + est = Estimator(net=net, + loss=loss, + metrics=metrics, + context=ctx) # input invalid context with assert_raises(ValueError): - est = estimator.Estimator(net=net, - loss=loss, - metrics=metrics, - context='cpu') + est = Estimator(net=net, + loss=loss, + metrics=metrics, + context='cpu') + + +def test_categorize_handlers(): + class CustomHandler1(EventHandler): + def __init__(self): + super(CustomHandler1, self).__init__() + + def train_begin(self): + print("custom train begin") + + class CustomHandler2(EventHandler): + def __init__(self): + super(CustomHandler2, self).__init__() + + def epoch_begin(self): + print("custom epoch begin") + + def batch_begin(self): + print("custom batch begin") + + def train_end(self): + print("custom train end") + + class CustomHandler3(EventHandler): + def __init__(self): + super(CustomHandler3, self).__init__() + + def epoch_begin(self): + print("custom epoch begin") + + def batch_begin(self): + print("custom batch begin") + + def batch_end(self): + print("custom batch end") + + def train_end(self): + print("custom train end") + + net = nn.Sequential() + net.add(nn.Dense(10)) + loss = gluon.loss.SoftmaxCrossEntropyLoss() + est = Estimator(net, loss=loss) + event_handlers = [CustomHandler1(), CustomHandler2(), CustomHandler3()] + train_begin, epoch_begin, batch_begin, \ + batch_end, epoch_end, train_end = est._categorize_handlers(event_handlers) + assert len(train_begin) == 1 + assert len(epoch_begin) == 2 + assert len(batch_begin) == 2 + assert len(batch_end) == 1 + assert len(train_end) == 2 diff --git a/tests/python/unittest/test_gluon_event_handler.py b/tests/python/unittest/test_gluon_event_handler.py index a551594d6430..ccbcb54b226b 100644 --- a/tests/python/unittest/test_gluon_event_handler.py +++ b/tests/python/unittest/test_gluon_event_handler.py @@ -45,7 +45,7 @@ def test_checkpoint_handler(): ce_loss = loss.SoftmaxCrossEntropyLoss() acc = mx.metric.Accuracy() est = estimator.Estimator(net, loss=ce_loss, metrics=acc) - checkpoint_handler = [event_handler.CheckpointHandler(est, file_path, + checkpoint_handler = [event_handler.CheckpointHandler(file_path, save_best_only=save_best_only, mode=mode)] est.fit(test_data, event_handlers=checkpoint_handler, epochs=1) @@ -63,15 +63,15 @@ def test_early_stopping(): ce_loss = loss.SoftmaxCrossEntropyLoss() acc = mx.metric.Accuracy() est = estimator.Estimator(net, loss=ce_loss, metrics=acc) - early_stopping = [event_handler.EarlyStoppingHandler(est, monitor, + early_stopping = [event_handler.EarlyStoppingHandler(monitor, patience=patience, - mode=mode)] - est.fit(test_data, event_handlers=early_stopping, epochs=1) + mode=mode)] + est.fit(test_data, event_handlers=early_stopping, epochs=3) mode = 'auto' monitor = 'train_accuracy' patience = 2 - early_stopping = [event_handler.EarlyStoppingHandler(est, monitor, + early_stopping = [event_handler.EarlyStoppingHandler(monitor, patience=patience, mode=mode)] est.fit(test_data, event_handlers=early_stopping, epochs=1) @@ -86,7 +86,7 @@ def test_logging(): ce_loss = loss.SoftmaxCrossEntropyLoss() acc = mx.metric.Accuracy() est = estimator.Estimator(net, loss=ce_loss, metrics=acc) - logging_handler = [event_handler.LoggingHandler(est, file_name=file_name, file_location=tmpdir)] + logging_handler = [event_handler.LoggingHandler(file_name=file_name, file_location=tmpdir)] est.fit(test_data, event_handlers=logging_handler, epochs=1) assert os.path.isfile(output_dir) os.remove(output_dir) \ No newline at end of file