Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-1396][Fit-API] Update default handler logic #14765

Merged
merged 5 commits into from
Apr 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions ci/docker/runtime_functions.sh
Original file line number Diff line number Diff line change
Expand Up @@ -1296,18 +1296,12 @@ nightly_scala_demo_test_cpu() {
bash bin/run_im.sh
}

nightly_estimator_gpu() {
nightly_estimator() {
set -ex
cd /work/mxnet/tests/nightly/estimator
export PYTHONPATH=/work/mxnet/python/
python test_estimator_cnn.py --type gpu
python test_sentiment_rnn.py --type gpu
}

nightly_estimator_cpu() {
set -ex
cd /work/mxnet/tests/nightly/estimator
export PYTHONPATH=/work/mxnet/python/
python test_estimator_cnn.py --type cpu
python test_sentiment_rnn.py --type cpu
}
Expand Down
69 changes: 53 additions & 16 deletions python/mxnet/gluon/contrib/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ def prepare_loss_and_metrics(self):
Based on loss functions and training metrics in estimator
Create metric wrappers to record loss values,
Create copies of train loss/metric objects to record validation values
Returns train_metrics and val_metrics

"""
if any(not hasattr(self, attribute) for attribute in
['train_metrics', 'val_metrics']):
Expand All @@ -165,8 +167,7 @@ def prepare_loss_and_metrics(self):
self.train_metrics = [Accuracy()]
self.val_metrics = []
for loss in self.loss:
self.train_metrics.append(Loss("Train " + ''.join([i for i in loss.name if not i.isdigit()])))
self.val_metrics.append(Loss("Validation " + ''.join([i for i in loss.name if not i.isdigit()])))
self.train_metrics.append(Loss(''.join([i for i in loss.name if not i.isdigit()])))
for metric in self.train_metrics:
val_metric = copy.deepcopy(metric)
metric.name = "Train " + metric.name
Expand Down Expand Up @@ -231,21 +232,9 @@ def fit(self, train_data,
from a data batch and load into contexts(devices)
"""
self.max_epochs = epochs
event_handlers = event_handlers or []
# provide default logging handler
if not event_handlers:
train_metrics, val_metrics = self.prepare_loss_and_metrics()
event_handlers.append(MetricHandler(train_metrics=train_metrics))
if val_data:
event_handlers.append(ValidationHandler(val_data=val_data, eval_fn=self.evaluate,
val_metrics=val_metrics))
event_handlers.append(LoggingHandler(train_metrics=train_metrics,
val_metrics=val_metrics))
warnings.warn("No Event Handler specified, default %s are used. "
"Please look at gluon.contrib.estimator.event_handler for more detail." %
", ".join([handler.__class__.__name__ for handler in event_handlers]))

event_handlers.sort(key=lambda handler: getattr(handler, 'rank', 0), reverse=True)
# provide default handlers
event_handlers = self._prepare_default_handlers(val_data, event_handlers)

train_begin, epoch_begin, batch_begin, \
batch_end, epoch_end, train_end = self._categorize_handlers(event_handlers)
Expand Down Expand Up @@ -297,6 +286,54 @@ def fit(self, train_data,
for handler in train_end:
handler.train_end(estimator_ref)

def _prepare_default_handlers(self, val_data, event_handlers):
event_handlers = event_handlers or []
default_handlers = []
train_metrics, val_metrics = self.prepare_loss_and_metrics()

if not any(isinstance(handler, MetricHandler) for handler in event_handlers):
event_handlers.append(MetricHandler(train_metrics=train_metrics))
default_handlers.append("MetricHandler")

if val_data and not any(isinstance(handler, ValidationHandler) for handler in event_handlers):
event_handlers.append(ValidationHandler(val_data=val_data, eval_fn=self.evaluate,
val_metrics=val_metrics))
default_handlers.append("ValidationHandler")

if not any(isinstance(handler, LoggingHandler) for handler in event_handlers):
event_handlers.append(LoggingHandler(train_metrics=train_metrics,
val_metrics=val_metrics))
default_handlers.append("LoggingHandler")

# if there is a mix of user defined event handlers and default event handlers
# they should have the save set of loss and metrics
if default_handlers:
msg = "You are training with the following default event handlers: %s. " \
"They use loss and metrics from estimator.prepare_loss_and_metrics(). " \
"Please use the same set of metrics for all your other handlers." % \
", ".join(default_handlers)
warnings.warn(msg)
references = []
for handler in event_handlers:
for attribute in dir(handler):
if any(keyword in attribute for keyword in ['metric' or 'monitor']):
reference = getattr(handler, attribute)
if isinstance(reference, list):
references += reference
else:
references.append(reference)
for metric in references:
if metric and metric not in train_metrics + val_metrics:
msg = "We have added following default handlers for you: %s and used " \
"estimator.prepare_loss_and_metrics() to pass metrics to " \
"those handlers. Please use the same set of metrics " \
"for all your handlers." % \
", ".join(default_handlers)
raise ValueError(msg)

event_handlers.sort(key=lambda handler: getattr(handler, 'priority', 0))
return event_handlers

def _categorize_handlers(self, event_handlers):
"""
categorize handlers into 6 event lists to avoid calling empty methods
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/gluon/contrib/estimator/event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def __init__(self,
self.save_best_only = save_best_only
if self.save_best_only and not isinstance(self.monitor, EvalMetric):
raise ValueError("To save best model only, please provide one of the metric objects as monitor, "
"You can create these objects using estimator.prepare_loss_and_metric()")
"You can get these objects using estimator.prepare_loss_and_metric()")
self.epoch_period = epoch_period
self.batch_period = batch_period
self.num_batches = 0
Expand Down
16 changes: 0 additions & 16 deletions tests/nightly/Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -136,22 +136,6 @@ core_logic: {
utils.docker_run('ubuntu_nightly_cpu', 'nightly_test_javascript', false)
}
}
},
'Gluon estimator: GPU': {
node(NODE_LINUX_GPU) {
ws('workspace/estimator-test-gpu') {
utils.unpack_and_init('gpu', mx_lib)
utils.docker_run('ubuntu_nightly_gpu', 'nightly_estimator_gpu', true)
}
}
},
'Gluon estimator: CPU': {
node(NODE_LINUX_CPU) {
ws('workspace/estimator-test-cpu') {
utils.unpack_and_init('cpu', mx_lib)
utils.docker_run('ubuntu_nightly_cpu', 'nightly_estimator_cpu', false)
}
}
}
}
}
Expand Down
8 changes: 8 additions & 0 deletions tests/nightly/JenkinsfileForBinaries
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,14 @@ core_logic: {
utils.docker_run('ubuntu_nightly_gpu', 'nightly_tutorial_test_ubuntu_python3_gpu', true, '1500m')
}
}
},
'Gluon estimator: GPU': {
node(NODE_LINUX_GPU) {
ws('workspace/estimator-test-gpu') {
utils.unpack_and_init('gpu', mx_lib)
utils.docker_run('ubuntu_nightly_gpu', 'nightly_estimator', true)
}
}
}
}
}
Expand Down
Loading