Skip to content

Commit

Permalink
[MXNet-1349][Fit API]Add validation support and unit tests for fit() …
Browse files Browse the repository at this point in the history
…API (apache#14442)

* added estimator unittests

* add more tests for estimator

* added validation logic

* added error handlers, unittests

* improve val stats

* fix pylint

* fix pylint

* update unit test

* fix tests

* fix tests

* updated metrics, val logic

* trigger ci

* trigger ci

* update metric, batch_fn error handler

* update context logic, add default metric
  • Loading branch information
abhinavs95 authored and roywei committed May 15, 2019
1 parent 5b1eb20 commit 02e7c9b
Show file tree
Hide file tree
Showing 3 changed files with 370 additions and 25 deletions.
116 changes: 92 additions & 24 deletions python/mxnet/gluon/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@
# pylint: disable=wildcard-import
"""Gluon Estimator"""

import copy
import warnings

from .event_handler import LoggingHandler
from ... import gluon, autograd
from ...context import Context, cpu, gpu, num_gpus
from ...io import DataIter
from ...metric import EvalMetric, Loss
from ...metric import EvalMetric, Loss, Accuracy

__all__ = ['Estimator']

Expand Down Expand Up @@ -62,44 +63,57 @@ def __init__(self, net,

if isinstance(loss, gluon.loss.Loss):
self.loss = [loss]
elif isinstance(loss, list) and all([isinstance(l, gluon.loss.Loss) for l in loss]):
self.loss = loss
else:
self.loss = loss or []
for l in self.loss:
if not isinstance(loss, gluon.loss.Loss):
raise ValueError("loss must be a Loss or a list of Loss, refer to gluon.loss.Loss")
raise ValueError("loss must be a Loss or a list of Loss, "
"refer to gluon.loss.Loss:{}".format(loss))

if isinstance(metrics, EvalMetric):
self.metrics = [metrics]
self.train_metrics = [metrics]
else:
self.metrics = metrics or []
for metric in self.metrics:
if not isinstance(metric, EvalMetric):
raise ValueError("metrics must be a Metric or a list of Metric, refer to mxnet.metric.EvalMetric")
self.train_metrics = metrics or []
if not all([isinstance(metric, EvalMetric) for metric in self.train_metrics]):
raise ValueError("metrics must be a Metric or a list of Metric, "
"refer to mxnet.metric.EvalMetric:{}".format(metrics))

# Use default mx.metric.Accuracy() for gluon.loss.SoftmaxCrossEntropyLoss()
if not self.train_metrics and any([isinstance(l, gluon.loss.SoftmaxCrossEntropyLoss) for l in self.loss]):
self.train_metrics = [Accuracy()]

# Use same metrics for validation
self.val_metrics = copy.deepcopy(self.train_metrics)

self.initializer = initializer
# store training statistics
self.train_stats = {}
self.train_stats['epochs'] = []
self.train_stats['learning_rate'] = []
# current step of the epoch
self.train_stats['step'] = ''
for metric in self.metrics:
for metric in self.train_metrics:
# record a history of metrics over each epoch
self.train_stats['train_' + metric.name] = []
# only record the latest metric numbers after each batch
self.train_stats['batch_' + metric.name] = 0.
self.loss_metrics = []
for metric in self.val_metrics:
self.train_stats['val_' + metric.name] = []
self.train_loss_metrics = []
self.val_loss_metrics = []
# using the metric wrapper for loss to record loss value
for l in self.loss:
self.loss_metrics.append(Loss(l.name))
self.train_loss_metrics.append(Loss(l.name))
self.val_loss_metrics.append(Loss(l.name))
self.train_stats['train_' + l.name] = []
self.train_stats['val_' + l.name] = []
# only record the latest loss numbers after each batch
self.train_stats['batch_' + l.name] = 0.

# handle context
if isinstance(context, Context):
self.context = [context]
if not context:
elif isinstance(context, list) and all([isinstance(c, Context) for c in context]):
self.context = context
elif not context:
if num_gpus() > 0:
# only use 1 GPU by default
if num_gpus() > 1:
Expand All @@ -109,8 +123,13 @@ def __init__(self, net,
self.context = [gpu(0)]
else:
self.context = [cpu()]
else:
raise ValueError("context must be a Context or a list of Context, "
"refer to mxnet.Context:{}".format(context))


# initialize the network
self.initializer = initializer
if self.initializer:
if self._is_initialized():
# if already initialized, re-init with user specified initializer
Expand All @@ -128,13 +147,13 @@ def __init__(self, net,
# handle trainers
if isinstance(trainers, gluon.Trainer):
self.trainers = [trainers]
else:
self.trainers = trainers or []
if not self.trainers:
elif not trainers:
warnings.warn("No trainer specified, default SGD optimizer "
"with learning rate 0.001 is used.")
self.trainers = [gluon.Trainer(self.net.collect_params(),
'sgd', {'learning_rate': 0.001})]
else:
raise ValueError("Invalid trainer specified, please provide a valid gluon.Trainer")

def _is_initialized(self):
param_dict = self.net.collect_params()
Expand All @@ -156,7 +175,48 @@ def _batch_fn(self, batch, ctx, is_iterator=False):
label = gluon.utils.split_and_load(label, ctx_list=ctx, batch_axis=0)
return data, label

def evaluate(self,
val_data,
batch_fn=None):
"""Evaluate model on validation data
Parameters
----------
val_data : DataLoader or DataIter
validation data with data and labels
batch_fn : function
custom batch function to extract data and label
from a data batch and load into contexts(devices)
"""

for metric in self.val_metrics + self.val_loss_metrics:
metric.reset()

for _, batch in enumerate(val_data):
if not batch_fn:
if isinstance(val_data, gluon.data.DataLoader):
data, label = self._batch_fn(batch, self.context)
elif isinstance(val_data, DataIter):
data, label = self._batch_fn(batch, self.context, is_iterator=True)
else:
raise ValueError("You are using a custom iteration, please also provide "
"batch_fn to extract data and label. Alternatively, you "
"can provide the data as gluon.data.DataLoader or "
"mx.io.DataIter")
else:
data, label = batch_fn(batch, self.context)
pred = [self.net(x) for x in data]
losses = []
for loss in self.loss:
losses.append([loss(y_hat, y) for y_hat, y in zip(pred, label)])
# update metrics
for metric in self.val_metrics:
metric.update(label, pred)
for loss, loss_metric, in zip(losses, self.val_loss_metrics):
loss_metric.update(0, [l for l in loss])

def fit(self, train_data,
val_data=None,
epochs=1,
batch_size=None,
event_handlers=None,
Expand Down Expand Up @@ -204,7 +264,7 @@ def fit(self, train_data,
for handler in event_handlers:
handler.epoch_begin()

for metric in self.metrics + self.loss_metrics:
for metric in self.train_metrics + self.train_loss_metrics:
metric.reset()

for i, batch in enumerate(train_data):
Expand All @@ -215,7 +275,9 @@ def fit(self, train_data,
data, label = self._batch_fn(batch, self.context, is_iterator=True)
else:
raise ValueError("You are using a custom iteration, please also provide "
"batch_fn to extract data and label")
"batch_fn to extract data and label. Alternatively, you "
"can provide the data as gluon.data.DataLoader or "
"mx.io.DataIter")
else:
data, label = batch_fn(batch, self.context)

Expand All @@ -233,11 +295,11 @@ def fit(self, train_data,
for l in loss:
l.backward()

# update metrics
for metric in self.metrics:
# update train metrics
for metric in self.train_metrics:
metric.update(label, pred)
self.train_stats['batch_' + metric.name] = metric.get()[1]
for loss, loss_metric, in zip(losses, self.loss_metrics):
for loss, loss_metric, in zip(losses, self.train_loss_metrics):
loss_metric.update(0, [l for l in loss])
self.train_stats['batch_' + loss_metric.name] = loss_metric.get()[1]

Expand All @@ -257,8 +319,14 @@ def fit(self, train_data,
for handler in event_handlers:
handler.batch_end()

for metric in self.metrics + self.loss_metrics:
if val_data:
self.evaluate(val_data, batch_fn)

for metric in self.train_metrics + self.train_loss_metrics:
self.train_stats['train_' + metric.name].append(metric.get()[1])
for metric in self.val_metrics + self.val_loss_metrics:
self.train_stats['val_' + metric.name].append(metric.get()[1])

# epoch end
for handler in event_handlers:
handler.epoch_end()
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/gluon/estimator/event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def epoch_end(self):
epoch = self._estimator.train_stats['epochs'][-1]
msg = '\n[Epoch %d] finished in %.3fs: ' % (epoch, epoch_time)
for key in self._estimator.train_stats.keys():
if key.startswith('train_') or key.startswith('test_'):
if key.startswith('train_') or key.startswith('val_'):
msg += key + ': ' + '%.4f ' % self._estimator.train_stats[key][epoch]
self.logger.info(msg)

Expand Down
Loading

0 comments on commit 02e7c9b

Please sign in to comment.