Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNet-1340][Fit API]Update train stats #14494

Merged
merged 10 commits into from
Apr 3, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 85 additions & 62 deletions python/mxnet/gluon/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import copy
import warnings

from .event_handler import LoggingHandler
from .event_handler import EventHandler, LoggingHandler
from ... import gluon, autograd
from ...context import Context, cpu, gpu, num_gpus
from ...io import DataIter
Expand All @@ -39,27 +39,26 @@ class Estimator(object):

Parameters
----------
loss : Loss or list of Loss
loss : gluon.loss.Loss or list of gluon.loss.Loss
Loss(objective functions) to calculate during training
metrics : EvalMetric or list of EvalMetric
Metrics for evaluating models
initializer : Initializer
initializer to initialize the network
trainers : Trainer or list of Trainer
Trainers to apply optimizers on network parameters
trainer : Trainer
Trainer to apply optimizer on network parameters
context : Context or list of Context
devices to run the training on
"""

def __init__(self, net,
loss=None,
loss,
metrics=None,
initializer=None,
trainers=None,
trainer=None,
context=None):

self.net = net
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you want to set self._estimator = None?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the estimator class, only event handlers should have self._estimator?

self.stop_training = False

if isinstance(loss, gluon.loss.Loss):
self.loss = [loss]
Expand All @@ -86,27 +85,14 @@ def __init__(self, net,

# store training statistics
self.train_stats = {}
self.train_stats['epochs'] = []
self.train_stats['learning_rate'] = []
# current step of the epoch
self.train_stats['step'] = ''
for metric in self.train_metrics:
# record a history of metrics over each epoch
self.train_stats['train_' + metric.name] = []
# only record the latest metric numbers after each batch
self.train_stats['batch_' + metric.name] = 0.
for metric in self.val_metrics:
self.train_stats['val_' + metric.name] = []

# separate train and validation
self.train_loss_metrics = []
self.val_loss_metrics = []
# using the metric wrapper for loss to record loss value
for l in self.loss:
self.train_loss_metrics.append(Loss(l.name))
self.val_loss_metrics.append(Loss(l.name))
self.train_stats['train_' + l.name] = []
self.train_stats['val_' + l.name] = []
# only record the latest loss numbers after each batch
self.train_stats['batch_' + l.name] = 0.

# handle context
if isinstance(context, Context):
Expand All @@ -127,15 +113,14 @@ def __init__(self, net,
raise ValueError("context must be a Context or a list of Context, "
"refer to mxnet.Context:{}".format(context))


# initialize the network
self.initializer = initializer
if self.initializer:
if self._is_initialized():
# if already initialized, re-init with user specified initializer
warnings.warn("Network already initialized, re-initializing with %s. "
"You don't need to pass initializer if you already "
"initialized your net."% type(self.initializer).__name__)
"initialized your net." % type(self.initializer).__name__)
self.net.initialize(init=self.initializer, ctx=self.context, force_reinit=True)
else:
# initialize with user specified initializer
Expand All @@ -144,16 +129,17 @@ def __init__(self, net,
if not self._is_initialized():
self.net.initialize(ctx=self.context)

# handle trainers
if isinstance(trainers, gluon.Trainer):
self.trainers = [trainers]
elif not trainers:
# handle trainer
if not trainer:
warnings.warn("No trainer specified, default SGD optimizer "
"with learning rate 0.001 is used.")
self.trainers = [gluon.Trainer(self.net.collect_params(),
'sgd', {'learning_rate': 0.001})]
self.trainer = gluon.Trainer(self.net.collect_params(),
'sgd', {'learning_rate': 0.001})
elif not isinstance(trainer, gluon.Trainer):
raise ValueError("Trainer must be a Gluon Trainer instance, refer to "
"gluon.Trainer:{}".format(trainer))
else:
raise ValueError("Invalid trainer specified, please provide a valid gluon.Trainer")
self.trainer = trainer

def _is_initialized(self):
param_dict = self.net.collect_params()
Expand Down Expand Up @@ -212,8 +198,12 @@ def evaluate(self,
# update metrics
for metric in self.val_metrics:
metric.update(label, pred)
name, value = metric.get()
self.train_stats['val_' + name] = value
for loss, loss_metric, in zip(losses, self.val_loss_metrics):
loss_metric.update(0, [l for l in loss])
name, value = loss_metric.get()
self.train_stats['val_' + name] = value

def fit(self, train_data,
val_data=None,
Expand Down Expand Up @@ -241,27 +231,38 @@ def fit(self, train_data,
from a data batch and load into contexts(devices)
"""


self.epochs = epochs
self.max_epoch = epochs
if not batch_size:
batch_size = 32 * len(self.context)
self.batch_size = 32 * len(self.context)
else:
self.batch_size = batch_size
self.stop_training = False
self.samples = None
self.batch_idx = 0

