From 2b2f85da175976d9d4964abd208ce5fc5911e574 Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Wed, 31 Jul 2019 10:01:21 -0700 Subject: [PATCH] fix test (#13) --- python/mxnet/gluon/contrib/estimator/estimator.py | 4 ++-- python/mxnet/gluon/contrib/estimator/event_handler.py | 9 +++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py index d077bcd4fdeb..b6142e100d96 100644 --- a/python/mxnet/gluon/contrib/estimator/estimator.py +++ b/python/mxnet/gluon/contrib/estimator/estimator.py @@ -153,7 +153,7 @@ def _check_trainer(self, trainer): warnings.warn("No trainer specified, default SGD optimizer " "with learning rate 0.001 is used.") trainer = Trainer(self.net.collect_params(), - 'sgd', {'learning_rate': 0.001}) + 'sgd', {'learning_rate': 0.001}) elif not isinstance(trainer, Trainer): raise ValueError("Trainer must be a Gluon Trainer instance, refer to " "gluon.Trainer:{}".format(trainer)) @@ -363,7 +363,7 @@ def _prepare_default_handlers(self, val_data, event_handlers): # if there is a mix of user defined event handlers and default event handlers # they should have the same set of loss and metrics - if default_handlers and len(event_handlers) > len(default_handlers): + if default_handlers and len(event_handlers) != len(default_handlers): msg = "You are training with the following default event handlers: %s. " \ "They use loss and metrics from estimator.prepare_loss_and_metrics(). " \ "Please use the same set of metrics for all your other handlers." % \ diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py index 0c8dbd9a4e1d..da2c84455e35 100644 --- a/python/mxnet/gluon/contrib/estimator/event_handler.py +++ b/python/mxnet/gluon/contrib/estimator/event_handler.py @@ -29,10 +29,11 @@ from ....metric import EvalMetric from ....metric import Loss as metric_loss -__all__ = ['TrainBegin', 'TrainEnd', 'EpochBegin', 'EpochEnd','BatchBegin', 'BatchEnd', +__all__ = ['TrainBegin', 'TrainEnd', 'EpochBegin', 'EpochEnd', 'BatchBegin', 'BatchEnd', 'StoppingHandler', 'MetricHandler', 'ValidationHandler', 'LoggingHandler', 'CheckpointHandler', 'EarlyStoppingHandler'] + class TrainBegin(object): def train_begin(self, estimator, *args, **kwargs): pass @@ -434,7 +435,7 @@ def train_begin(self, estimator, *args, **kwargs): self.current_epoch = 0 self.current_batch = 0 if self.save_best: - self.best = np.Inf if self.monitor_op == np.less else -np.Inf # pylint: disable=comparison-with-callable + self.best = np.Inf if self.monitor_op == np.less else -np.Inf # pylint: disable=comparison-with-callable if self.resume_from_checkpoint: error_msg = "To use resume from checkpoint, you must only specify " \ "the same type of period you used for training." \ @@ -670,7 +671,7 @@ def __init__(self, "if you want otherwise", self.monitor.get()[0]) self.monitor_op = np.less - if self.monitor_op == np.greater: # pylint: disable=comparison-with-callable + if self.monitor_op == np.greater: # pylint: disable=comparison-with-callable self.min_delta *= 1 else: self.min_delta *= -1 @@ -683,7 +684,7 @@ def train_begin(self, estimator, *args, **kwargs): if self.baseline is not None: self.best = self.baseline else: - self.best = np.Inf if self.monitor_op == np.less else -np.Inf # pylint: disable=comparison-with-callable + self.best = np.Inf if self.monitor_op == np.less else -np.Inf # pylint: disable=comparison-with-callable def epoch_end(self, estimator, *args, **kwargs): monitor_name, monitor_value = self.monitor.get()