Skip to content

Commit

Permalink
fix test (apache#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
roywei committed Jul 31, 2019
1 parent a69b406 commit 2b2f85d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
4 changes: 2 additions & 2 deletions python/mxnet/gluon/contrib/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def _check_trainer(self, trainer):
warnings.warn("No trainer specified, default SGD optimizer "
"with learning rate 0.001 is used.")
trainer = Trainer(self.net.collect_params(),
'sgd', {'learning_rate': 0.001})
'sgd', {'learning_rate': 0.001})
elif not isinstance(trainer, Trainer):
raise ValueError("Trainer must be a Gluon Trainer instance, refer to "
"gluon.Trainer:{}".format(trainer))
Expand Down Expand Up @@ -363,7 +363,7 @@ def _prepare_default_handlers(self, val_data, event_handlers):

# if there is a mix of user defined event handlers and default event handlers
# they should have the same set of loss and metrics
if default_handlers and len(event_handlers) > len(default_handlers):
if default_handlers and len(event_handlers) != len(default_handlers):
msg = "You are training with the following default event handlers: %s. " \
"They use loss and metrics from estimator.prepare_loss_and_metrics(). " \
"Please use the same set of metrics for all your other handlers." % \
Expand Down
9 changes: 5 additions & 4 deletions python/mxnet/gluon/contrib/estimator/event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@
from ....metric import EvalMetric
from ....metric import Loss as metric_loss

__all__ = ['TrainBegin', 'TrainEnd', 'EpochBegin', 'EpochEnd','BatchBegin', 'BatchEnd',
__all__ = ['TrainBegin', 'TrainEnd', 'EpochBegin', 'EpochEnd', 'BatchBegin', 'BatchEnd',
'StoppingHandler', 'MetricHandler', 'ValidationHandler',
'LoggingHandler', 'CheckpointHandler', 'EarlyStoppingHandler']


class TrainBegin(object):
def train_begin(self, estimator, *args, **kwargs):
pass
Expand Down Expand Up @@ -434,7 +435,7 @@ def train_begin(self, estimator, *args, **kwargs):
self.current_epoch = 0
self.current_batch = 0
if self.save_best:
self.best = np.Inf if self.monitor_op == np.less else -np.Inf # pylint: disable=comparison-with-callable
self.best = np.Inf if self.monitor_op == np.less else -np.Inf # pylint: disable=comparison-with-callable
if self.resume_from_checkpoint:
error_msg = "To use resume from checkpoint, you must only specify " \
"the same type of period you used for training." \
Expand Down Expand Up @@ -670,7 +671,7 @@ def __init__(self,
"if you want otherwise", self.monitor.get()[0])
self.monitor_op = np.less

if self.monitor_op == np.greater: # pylint: disable=comparison-with-callable
if self.monitor_op == np.greater: # pylint: disable=comparison-with-callable
self.min_delta *= 1
else:
self.min_delta *= -1
Expand All @@ -683,7 +684,7 @@ def train_begin(self, estimator, *args, **kwargs):
if self.baseline is not None:
self.best = self.baseline
else:
self.best = np.Inf if self.monitor_op == np.less else -np.Inf # pylint: disable=comparison-with-callable
self.best = np.Inf if self.monitor_op == np.less else -np.Inf # pylint: disable=comparison-with-callable

def epoch_end(self, estimator, *args, **kwargs):
monitor_name, monitor_value = self.monitor.get()
Expand Down

0 comments on commit 2b2f85d

Please sign in to comment.