Skip to content

Commit

Permalink
Fix doc gen (apache#12)
Browse files Browse the repository at this point in the history
* fix warining

* fix test

* fix

* fix

* fix print
  • Loading branch information
roywei committed Jul 31, 2019
1 parent 55c54e5 commit a69b406
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 19 deletions.
22 changes: 15 additions & 7 deletions docs/tutorials/gluon/fit_api_tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,12 @@ est = estimator.Estimator(net=resnet_18_v1,
trainer=trainer,
context=ctx)

# Magic line
est.fit(train_data=train_data_loader,
# ignore warnings for nightly test on CI only
import warnings
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# Magic line
est.fit(train_data=train_data_loader,
epochs=num_epochs)
```

Expand Down Expand Up @@ -224,11 +228,15 @@ checkpoint_handler = CheckpointHandler(model_dir='./',
save_best=True) # Save the best model in terms of
# Let's instantiate another handler which we defined above
loss_record_handler = LossRecordHandler()
# Magic line
est.fit(train_data=train_data_loader,
val_data=val_data_loader,
epochs=num_epochs,
event_handlers=[checkpoint_handler, loss_record_handler]) # Add the event handlers
# ignore warnings for nightly test on CI only
import warnings
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# Magic line
est.fit(train_data=train_data_loader,
val_data=val_data_loader,
epochs=num_epochs,
event_handlers=[checkpoint_handler, loss_record_handler]) # Add the event handlers
```

Training begin: using optimizer SGD with current learning rate 0.0400 <!--notebook-skip-line-->
Expand Down
26 changes: 17 additions & 9 deletions python/mxnet/gluon/contrib/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,28 +334,36 @@ def fit(self, train_data,
def _prepare_default_handlers(self, val_data, event_handlers):
event_handlers = event_handlers or []
default_handlers = []
train_metrics, val_metrics = self.prepare_loss_and_metrics()
self.prepare_loss_and_metrics()

# no need to add to default handler check as StoppingHandler does not use metrics
event_handlers.append(StoppingHandler(self.max_epoch, self.max_batch))
default_handlers.append("StoppingHandler")

if not any(isinstance(handler, MetricHandler) for handler in event_handlers):
event_handlers.append(MetricHandler(train_metrics=train_metrics))
event_handlers.append(MetricHandler(train_metrics=self.train_metrics))
default_handlers.append("MetricHandler")

if val_data and not any(isinstance(handler, ValidationHandler) for handler in event_handlers):
event_handlers.append(ValidationHandler(val_data=val_data, eval_fn=self.evaluate,
val_metrics=val_metrics))
default_handlers.append("ValidationHandler")
if not any(isinstance(handler, ValidationHandler) for handler in event_handlers):
# no validation handler
if val_data:
# add default validation handler if validation data found
event_handlers.append(ValidationHandler(val_data=val_data, eval_fn=self.evaluate,
val_metrics=self.val_metrics))
default_handlers.append("ValidationHandler")
val_metrics = self.val_metrics
else:
# set validation metrics to None if no validation data and no validation handler
val_metrics = []

if not any(isinstance(handler, LoggingHandler) for handler in event_handlers):
event_handlers.append(LoggingHandler(train_metrics=train_metrics,
event_handlers.append(LoggingHandler(train_metrics=self.train_metrics,
val_metrics=val_metrics))
default_handlers.append("LoggingHandler")

# if there is a mix of user defined event handlers and default event handlers
# they should have the same set of loss and metrics
if default_handlers:
if default_handlers and len(event_handlers) > len(default_handlers):
msg = "You are training with the following default event handlers: %s. " \
"They use loss and metrics from estimator.prepare_loss_and_metrics(). " \
"Please use the same set of metrics for all your other handlers." % \
Expand All @@ -374,7 +382,7 @@ def _prepare_default_handlers(self, val_data, event_handlers):
# remove None metric references
references = set([ref for ref in references if ref])
for metric in references:
if metric not in train_metrics + val_metrics:
if metric not in self.train_metrics + self.val_metrics:
msg = "We have added following default handlers for you: %s and used " \
"estimator.prepare_loss_and_metrics() to pass metrics to " \
"those handlers. Please use the same set of metrics " \
Expand Down
7 changes: 4 additions & 3 deletions python/mxnet/gluon/contrib/estimator/event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
from ....metric import EvalMetric
from ....metric import Loss as metric_loss

__all__ = ['StoppingHandler', 'MetricHandler', 'ValidationHandler',
__all__ = ['TrainBegin', 'TrainEnd', 'EpochBegin', 'EpochEnd','BatchBegin', 'BatchEnd',
'StoppingHandler', 'MetricHandler', 'ValidationHandler',
'LoggingHandler', 'CheckpointHandler', 'EarlyStoppingHandler']

class TrainBegin(object):
Expand Down Expand Up @@ -513,8 +514,8 @@ def _save_symbol(self, estimator):
sym = estimator.net._cached_graph[1]
sym.save(symbol_file)
else:
self.logger.info("Model architecture(symbol file) is not saved, please use HybridBlock"
"to construct your model, can call net.hybridize() before passing to"
self.logger.info("Model architecture(symbol file) is not saved, please use HybridBlock "
"to construct your model, can call net.hybridize() before passing to "
"Estimator in order to save model architecture as %s.", symbol_file)

def _save_params_and_trainer(self, estimator, file_prefix):
Expand Down
2 changes: 2 additions & 0 deletions tests/python/unittest/test_gluon_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@

import sys
import unittest
import warnings

import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn
from mxnet.gluon.contrib.estimator import *
from mxnet.gluon.contrib.estimator.event_handler import *
from nose.tools import assert_raises


Expand Down

0 comments on commit a69b406

Please sign in to comment.