Skip to content

Commit

Permalink
[MXNet-1340][Fit API]Update train stats (apache#14494)
Browse files Browse the repository at this point in the history
* add train history

* update history

* update test

* avoid calling empty methods

* remove train history object

* fix pylint

* add unit test

* fix test

* update categorize handlers
  • Loading branch information
roywei committed May 15, 2019
1 parent 02e7c9b commit 92c3c21
Show file tree
Hide file tree
Showing 5 changed files with 280 additions and 181 deletions.
147 changes: 85 additions & 62 deletions python/mxnet/gluon/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import copy
import warnings

from .event_handler import LoggingHandler
from .event_handler import EventHandler, LoggingHandler
from ... import gluon, autograd
from ...context import Context, cpu, gpu, num_gpus
from ...io import DataIter
Expand All @@ -39,27 +39,26 @@ class Estimator(object):
Parameters
----------
loss : Loss or list of Loss
loss : gluon.loss.Loss or list of gluon.loss.Loss
Loss(objective functions) to calculate during training
metrics : EvalMetric or list of EvalMetric
Metrics for evaluating models
initializer : Initializer
initializer to initialize the network
trainers : Trainer or list of Trainer
Trainers to apply optimizers on network parameters
trainer : Trainer
Trainer to apply optimizer on network parameters
context : Context or list of Context
devices to run the training on
"""

def __init__(self, net,
loss=None,
loss,
metrics=None,
initializer=None,
trainers=None,
trainer=None,
context=None):

self.net = net
self.stop_training = False

if isinstance(loss, gluon.loss.Loss):
self.loss = [loss]
Expand All @@ -86,27 +85,14 @@ def __init__(self, net,

# store training statistics
self.train_stats = {}
self.train_stats['epochs'] = []
self.train_stats['learning_rate'] = []
# current step of the epoch
self.train_stats['step'] = ''
for metric in self.train_metrics:
# record a history of metrics over each epoch
self.train_stats['train_' + metric.name] = []
# only record the latest metric numbers after each batch
self.train_stats['batch_' + metric.name] = 0.
for metric in self.val_metrics:
self.train_stats['val_' + metric.name] = []

# separate train and validation
self.train_loss_metrics = []
self.val_loss_metrics = []
# using the metric wrapper for loss to record loss value
for l in self.loss:
self.train_loss_metrics.append(Loss(l.name))
self.val_loss_metrics.append(Loss(l.name))
self.train_stats['train_' + l.name] = []
self.train_stats['val_' + l.name] = []
# only record the latest loss numbers after each batch
self.train_stats['batch_' + l.name] = 0.

# handle context
if isinstance(context, Context):
Expand All @@ -127,15 +113,14 @@ def __init__(self, net,
raise ValueError("context must be a Context or a list of Context, "
"refer to mxnet.Context:{}".format(context))


# initialize the network
self.initializer = initializer
if self.initializer:
if self._is_initialized():
# if already initialized, re-init with user specified initializer
warnings.warn("Network already initialized, re-initializing with %s. "
"You don't need to pass initializer if you already "
"initialized your net."% type(self.initializer).__name__)
"initialized your net." % type(self.initializer).__name__)
self.net.initialize(init=self.initializer, ctx=self.context, force_reinit=True)
else:
# initialize with user specified initializer
Expand All @@ -144,16 +129,17 @@ def __init__(self, net,
if not self._is_initialized():
self.net.initialize(ctx=self.context)

# handle trainers
if isinstance(trainers, gluon.Trainer):
self.trainers = [trainers]
elif not trainers:
# handle trainer
if not trainer:
warnings.warn("No trainer specified, default SGD optimizer "
"with learning rate 0.001 is used.")
self.trainers = [gluon.Trainer(self.net.collect_params(),
'sgd', {'learning_rate': 0.001})]
self.trainer = gluon.Trainer(self.net.collect_params(),
'sgd', {'learning_rate': 0.001})
elif not isinstance(trainer, gluon.Trainer):
raise ValueError("Trainer must be a Gluon Trainer instance, refer to "
"gluon.Trainer:{}".format(trainer))
else:
raise ValueError("Invalid trainer specified, please provide a valid gluon.Trainer")
self.trainer = trainer

def _is_initialized(self):
param_dict = self.net.collect_params()
Expand Down Expand Up @@ -212,8 +198,12 @@ def evaluate(self,
# update metrics
for metric in self.val_metrics:
metric.update(label, pred)
name, value = metric.get()
self.train_stats['val_' + name] = value
for loss, loss_metric, in zip(losses, self.val_loss_metrics):
loss_metric.update(0, [l for l in loss])
name, value = loss_metric.get()
self.train_stats['val_' + name] = value

def fit(self, train_data,
val_data=None,
Expand Down Expand Up @@ -241,27 +231,38 @@ def fit(self, train_data,
from a data batch and load into contexts(devices)
"""


self.epochs = epochs
self.max_epoch = epochs
if not batch_size:
batch_size = 32 * len(self.context)
self.batch_size = 32 * len(self.context)
else:
self.batch_size = batch_size
self.stop_training = False
self.samples = None
self.batch_idx = 0

event_handlers = event_handlers or []
# provide default logging handler
if not event_handlers or \
not any(isinstance(handler, LoggingHandler) for handler in event_handlers):
event_handlers.append(LoggingHandler(self))
event_handlers.append(LoggingHandler())

# training begin
train_begin, epoch_begin, batch_begin, \
batch_end, epoch_end, train_end = self._categorize_handlers(event_handlers)

# passing estimator to event handlers so they can access estimator information
# when a event is triggered
for handler in event_handlers:
handler.estimator = self

# training begin
for handler in train_begin:
handler.train_begin()

for epoch in range(epochs):
for epoch in range(self.max_epoch):
# epoch begin
self.train_stats['epochs'].append(epoch)
self.train_stats['learning_rate'].append(self.trainers[0].learning_rate)
self.current_epoch = epoch

for handler in event_handlers:
for handler in epoch_begin:
handler.epoch_begin()

for metric in self.train_metrics + self.train_loss_metrics:
Expand All @@ -282,7 +283,7 @@ def fit(self, train_data,
data, label = batch_fn(batch, self.context)

# batch begin
for handler in event_handlers:
for handler in batch_begin:
handler.batch_begin()

with autograd.record():
Expand All @@ -298,42 +299,64 @@ def fit(self, train_data,
# update train metrics
for metric in self.train_metrics:
metric.update(label, pred)
self.train_stats['batch_' + metric.name] = metric.get()[1]
# get metric name and current value and update train stats
name, value = metric.get()
self.train_stats['train_' + name] = value

# update loss
for loss, loss_metric, in zip(losses, self.train_loss_metrics):
loss_metric.update(0, [l for l in loss])
self.train_stats['batch_' + loss_metric.name] = loss_metric.get()[1]

try:
completed_samples = len(train_data._dataset) if i == len(train_data._dataset) - 1 \
else batch_size * (i + 1)
# We need to check if this is the last batch in the current epoch and select
# the value to print appropriately
self.train_stats['step'] = "{}/{}".format(completed_samples, len(train_data._dataset))
except AttributeError:
self.train_stats['step'] = i
name, value = loss_metric.get()
self.train_stats['train_' + name] = value

for trainer in self.trainers:
trainer.step(batch_size)
self.batch_idx = i
# record trained samples v.s. total samples if using Gluon DataLoader
if isinstance(train_data, gluon.data.DataLoader):
self.samples = "{}/{}".format(self.batch_size * (i + 1), len(train_data._dataset))

self.trainer.step(self.batch_size)
# batch end
for handler in event_handlers:
for handler in batch_end:
handler.batch_end()

if val_data:
self.evaluate(val_data, batch_fn)

for metric in self.train_metrics + self.train_loss_metrics:
self.train_stats['train_' + metric.name].append(metric.get()[1])
for metric in self.val_metrics + self.val_loss_metrics:
self.train_stats['val_' + metric.name].append(metric.get()[1])

# epoch end
for handler in event_handlers:
for handler in epoch_end:
handler.epoch_end()

if self.stop_training:
break

# train end
for handler in event_handlers:
for handler in train_end:
handler.train_end()

def _categorize_handlers(self, event_handlers):
"""
categorize handlers into 6 event lists to avoid calling empty methods
for example, only event handlers with train_begin method
implemented will be called at train begin
"""

train_begin = []
epoch_begin = []
batch_begin = []
batch_end = []
epoch_end = []
train_end = []
for handler in event_handlers:
if not handler.__class__.train_begin == EventHandler.train_begin:
train_begin.append(handler)
if not handler.__class__.epoch_begin == EventHandler.epoch_begin:
epoch_begin.append(handler)
if not handler.__class__.batch_begin == EventHandler.batch_begin:
batch_begin.append(handler)
if not handler.__class__.batch_end == EventHandler.batch_end:
batch_end.append(handler)
if not handler.__class__.epoch_end == EventHandler.epoch_end:
epoch_end.append(handler)
if not handler.__class__.train_end == EventHandler.train_end:
train_end.append(handler)
return train_begin, epoch_begin, batch_begin, batch_end, epoch_end, train_end
Loading

0 comments on commit 92c3c21

Please sign in to comment.