Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNet-1349][Fit API]Add validation support and unit tests for fit() API #14442

Merged
merged 16 commits into from
Mar 25, 2019
81 changes: 64 additions & 17 deletions python/mxnet/gluon/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# pylint: disable=wildcard-import
"""Gluon Estimator"""

import copy
import warnings

from .event_handler import LoggingHandler
Expand Down Expand Up @@ -64,17 +65,21 @@ def __init__(self, net,
self.loss = [loss]
else:
self.loss = loss or []
if not self.loss:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this if loop is not correct

if isinstance(loss, gluon.loss.Loss):
   self.loss = [loss]
elif isinstance(loss, list) and all([isinstance(l, gluon.loss.Loss) for l in loss]):
   self.loss = loss
else:
  raise ValueError("loss must be a Loss or a list of Loss, refer to gluon.loss.Loss:{}".format(loss))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

raise ValueError("No loss specified, refer to gluon.loss.Loss")
for l in self.loss:
if not isinstance(loss, gluon.loss.Loss):
if not isinstance(l, gluon.loss.Loss):
raise ValueError("loss must be a Loss or a list of Loss, refer to gluon.loss.Loss")

if isinstance(metrics, EvalMetric):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use logic similar to above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

self.metrics = [metrics]
self.train_metrics = [metrics]
Copy link
Member

@nswamy nswamy Mar 21, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you infer from the loss function? use 'Accuracy' as default when not passed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'Accuracy' will only work for classification cases, for other cases it will give inaccurate resutls or even fail. Also I'm not sure how we can infer metrics from loss function as there isn't a direct correlation between them, do you have any suggestions?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should still infer metrics from known Loss functions (at least from the examples you know)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added Accuracy metric as default for SoftmaxCrossEntropy loss for now. Will add more in a followup PR.

else:
self.metrics = metrics or []
for metric in self.metrics:
self.train_metrics = metrics or []
for metric in self.train_metrics:
if not isinstance(metric, EvalMetric):
raise ValueError("metrics must be a Metric or a list of Metric, refer to mxnet.metric.EvalMetric")
# Use same metrics for validation
self.test_metrics = copy.deepcopy(self.train_metrics)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename self.test_metrics-> self.val_metrics

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!


self.initializer = initializer
# store training statistics
Expand All @@ -83,16 +88,21 @@ def __init__(self, net,
self.train_stats['learning_rate'] = []
# current step of the epoch
self.train_stats['step'] = ''
for metric in self.metrics:
for metric in self.train_metrics:
# record a history of metrics over each epoch
self.train_stats['train_' + metric.name] = []
# only record the latest metric numbers after each batch
self.train_stats['batch_' + metric.name] = 0.
self.loss_metrics = []
for metric in self.test_metrics:
self.train_stats['val_' + metric.name] = []
self.train_loss_metrics = []
self.test_loss_metrics = []
# using the metric wrapper for loss to record loss value
for l in self.loss:
self.loss_metrics.append(Loss(l.name))
self.train_loss_metrics.append(Loss(l.name))
self.test_loss_metrics.append(Loss(l.name))
self.train_stats['train_' + l.name] = []
self.train_stats['val_' + l.name] = []
# only record the latest loss numbers after each batch
self.train_stats['batch_' + l.name] = 0.

Expand Down Expand Up @@ -130,11 +140,13 @@ def __init__(self, net,
self.trainers = [trainers]
else:
self.trainers = trainers or []
if not self.trainers:
warnings.warn("No trainer specified, default SGD optimizer "
"with learning rate 0.001 is used.")
self.trainers = [gluon.Trainer(self.net.collect_params(),
'sgd', {'learning_rate': 0.001})]
if not self.trainers:
warnings.warn("No trainer specified, default SGD optimizer "
"with learning rate 0.001 is used.")
self.trainers = [gluon.Trainer(self.net.collect_params(),
'sgd', {'learning_rate': 0.001})]
else:
raise ValueError("Invalid trainer specified, please provide a valid gluon.Trainer")

def _is_initialized(self):
param_dict = self.net.collect_params()
Expand All @@ -156,7 +168,33 @@ def _batch_fn(self, batch, ctx, is_iterator=False):
label = gluon.utils.split_and_load(label, ctx_list=ctx, batch_axis=0)
return data, label

def _evaluate(self, val_data, batch_fn=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if i just want to validate on a single item? can i not pass X, y?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be done by wrapping X, y in a dataloader/dataiter and passing it as val_data

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we expose this method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

for metric in self.test_metrics + self.test_loss_metrics:
metric.reset()

for _, batch in enumerate(val_data):
if not batch_fn:
if isinstance(val_data, gluon.data.DataLoader):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this if/else into into self._batch_fn

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this check needs to be before calling batch_fn as val_data is not available to it

data, label = self._batch_fn(batch, self.context)
elif isinstance(val_data, DataIter):
data, label = self._batch_fn(batch, self.context, is_iterator=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same above

else:
raise ValueError("You are using a custom iteration, please also provide "
Copy link
Contributor

@karan6181 karan6181 Mar 21, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be helpful to the end user if could provide more detailed exception. you can append the below statement at the end, something like this: or you can provide the data in terms of gluon.data.DataLoader or mx.io.DataIter. Please also change this statement in fit() method if you are changing it here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing it out, updated!

"batch_fn to extract data and label")
else:
data, label = batch_fn(batch, self.context)
pred = [self.net(x) for x in data]
losses = []
for loss in self.loss:
losses.append([loss(y_hat, y) for y_hat, y in zip(pred, label)])
# update metrics
for metric in self.test_metrics:
metric.update(label, pred)
for loss, loss_metric, in zip(losses, self.test_loss_metrics):
loss_metric.update(0, [l for l in loss])

def fit(self, train_data,
val_data=None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this shouldn't be optional

Copy link
Contributor Author

@abhinavs95 abhinavs95 Mar 19, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Users might want to train without a validation set. Although this is rare, still keeping it optional provides a bit of flexibility

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see a reason why users would not a validation dataset, its required to know that the model is not overfitting/

epochs=1,
batch_size=None,
event_handlers=None,
Expand Down Expand Up @@ -192,6 +230,9 @@ def fit(self, train_data,
not any(isinstance(handler, LoggingHandler) for handler in event_handlers):
event_handlers.append(LoggingHandler(self))

# Check for validation data
do_validation = True if val_data else False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!


# training begin
for handler in event_handlers:
handler.train_begin()
Expand All @@ -204,7 +245,7 @@ def fit(self, train_data,
for handler in event_handlers:
handler.epoch_begin()

for metric in self.metrics + self.loss_metrics:
for metric in self.train_metrics + self.train_loss_metrics:
metric.reset()

for i, batch in enumerate(train_data):
Expand Down Expand Up @@ -233,11 +274,11 @@ def fit(self, train_data,
for l in loss:
l.backward()

# update metrics
for metric in self.metrics:
# update train metrics
for metric in self.train_metrics:
metric.update(label, pred)
self.train_stats['batch_' + metric.name] = metric.get()[1]
for loss, loss_metric, in zip(losses, self.loss_metrics):
for loss, loss_metric, in zip(losses, self.train_loss_metrics):
loss_metric.update(0, [l for l in loss])
self.train_stats['batch_' + loss_metric.name] = loss_metric.get()[1]

Expand All @@ -253,8 +294,14 @@ def fit(self, train_data,
for handler in event_handlers:
handler.batch_end()

for metric in self.metrics + self.loss_metrics:
if do_validation:
self._evaluate(val_data, batch_fn)

for metric in self.train_metrics + self.train_loss_metrics:
self.train_stats['train_' + metric.name].append(metric.get()[1])
for metric in self.test_metrics + self.test_loss_metrics:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename test_metrics to val_metrics for consistency, since we are referring to them as validation stuff throughout

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can use test_metrics for both validation (at epoch end) and test (at train end) dataset

self.train_stats['val_' + metric.name].append(metric.get()[1])

# epoch end
for handler in event_handlers:
handler.epoch_end()
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/gluon/estimator/event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def epoch_end(self):
epoch = self._estimator.train_stats['epochs'][-1]
msg = '\n[Epoch %d] finished in %.3fs: ' % (epoch, epoch_time)
for key in self._estimator.train_stats.keys():
if key.startswith('train_') or key.startswith('test_'):
if key.startswith('train_') or key.startswith('val_'):
msg += key + ': ' + '%.4f ' % self._estimator.train_stats[key][epoch]
self.logger.info(msg)

Expand Down
Loading