Skip to content

Commit

Permalink
[Estimator] refactor estimator and clarify docs (apache#16694)
Browse files Browse the repository at this point in the history
* refactor estimator and clarify docs

* fix info message and test

* clean up after releasing logging handler
  • Loading branch information
szha authored and yajiedesign committed Nov 6, 2019
1 parent 495df4d commit d51cb0f
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 107 deletions.
134 changes: 57 additions & 77 deletions python/mxnet/gluon/contrib/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,14 @@

from .event_handler import MetricHandler, ValidationHandler, LoggingHandler, StoppingHandler
from .event_handler import TrainBegin, EpochBegin, BatchBegin, BatchEnd, EpochEnd, TrainEnd
from .utils import _check_metrics
from .event_handler import _check_event_handlers
from .utils import _check_metrics, _suggest_metric_for_loss, _check_handler_metric_ref
from ...data import DataLoader
from ...loss import SoftmaxCrossEntropyLoss
from ...loss import Loss as gluon_loss
from ...trainer import Trainer
from ...utils import split_and_load
from .... import autograd
from ....context import Context, cpu, gpu, num_gpus
from ....metric import Accuracy
from ....metric import Loss as metric_loss

__all__ = ['Estimator']
Expand All @@ -48,8 +47,8 @@ class Estimator(object):
----------
net : gluon.Block
The model used for training.
loss : gluon.loss.Loss or list of gluon.loss.Loss
Loss(objective functions) to calculate during training.
loss : gluon.loss.Loss
Loss (objective) function to calculate during training.
metrics : EvalMetric or list of EvalMetric
Metrics for evaluating models.
initializer : Initializer
Expand All @@ -69,19 +68,17 @@ def __init__(self, net,

self.net = net
self.loss = self._check_loss(loss)
self.train_metrics = _check_metrics(metrics)
self._train_metrics = _check_metrics(metrics)
self._add_default_training_metrics()
self._add_validation_metrics()

self.context = self._check_context(context)
self._initialize(initializer)
self.trainer = self._check_trainer(trainer)

def _check_loss(self, loss):
if isinstance(loss, gluon_loss):
loss = [loss]
elif isinstance(loss, list) and all([isinstance(l, gluon_loss) for l in loss]):
loss = loss
else:
raise ValueError("loss must be a Loss or a list of Loss, "
if not isinstance(loss, gluon_loss):
raise ValueError("loss must be a Loss, "
"refer to gluon.loss.Loss:{}".format(loss))
return loss

Expand Down Expand Up @@ -166,31 +163,30 @@ def _get_data_and_label(self, batch, ctx, batch_axis=0):
label = split_and_load(label, ctx_list=ctx, batch_axis=batch_axis)
return data, label

def prepare_loss_and_metrics(self):
"""
Based on loss functions and training metrics in estimator
Create metric wrappers to record loss values,
Create copies of train loss/metric objects to record validation values
def _add_default_training_metrics(self):
if not self._train_metrics:
suggested_metric = _suggest_metric_for_loss(self.loss)
if suggested_metric:
self._train_metrics = [suggested_metric]
loss_name = self.loss.name.rstrip('1234567890')
self._train_metrics.append(metric_loss(loss_name))

Returns
-------
train_metrics, val_metrics
"""
if any(not hasattr(self, attribute) for attribute in
['train_metrics', 'val_metrics']):
# Use default mx.metric.Accuracy() for SoftmaxCrossEntropyLoss()
if not self.train_metrics and any([isinstance(l, SoftmaxCrossEntropyLoss) for l in self.loss]):
self.train_metrics = [Accuracy()]
self.val_metrics = []
for loss in self.loss:
# remove trailing numbers from loss name to avoid confusion
self.train_metrics.append(metric_loss(loss.name.rstrip('1234567890')))
for metric in self.train_metrics:
val_metric = copy.deepcopy(metric)
metric.name = "train " + metric.name
val_metric.name = "validation " + val_metric.name
self.val_metrics.append(val_metric)
return self.train_metrics, self.val_metrics
for metric in self._train_metrics:
metric.name = "training " + metric.name

def _add_validation_metrics(self):
self._val_metrics = [copy.deepcopy(metric) for metric in self._train_metrics]

for metric in self._val_metrics:
metric.name = "validation " + metric.name

@property
def train_metrics(self):
return self._train_metrics

@property
def val_metrics(self):
return self._val_metrics

def evaluate_batch(self,
val_batch,
Expand All @@ -209,7 +205,7 @@ def evaluate_batch(self,
"""
data, label = self._get_data_and_label(val_batch, self.context, batch_axis)
pred = [self.net(x) for x in data]
loss = [self.loss[0](y_hat, y) for y_hat, y in zip(pred, label)]
loss = [self.loss(y_hat, y) for y_hat, y in zip(pred, label)]
# update metrics
for metric in val_metrics:
if isinstance(metric, metric_loss):
Expand Down Expand Up @@ -275,7 +271,7 @@ def fit_batch(self, train_batch,

with autograd.record():
pred = [self.net(x) for x in data]
loss = [self.loss[0](y_hat, y) for y_hat, y in zip(pred, label)]
loss = [self.loss(y_hat, y) for y_hat, y in zip(pred, label)]

for l in loss:
l.backward()
Expand Down Expand Up @@ -377,63 +373,47 @@ def fit(self, train_data,
handler.train_end(estimator_ref)

def _prepare_default_handlers(self, val_data, event_handlers):
event_handlers = event_handlers or []
default_handlers = []
self.prepare_loss_and_metrics()
event_handlers = _check_event_handlers(event_handlers)
added_default_handlers = []

# no need to add to default handler check as StoppingHandler does not use metrics
event_handlers.append(StoppingHandler(self.max_epoch, self.max_batch))
default_handlers.append("StoppingHandler")
added_default_handlers.append(StoppingHandler(self.max_epoch, self.max_batch))

if not any(isinstance(handler, MetricHandler) for handler in event_handlers):
event_handlers.append(MetricHandler(train_metrics=self.train_metrics))
default_handlers.append("MetricHandler")
added_default_handlers.append(MetricHandler(train_metrics=self.train_metrics))

if not any(isinstance(handler, ValidationHandler) for handler in event_handlers):
# no validation handler
if val_data:
# add default validation handler if validation data found
event_handlers.append(ValidationHandler(val_data=val_data, eval_fn=self.evaluate,
val_metrics=self.val_metrics))
default_handlers.append("ValidationHandler")
val_metrics = self.val_metrics
# add default validation handler if validation data found
added_default_handlers.append(ValidationHandler(val_data=val_data,
eval_fn=self.evaluate,
val_metrics=val_metrics))
else:
# set validation metrics to None if no validation data and no validation handler
val_metrics = []

if not any(isinstance(handler, LoggingHandler) for handler in event_handlers):
event_handlers.append(LoggingHandler(train_metrics=self.train_metrics,
val_metrics=val_metrics))
default_handlers.append("LoggingHandler")
added_default_handlers.append(LoggingHandler(train_metrics=self.train_metrics,
val_metrics=val_metrics))

# if there is a mix of user defined event handlers and default event handlers
# they should have the same set of loss and metrics
if default_handlers and len(event_handlers) != len(default_handlers):
msg = "You are training with the following default event handlers: %s. " \
"They use loss and metrics from estimator.prepare_loss_and_metrics(). " \
"Please use the same set of metrics for all your other handlers." % \
", ".join(default_handlers)
# they should have the same set of metrics
mixing_handlers = event_handlers and added_default_handlers

event_handlers.extend(added_default_handlers)

if mixing_handlers:
msg = "The following default event handlers are added: {}.".format(
", ".join([type(h).__name__ for h in added_default_handlers]))
warnings.warn(msg)
# check if all handlers has the same set of references to loss and metrics
references = []


# check if all handlers have the same set of references to metrics
known_metrics = set(self.train_metrics + self.val_metrics)
for handler in event_handlers:
for attribute in dir(handler):
if any(keyword in attribute for keyword in ['metric' or 'monitor']):
reference = getattr(handler, attribute)
if isinstance(reference, list):
references += reference
else:
references.append(reference)
# remove None metric references
references = set([ref for ref in references if ref])
for metric in references:
if metric not in self.train_metrics + self.val_metrics:
msg = "We have added following default handlers for you: %s and used " \
"estimator.prepare_loss_and_metrics() to pass metrics to " \
"those handlers. Please use the same set of metrics " \
"for all your handlers." % \
", ".join(default_handlers)
raise ValueError(msg)
_check_handler_metric_ref(handler, known_metrics)

event_handlers.sort(key=lambda handler: getattr(handler, 'priority', 0))
return event_handlers
Expand Down
67 changes: 45 additions & 22 deletions python/mxnet/gluon/contrib/estimator/event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.

# coding: utf-8
# pylint: disable=wildcard-import, unused-argument
# pylint: disable=wildcard-import, unused-argument, too-many-ancestors
"""Gluon EventHandlers for Estimators"""

import logging
Expand All @@ -34,33 +34,47 @@
'StoppingHandler', 'MetricHandler', 'ValidationHandler',
'LoggingHandler', 'CheckpointHandler', 'EarlyStoppingHandler']

class EventHandler(object):
pass

class TrainBegin(object):

def _check_event_handlers(handlers):
if isinstance(handlers, EventHandler):
handlers = [handlers]
else:
handlers = handlers or []
if not all([isinstance(handler, EventHandler) for handler in handlers]):
raise ValueError("handlers must be an EventHandler or a list of EventHandler, "
"got: {}".format(handlers))
return handlers


class TrainBegin(EventHandler):
def train_begin(self, estimator, *args, **kwargs):
pass


class TrainEnd(object):
class TrainEnd(EventHandler):
def train_end(self, estimator, *args, **kwargs):
pass


class EpochBegin(object):
class EpochBegin(EventHandler):
def epoch_begin(self, estimator, *args, **kwargs):
pass


class EpochEnd(object):
class EpochEnd(EventHandler):
def epoch_end(self, estimator, *args, **kwargs):
return False


class BatchBegin(object):
class BatchBegin(EventHandler):
def batch_begin(self, estimator, *args, **kwargs):
pass


class BatchEnd(object):
class BatchEnd(EventHandler):
def batch_end(self, estimator, *args, **kwargs):
return False

Expand Down Expand Up @@ -242,14 +256,16 @@ def __init__(self, file_name=None,
super(LoggingHandler, self).__init__()
self.logger = logging.getLogger(__name__)
self.logger.setLevel(logging.INFO)
stream_handler = logging.StreamHandler()
self.logger.addHandler(stream_handler)
self._added_logging_handlers = [logging.StreamHandler()]
# save logger to file only if file name or location is specified
if file_name or file_location:
file_name = file_name or 'estimator_log'
file_location = file_location or './'
file_handler = logging.FileHandler(os.path.join(file_location, file_name), mode=filemode)
self.logger.addHandler(file_handler)
self._added_logging_handlers.append(file_handler)
for handler in self._added_logging_handlers:
self.logger.addHandler(handler)

if verbose not in [self.LOG_PER_EPOCH, self.LOG_PER_BATCH]:
raise ValueError("verbose level must be either LOG_PER_EPOCH or "
"LOG_PER_BATCH, received %s. "
Expand All @@ -265,6 +281,12 @@ def __init__(self, file_name=None,
# it will also shut down logging at train end
self.priority = np.Inf

def __del__(self):
for handler in self._added_logging_handlers:
handler.flush()
self.logger.removeHandler(handler)
handler.close()

def train_begin(self, estimator, *args, **kwargs):
self.train_start = time.time()
trainer = estimator.trainer
Expand Down Expand Up @@ -393,8 +415,8 @@ def __init__(self,
self.model_prefix = model_prefix
self.save_best = save_best
if self.save_best and not isinstance(self.monitor, EvalMetric):
raise ValueError("To save best model only, please provide one of the metric objects as monitor, "
"You can get these objects using estimator.prepare_loss_and_metric()")
raise ValueError("To save best model only, please provide one of the metric objects "
"from estimator.train_metrics and estimator.val_metrics as monitor.")
self.epoch_period = epoch_period
self.batch_period = batch_period
self.current_batch = 0
Expand Down Expand Up @@ -487,10 +509,10 @@ def _save_checkpoint(self, estimator):
monitor_name, monitor_value = self.monitor.get()
# check if monitor exists in train stats
if np.isnan(monitor_value):
warnings.warn(RuntimeWarning('Skipping save best because %s is not updated, make sure you '
'pass one of the metric objects as monitor, '
'you can use estimator.prepare_loss_and_metrics to'
'create all metric objects', monitor_name))
warnings.warn(RuntimeWarning(
'Skipping save best because %s is not updated, make sure you pass one of the '
'metric objects estimator.train_metrics and estimator.val_metrics as monitor',
monitor_name))
else:
if self.monitor_op(monitor_value, self.best):
prefix = self.model_prefix + '-best'
Expand All @@ -517,7 +539,7 @@ def _save_symbol(self, estimator):
sym.save(symbol_file)
else:
self.logger.info("Model architecture(symbol file) is not saved, please use HybridBlock "
"to construct your model, can call net.hybridize() before passing to "
"to construct your model, and call net.hybridize() before passing to "
"Estimator in order to save model architecture as %s.", symbol_file)

def _save_params_and_trainer(self, estimator, file_prefix):
Expand Down Expand Up @@ -636,8 +658,9 @@ def __init__(self,
super(EarlyStoppingHandler, self).__init__()

if not isinstance(monitor, EvalMetric):
raise ValueError("Please provide one of the metric objects as monitor, "
"You can create these objects using estimator.prepare_loss_and_metric()")
raise ValueError(
"Please provide one of the metric objects from estimator.train_metrics and "
"estimator.val_metrics as monitor.")
if isinstance(monitor, CompositeEvalMetric):
raise ValueError("CompositeEvalMetric is not supported for EarlyStoppingHandler, "
"please specify a simple metric instead.")
Expand Down Expand Up @@ -693,9 +716,9 @@ def train_begin(self, estimator, *args, **kwargs):
def epoch_end(self, estimator, *args, **kwargs):
monitor_name, monitor_value = self.monitor.get()
if np.isnan(monitor_value):
warnings.warn(RuntimeWarning('%s is not updated, make sure you pass one of the metric objects'
'as monitor, you can use estimator.prepare_loss_and_metrics to'
'create all metric objects', monitor_name))
warnings.warn(RuntimeWarning(
'%s is not updated, make sure you pass one of the metric objects from'
'estimator.train_metrics and estimator.val_metrics as monitor.', monitor_name))
else:
if self.monitor_op(monitor_value - self.min_delta, self.best):
self.best = monitor_value
Expand Down
Loading

0 comments on commit d51cb0f

Please sign in to comment.