Skip to content

Commit

Permalink
[MXNET-1396][Fit-API] Update default handler logic (apache#14765)
Browse files Browse the repository at this point in the history
* move to nightly for binaries

* update default handler

* fix pylint

* trigger ci

* trigger ci
  • Loading branch information
roywei authored and haohuw committed Jun 23, 2019
1 parent 74be350 commit 9cde412
Show file tree
Hide file tree
Showing 7 changed files with 168 additions and 95 deletions.
8 changes: 1 addition & 7 deletions ci/docker/runtime_functions.sh
Original file line number Diff line number Diff line change
Expand Up @@ -1350,18 +1350,12 @@ nightly_scala_demo_test_cpu() {
bash bin/run_im.sh
}

nightly_estimator_gpu() {
nightly_estimator() {
set -ex
cd /work/mxnet/tests/nightly/estimator
export PYTHONPATH=/work/mxnet/python/
python test_estimator_cnn.py --type gpu
python test_sentiment_rnn.py --type gpu
}

nightly_estimator_cpu() {
set -ex
cd /work/mxnet/tests/nightly/estimator
export PYTHONPATH=/work/mxnet/python/
python test_estimator_cnn.py --type cpu
python test_sentiment_rnn.py --type cpu
}
Expand Down
69 changes: 53 additions & 16 deletions python/mxnet/gluon/contrib/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ def prepare_loss_and_metrics(self):
Based on loss functions and training metrics in estimator
Create metric wrappers to record loss values,
Create copies of train loss/metric objects to record validation values
Returns train_metrics and val_metrics
"""
if any(not hasattr(self, attribute) for attribute in
['train_metrics', 'val_metrics']):
Expand All @@ -165,8 +167,7 @@ def prepare_loss_and_metrics(self):
self.train_metrics = [Accuracy()]
self.val_metrics = []
for loss in self.loss:
self.train_metrics.append(Loss("Train " + ''.join([i for i in loss.name if not i.isdigit()])))
self.val_metrics.append(Loss("Validation " + ''.join([i for i in loss.name if not i.isdigit()])))
self.train_metrics.append(Loss(''.join([i for i in loss.name if not i.isdigit()])))
for metric in self.train_metrics:
val_metric = copy.deepcopy(metric)
metric.name = "Train " + metric.name
Expand Down Expand Up @@ -231,21 +232,9 @@ def fit(self, train_data,
from a data batch and load into contexts(devices)
"""
self.max_epochs = epochs
event_handlers = event_handlers or []
# provide default logging handler
if not event_handlers:
train_metrics, val_metrics = self.prepare_loss_and_metrics()
event_handlers.append(MetricHandler(train_metrics=train_metrics))
if val_data:
event_handlers.append(ValidationHandler(val_data=val_data, eval_fn=self.evaluate,
val_metrics=val_metrics))
event_handlers.append(LoggingHandler(train_metrics=train_metrics,
val_metrics=val_metrics))
warnings.warn("No Event Handler specified, default %s are used. "
"Please look at gluon.contrib.estimator.event_handler for more detail." %
", ".join([handler.__class__.__name__ for handler in event_handlers]))

event_handlers.sort(key=lambda handler: getattr(handler, 'rank', 0), reverse=True)
# provide default handlers
event_handlers = self._prepare_default_handlers(val_data, event_handlers)

train_begin, epoch_begin, batch_begin, \
batch_end, epoch_end, train_end = self._categorize_handlers(event_handlers)
Expand Down Expand Up @@ -297,6 +286,54 @@ def fit(self, train_data,
for handler in train_end:
handler.train_end(estimator_ref)

def _prepare_default_handlers(self, val_data, event_handlers):
event_handlers = event_handlers or []
default_handlers = []
train_metrics, val_metrics = self.prepare_loss_and_metrics()

if not any(isinstance(handler, MetricHandler) for handler in event_handlers):
event_handlers.append(MetricHandler(train_metrics=train_metrics))
default_handlers.append("MetricHandler")

if val_data and not any(isinstance(handler, ValidationHandler) for handler in event_handlers):
event_handlers.append(ValidationHandler(val_data=val_data, eval_fn=self.evaluate,
val_metrics=val_metrics))
default_handlers.append("ValidationHandler")

if not any(isinstance(handler, LoggingHandler) for handler in event_handlers):
event_handlers.append(LoggingHandler(train_metrics=train_metrics,
val_metrics=val_metrics))
default_handlers.append("LoggingHandler")

# if there is a mix of user defined event handlers and default event handlers
# they should have the save set of loss and metrics
if default_handlers:
msg = "You are training with the following default event handlers: %s. " \
"They use loss and metrics from estimator.prepare_loss_and_metrics(). " \
"Please use the same set of metrics for all your other handlers." % \
", ".join(default_handlers)
warnings.warn(msg)
references = []
for handler in event_handlers:
for attribute in dir(handler):
if any(keyword in attribute for keyword in ['metric' or 'monitor']):
reference = getattr(handler, attribute)
if isinstance(reference, list):
references += reference
else:
references.append(reference)
for metric in references:
if metric and metric not in train_metrics + val_metrics:
msg = "We have added following default handlers for you: %s and used " \
"estimator.prepare_loss_and_metrics() to pass metrics to " \
"those handlers. Please use the same set of metrics " \
"for all your handlers." % \
", ".join(default_handlers)
raise ValueError(msg)

event_handlers.sort(key=lambda handler: getattr(handler, 'priority', 0))
return event_handlers

def _categorize_handlers(self, event_handlers):
"""
categorize handlers into 6 event lists to avoid calling empty methods
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/gluon/contrib/estimator/event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def __init__(self,
self.save_best_only = save_best_only
if self.save_best_only and not isinstance(self.monitor, EvalMetric):
raise ValueError("To save best model only, please provide one of the metric objects as monitor, "
"You can create these objects using estimator.prepare_loss_and_metric()")
"You can get these objects using estimator.prepare_loss_and_metric()")
self.epoch_period = epoch_period
self.batch_period = batch_period
self.num_batches = 0
Expand Down
16 changes: 0 additions & 16 deletions tests/nightly/Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -136,22 +136,6 @@ core_logic: {
utils.docker_run('ubuntu_nightly_cpu', 'nightly_test_javascript', false)
}
}
},
'Gluon estimator: GPU': {
node(NODE_LINUX_GPU) {
ws('workspace/estimator-test-gpu') {
utils.unpack_and_init('gpu', mx_lib)
utils.docker_run('ubuntu_nightly_gpu', 'nightly_estimator_gpu', true)
}
}
},
'Gluon estimator: CPU': {
node(NODE_LINUX_CPU) {
ws('workspace/estimator-test-cpu') {
utils.unpack_and_init('cpu', mx_lib)
utils.docker_run('ubuntu_nightly_cpu', 'nightly_estimator_cpu', false)
}
}
}
}
}
Expand Down
8 changes: 8 additions & 0 deletions tests/nightly/JenkinsfileForBinaries
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,14 @@ core_logic: {
utils.docker_run('ubuntu_nightly_gpu', 'nightly_tutorial_test_ubuntu_python3_gpu', true, '1500m')
}
}
},
'Gluon estimator: GPU': {
node(NODE_LINUX_GPU) {
ws('workspace/estimator-test-gpu') {
utils.unpack_and_init('gpu', mx_lib)
utils.docker_run('ubuntu_nightly_gpu', 'nightly_estimator', true)
}
}
}
}
}
Expand Down
Loading

0 comments on commit 9cde412

Please sign in to comment.