event_handlers = event_handlers or []
# provide default logging handler
if not event_handlers or \
not any(isinstance(handler, LoggingHandler) for handler in event_handlers):
event_handlers.append(LoggingHandler(self))
event_handlers.append(LoggingHandler())

# training begin
train_begin, epoch_begin, batch_begin, \
batch_end, epoch_end, train_end = self._categorize_handlers(event_handlers)

# passing estimator to event handlers so they can access estimator information
# when a event is triggered
for handler in event_handlers:
handler.estimator = self
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nswamy This will avoid to ask user passing estimator during event handler construction, reference: #14462 (comment)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering how the user of handler will know that an estimator will be initialized here? Also can you have a setter and getter for the estimator in Handler and not call handler.setEstimator(e) if handler.getEstimator() is not None.

Copy link
Member Author

@roywei roywei Apr 2, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nswamy when user call est.fit(xxx, event_handlers=XXX), this will already associate the event handlers with an estimator instance. I m just helping the user to pass this estimator so they don't need to do so during event handler construction.
The getter and setter are already implemented through the property interface. handler.estimator=self is actually the setter method of property estimator.


# training begin
for handler in train_begin:
handler.train_begin()

for epoch in range(epochs):
for epoch in range(self.max_epoch):
# epoch begin
self.train_stats['epochs'].append(epoch)
self.train_stats['learning_rate'].append(self.trainers[0].learning_rate)
self.current_epoch = epoch

for handler in event_handlers:
for handler in epoch_begin:
handler.epoch_begin()

for metric in self.train_metrics + self.train_loss_metrics:
Expand All @@ -282,7 +283,7 @@ def fit(self, train_data,
data, label = batch_fn(batch, self.context)

# batch begin
for handler in event_handlers:
for handler in batch_begin:
handler.batch_begin()

with autograd.record():
Expand All @@ -298,42 +299,64 @@ def fit(self, train_data,
# update train metrics
for metric in self.train_metrics:
metric.update(label, pred)
self.train_stats['batch_' + metric.name] = metric.get()[1]
# get metric name and current value and update train stats
name, value = metric.get()
self.train_stats['train_' + name] = value

# update loss
for loss, loss_metric, in zip(losses, self.train_loss_metrics):
loss_metric.update(0, [l for l in loss])
self.train_stats['batch_' + loss_metric.name] = loss_metric.get()[1]

try:
completed_samples = len(train_data._dataset) if i == len(train_data._dataset) - 1 \
else batch_size * (i + 1)
# We need to check if this is the last batch in the current epoch and select
# the value to print appropriately
self.train_stats['step'] = "{}/{}".format(completed_samples, len(train_data._dataset))
except AttributeError:
self.train_stats['step'] = i
name, value = loss_metric.get()
self.train_stats['train_' + name] = value

for trainer in self.trainers:
trainer.step(batch_size)
self.batch_idx = i
# record trained samples v.s. total samples if using Gluon DataLoader
if isinstance(train_data, gluon.data.DataLoader):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.samples = "{}/{}".format(self.batch_size * (i + 1), len(train_data._dataset))

self.trainer.step(self.batch_size)
# batch end
for handler in event_handlers:
for handler in batch_end:
handler.batch_end()

if val_data:
self.evaluate(val_data, batch_fn)

for metric in self.train_metrics + self.train_loss_metrics:
self.train_stats['train_' + metric.name].append(metric.get()[1])
for metric in self.val_metrics + self.val_loss_metrics:
self.train_stats['val_' + metric.name].append(metric.get()[1])

# epoch end
for handler in event_handlers:
for handler in epoch_end:
handler.epoch_end()

if self.stop_training:
break

# train end
for handler in event_handlers:
for handler in train_end:
handler.train_end()

def _categorize_handlers(self, event_handlers):
"""
categorize handlers into 6 event lists to avoid calling empty methods
for example, only event handlers with train_begin method
implemented will be called at train begin
"""

train_begin = []
epoch_begin = []
batch_begin = []
batch_end = []
epoch_end = []
train_end = []
for handler in event_handlers:
if not handler.__class__.train_begin == EventHandler.train_begin:
train_begin.append(handler)
if not handler.__class__.epoch_begin == EventHandler.epoch_begin:
epoch_begin.append(handler)
if not handler.__class__.batch_begin == EventHandler.batch_begin:
batch_begin.append(handler)
if not handler.__class__.batch_end == EventHandler.batch_end:
batch_end.append(handler)
if not handler.__class__.epoch_end == EventHandler.epoch_end:
epoch_end.append(handler)
if not handler.__class__.train_end == EventHandler.train_end:
train_end.append(handler)
return train_begin, epoch_begin, batch_begin, batch_end, epoch_end, train_end
Loading