Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Don't use global logger. Make it specific to each estimator.
Browse files Browse the repository at this point in the history
gluon.contrib.estimator used a global Logger obtained via
`logging.getLogger('gluon.contrib.estimator.event_handlers')`. This logger used
to be configured every time a gluon.contrib.estimator.LoggingHandler was
created, which is a bug. We can't modify a global Logger instance whenever the
user creates an Estimator and a LoggingHandler.

Instead, this commit separates the LoggingHandler (responsible for logging
metadata during estimator.fit) from the configuration of the Logger.

We expose the Logger as attribute of the Estimator, and configure it to output
to stdout by default. Instructions are given how users can configure the
Estimator.logger to log to a file instead.
  • Loading branch information
leezu committed Nov 14, 2019
1 parent ae9647b commit 2bcdedc
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 85 deletions.
25 changes: 24 additions & 1 deletion python/mxnet/gluon/contrib/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
"""Gluon Estimator"""

import copy
import logging
import sys
import warnings

from .event_handler import MetricHandler, ValidationHandler, LoggingHandler, StoppingHandler
Expand Down Expand Up @@ -57,6 +59,25 @@ class Estimator(object):
Trainer to apply optimizer on network parameters.
context : Context or list of Context
Device(s) to run the training on.
"""

logger = None
"""logging.Logger object associated with the Estimator.
The logger is used for all logs generated by this estimator and its
handlers. A new logging.Logger is created during Estimator construction and
configured to write all logs with level logging.INFO or higher to
sys.stdout.
You can modify the logging settings using the standard Python methods. For
example, to save logs to a file in addition to printing them to stdout
output, you can attach a logging.FileHandler to the logger.
>>> est = Estimator(net, loss)
>>> import logging
>>> est.logger.addHandler(logging.FileHandler(filename))
"""

def __init__(self, net,
Expand All @@ -65,13 +86,15 @@ def __init__(self, net,
initializer=None,
trainer=None,
context=None):

self.net = net
self.loss = self._check_loss(loss)
self._train_metrics = _check_metrics(metrics)
self._add_default_training_metrics()
self._add_validation_metrics()

self.logger = logging.Logger(name='Estimator', level=logging.INFO)
self.logger.addHandler(logging.StreamHandler(sys.stdout))

self.context = self._check_context(context)
self._initialize(initializer)
self.trainer = self._check_trainer(trainer)
Expand Down
131 changes: 50 additions & 81 deletions python/mxnet/gluon/contrib/estimator/event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,21 @@

import logging
import os
import sys
import time
import warnings

import numpy as np

from ....metric import EvalMetric, CompositeEvalMetric
from ....metric import CompositeEvalMetric, EvalMetric
from ....metric import Loss as metric_loss
from .utils import _check_metrics

__all__ = ['TrainBegin', 'TrainEnd', 'EpochBegin', 'EpochEnd', 'BatchBegin', 'BatchEnd',
'StoppingHandler', 'MetricHandler', 'ValidationHandler',
'LoggingHandler', 'CheckpointHandler', 'EarlyStoppingHandler']


class EventHandler(object):
pass

Expand Down Expand Up @@ -194,7 +196,6 @@ def __init__(self,
# order to be called among all callbacks
# validation metrics need to be calculated before other callbacks can access them
self.priority = -np.Inf
self.logger = logging.getLogger(__name__)

def train_begin(self, estimator, *args, **kwargs):
# reset epoch and batch counter
Expand All @@ -211,7 +212,7 @@ def batch_end(self, estimator, *args, **kwargs):
for monitor in self.val_metrics:
name, value = monitor.get()
msg += '%s: %.4f, ' % (name, value)
self.logger.info(msg.rstrip(','))
estimator.logger.info(msg.rstrip(','))

def epoch_end(self, estimator, *args, **kwargs):
self.current_epoch += 1
Expand All @@ -228,12 +229,6 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat
Parameters
----------
file_name : str
File name to save the logs.
file_location : str
File location to save the logs.
filemode : str, default 'a'
Logging file mode, default using append mode.
verbose : int, default LOG_PER_EPOCH
Limit the granularity of metrics displayed during training process.
verbose=LOG_PER_EPOCH: display metrics every epoch
Expand All @@ -247,25 +242,10 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat
LOG_PER_EPOCH = 1
LOG_PER_BATCH = 2

def __init__(self, file_name=None,
file_location=None,
filemode='a',
verbose=LOG_PER_EPOCH,
def __init__(self, verbose=LOG_PER_EPOCH,
train_metrics=None,
val_metrics=None):
super(LoggingHandler, self).__init__()
self.logger = logging.getLogger(__name__)
self.logger.setLevel(logging.INFO)
self._added_logging_handlers = [logging.StreamHandler()]
# save logger to file only if file name or location is specified
if file_name or file_location:
file_name = file_name or 'estimator_log'
file_location = file_location or './'
file_handler = logging.FileHandler(os.path.join(file_location, file_name), mode=filemode)
self._added_logging_handlers.append(file_handler)
for handler in self._added_logging_handlers:
self.logger.addHandler(handler)

if verbose not in [self.LOG_PER_EPOCH, self.LOG_PER_BATCH]:
raise ValueError("verbose level must be either LOG_PER_EPOCH or "
"LOG_PER_BATCH, received %s. "
Expand All @@ -281,24 +261,18 @@ def __init__(self, file_name=None,
# it will also shut down logging at train end
self.priority = np.Inf

def __del__(self):
for handler in self._added_logging_handlers:
handler.flush()
self.logger.removeHandler(handler)
handler.close()

def train_begin(self, estimator, *args, **kwargs):
self.train_start = time.time()
trainer = estimator.trainer
optimizer = trainer.optimizer.__class__.__name__
lr = trainer.learning_rate
self.logger.info("Training begin: using optimizer %s "
"with current learning rate %.4f ",
optimizer, lr)
estimator.logger.info("Training begin: using optimizer %s "
"with current learning rate %.4f ",
optimizer, lr)
if estimator.max_epoch:
self.logger.info("Train for %d epochs.", estimator.max_epoch)
estimator.logger.info("Train for %d epochs.", estimator.max_epoch)
else:
self.logger.info("Train for %d batches.", estimator.max_batch)
estimator.logger.info("Train for %d batches.", estimator.max_batch)
# reset all counters
self.current_epoch = 0
self.batch_index = 0
Expand All @@ -311,13 +285,7 @@ def train_end(self, estimator, *args, **kwargs):
for metric in self.train_metrics + self.val_metrics:
name, value = metric.get()
msg += '%s: %.4f, ' % (name, value)
self.logger.info(msg.rstrip(', '))
# make a copy of handler list and remove one by one
# as removing handler will edit the handler list
for handler in self.logger.handlers[:]:
handler.close()
self.logger.removeHandler(handler)
logging.shutdown()
estimator.logger.info(msg.rstrip(', '))

def batch_begin(self, estimator, *args, **kwargs):
if self.verbose == self.LOG_PER_BATCH:
Expand All @@ -334,14 +302,14 @@ def batch_end(self, estimator, *args, **kwargs):
# only log current training loss & metric after each batch
name, value = metric.get()
msg += '%s: %.4f, ' % (name, value)
self.logger.info(msg.rstrip(', '))
estimator.logger.info(msg.rstrip(', '))
self.batch_index += 1

def epoch_begin(self, estimator, *args, **kwargs):
if self.verbose >= self.LOG_PER_EPOCH:
self.epoch_start = time.time()
self.logger.info("[Epoch %d] Begin, current learning rate: %.4f",
self.current_epoch, estimator.trainer.learning_rate)
estimator.logger.info("[Epoch %d] Begin, current learning rate: %.4f",
self.current_epoch, estimator.trainer.learning_rate)

def epoch_end(self, estimator, *args, **kwargs):
if self.verbose >= self.LOG_PER_EPOCH:
Expand All @@ -350,7 +318,7 @@ def epoch_end(self, estimator, *args, **kwargs):
for monitor in self.train_metrics + self.val_metrics:
name, value = monitor.get()
msg += '%s: %.4f, ' % (name, value)
self.logger.info(msg.rstrip(', '))
estimator.logger.info(msg.rstrip(', '))
self.current_epoch += 1
self.batch_index = 0

Expand Down Expand Up @@ -424,7 +392,6 @@ def __init__(self,
self.max_checkpoints = max_checkpoints
self.resume_from_checkpoint = resume_from_checkpoint
self.saved_checkpoints = []
self.logger = logging.getLogger(__name__)
if self.save_best:
if mode not in ['auto', 'min', 'max']:
warnings.warn('ModelCheckpoint mode %s is unknown, '
Expand All @@ -443,14 +410,14 @@ def __init__(self,
else:
# use greater for accuracy and f1 and less otherwise
if 'acc' or 'f1' in self.monitor.get()[0].lower():
self.logger.info("`greater` operator will be used to determine "
"if %s has improved, please use `min` for mode "
"if you want otherwise", self.monitor.get()[0])
warnings.warn("`greater` operator will be used to determine if %s has improved. "
"Please specify `mode='min'` to use the `less` operator. "
"Specify `mode='max' to disable this warning.`", self.monitor.get()[0])
self.monitor_op = np.greater
else:
self.logger.info("`less` operator will be used to determine "
"if %s has improved, please use `max` for mode "
"if you want otherwise", self.monitor.get()[0])
warnings.warn("`less` operator will be used to determine if %s has improved. "
"Please specify `mode='max'` to use the `greater` operator. "
"Specify `mode='min' to disable this warning.`", self.monitor.get()[0])
self.monitor_op = np.less

def train_begin(self, estimator, *args, **kwargs):
Expand Down Expand Up @@ -501,9 +468,9 @@ def _save_checkpoint(self, estimator):
prefix = "%s-epoch%dbatch%d" % (self.model_prefix, save_epoch_number, save_batch_number)
self._save_params_and_trainer(estimator, prefix)
if self.verbose > 0:
self.logger.info('[Epoch %d] CheckpointHandler: trained total %d batches, '
'saving model at %s with prefix: %s',
self.current_epoch, self.current_batch + 1, self.model_dir, prefix)
estimator.logger.info('[Epoch %d] CheckpointHandler: trained total %d batches, '
'saving model at %s with prefix: %s',
self.current_epoch, self.current_batch + 1, self.model_dir, prefix)

if self.save_best:
monitor_name, monitor_value = self.monitor.get()
Expand All @@ -519,28 +486,30 @@ def _save_checkpoint(self, estimator):
self._save_params_and_trainer(estimator, prefix)
self.best = monitor_value
if self.verbose > 0:
self.logger.info('[Epoch %d] CheckpointHandler: '
'%s improved from %0.5f to %0.5f, '
'updating best model at %s with prefix: %s',
self.current_epoch, monitor_name,
self.best, monitor_value, self.model_dir, prefix)
estimator.logger.info('[Epoch %d] CheckpointHandler: '
'%s improved from %0.5f to %0.5f, '
'updating best model at %s with prefix: %s',
self.current_epoch, monitor_name,
self.best, monitor_value, self.model_dir, prefix)
else:
if self.verbose > 0:
self.logger.info('[Epoch %d] CheckpointHandler: '
'%s did not improve from %0.5f, '
'skipping updating best model',
self.current_batch, monitor_name,
self.best)
estimator.logger.info('[Epoch %d] CheckpointHandler: '
'%s did not improve from %0.5f, '
'skipping updating best model',
self.current_batch, monitor_name,
self.best)

def _save_symbol(self, estimator):
symbol_file = os.path.join(self.model_dir, self.model_prefix + '-symbol.json')
if hasattr(estimator.net, '_cached_graph') and estimator.net._cached_graph:
sym = estimator.net._cached_graph[1]
sym.save(symbol_file)
else:
self.logger.info("Model architecture(symbol file) is not saved, please use HybridBlock "
"to construct your model, and call net.hybridize() before passing to "
"Estimator in order to save model architecture as %s.", symbol_file)
estimator.logger.info(
"Model architecture(symbol file) is not saved, please use HybridBlock "
"to construct your model, and call net.hybridize() before passing to "
"Estimator in order to save model architecture as %s.",
symbol_file)

def _save_params_and_trainer(self, estimator, file_prefix):
param_file = os.path.join(self.model_dir, file_prefix + '.params')
Expand Down Expand Up @@ -579,7 +548,7 @@ def _resume_from_checkpoint(self, estimator):
msg += "%d batches" % estimator.max_batch
else:
msg += "%d epochs" % estimator.max_epoch
self.logger.info(msg)
estimator.logger.info(msg)
else:
msg = "CheckpointHandler: Checkpoint resumed from epoch %d batch %d, " \
"continue to train for " % (self.trained_epoch, self.trained_batch)
Expand Down Expand Up @@ -607,7 +576,7 @@ def _resume_from_checkpoint(self, estimator):
assert os.path.exists(trainer_file), "Failed to load checkpoint, %s does not exist" % trainer_file
estimator.net.load_parameters(param_file, ctx=estimator.context)
estimator.trainer.load_states(trainer_file)
self.logger.warning(msg)
estimator.logger.warning(msg)

def _find_max_iteration(self, dir, prefix, start, end, saved_checkpoints=None):
error_msg = "Error parsing checkpoint file, please check your " \
Expand Down Expand Up @@ -672,7 +641,6 @@ def __init__(self,
self.stopped_epoch = 0
self.current_epoch = 0
self.stop_training = False
self.logger = logging.getLogger(__name__)

if mode not in ['auto', 'min', 'max']:
warnings.warn('EarlyStopping mode %s is unknown, '
Expand All @@ -688,14 +656,14 @@ def __init__(self,
self.monitor_op = np.greater
else:
if 'acc' or 'f1' in self.monitor.get()[0].lower():
self.logger.info("`greater` operator is used to determine "
"if %s has improved, please use `min` for mode "
"if you want otherwise", self.monitor.get()[0])
warnings.warn("`greater` operator will be used to determine if %s has improved. "
"Please specify `mode='min'` to use the `less` operator. "
"Specify `mode='max' to disable this warning.`", self.monitor.get()[0])
self.monitor_op = np.greater
else:
self.logger.info("`less` operator is used to determine "
"if %s has improved, please use `max` for mode "
"if you want otherwise", self.monitor.get()[0])
warnings.warn("`less` operator will be used to determine if %s has improved. "
"Please specify `mode='max'` to use the `greater` operator. "
"Specify `mode='min' to disable this warning.`", self.monitor.get()[0])
self.monitor_op = np.less

if self.monitor_op == np.greater: # pylint: disable=comparison-with-callable
Expand Down Expand Up @@ -733,5 +701,6 @@ def epoch_end(self, estimator, *args, **kwargs):

def train_end(self, estimator, *args, **kwargs):
if self.stopped_epoch > 0:
self.logger.info('[Epoch %d] EarlyStoppingHanlder: early stopping due to %s not improving',
self.stopped_epoch, self.monitor.get()[0])
estimator.logger.info('[Epoch %d] EarlyStoppingHanlder: '
'early stopping due to %s not improving',
self.stopped_epoch, self.monitor.get()[0])
8 changes: 5 additions & 3 deletions tests/python/unittest/test_gluon_event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

import os
import logging

import mxnet as mx
from common import TemporaryDirectory
Expand Down Expand Up @@ -143,11 +144,12 @@ def test_logging():
ce_loss = loss.SoftmaxCrossEntropyLoss()
acc = mx.metric.Accuracy()
est = estimator.Estimator(net, loss=ce_loss, metrics=acc)

est.logger.addHandler(logging.FileHandler(output_dir))

train_metrics = est.train_metrics
val_metrics = est.val_metrics
logging_handler = event_handler.LoggingHandler(file_name=file_name,
file_location=tmpdir,
train_metrics=train_metrics,
logging_handler = event_handler.LoggingHandler(train_metrics=train_metrics,
val_metrics=val_metrics)
est.fit(test_data, event_handlers=[logging_handler], epochs=3)
assert logging_handler.batch_index == 0
Expand Down

0 comments on commit 2bcdedc

Please sign in to comment.