-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNet-1349][Fit API]Add validation support and unit tests for fit() API #14442
Changes from 6 commits
9f334da
c81132a
69e118b
7750027
1eafd3a
d9b7480
5d7b58e
7d9137a
353e3d3
b843f56
305d1bf
d07052a
abf6a68
f88515f
282957e
5f77df9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,6 +26,7 @@ | |
from ...context import Context, cpu, gpu, num_gpus | ||
from ...io import DataIter | ||
from ...metric import EvalMetric, Loss | ||
import copy | ||
|
||
__all__ = ['Estimator'] | ||
|
||
|
@@ -64,17 +65,21 @@ def __init__(self, net, | |
self.loss = [loss] | ||
else: | ||
self.loss = loss or [] | ||
if not self.loss: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think this if loop is not correct
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done! |
||
raise ValueError("No loss specified, refer to gluon.loss.Loss") | ||
for l in self.loss: | ||
if not isinstance(loss, gluon.loss.Loss): | ||
if not isinstance(l, gluon.loss.Loss): | ||
raise ValueError("loss must be a Loss or a list of Loss, refer to gluon.loss.Loss") | ||
|
||
if isinstance(metrics, EvalMetric): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use logic similar to above There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done! |
||
self.metrics = [metrics] | ||
self.train_metrics = [metrics] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you infer from the loss function? use 'Accuracy' as default when not passed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 'Accuracy' will only work for classification cases, for other cases it will give inaccurate resutls or even fail. Also I'm not sure how we can infer metrics from loss function as there isn't a direct correlation between them, do you have any suggestions? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should still infer metrics from known Loss functions (at least from the examples you know) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added Accuracy metric as default for SoftmaxCrossEntropy loss for now. Will add more in a followup PR. |
||
else: | ||
self.metrics = metrics or [] | ||
for metric in self.metrics: | ||
self.train_metrics = metrics or [] | ||
for metric in self.train_metrics: | ||
if not isinstance(metric, EvalMetric): | ||
raise ValueError("metrics must be a Metric or a list of Metric, refer to mxnet.metric.EvalMetric") | ||
# Use same metrics for validation | ||
self.test_metrics = copy.deepcopy(self.train_metrics) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. rename There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done! |
||
|
||
self.initializer = initializer | ||
# store training statistics | ||
|
@@ -83,16 +88,21 @@ def __init__(self, net, | |
self.train_stats['learning_rate'] = [] | ||
# current step of the epoch | ||
self.train_stats['step'] = '' | ||
for metric in self.metrics: | ||
for metric in self.train_metrics: | ||
# record a history of metrics over each epoch | ||
self.train_stats['train_' + metric.name] = [] | ||
# only record the latest metric numbers after each batch | ||
self.train_stats['batch_' + metric.name] = 0. | ||
self.loss_metrics = [] | ||
for metric in self.test_metrics: | ||
self.train_stats['val_' + metric.name] = [] | ||
self.train_loss_metrics = [] | ||
self.test_loss_metrics = [] | ||
# using the metric wrapper for loss to record loss value | ||
for l in self.loss: | ||
self.loss_metrics.append(Loss(l.name)) | ||
self.train_loss_metrics.append(Loss(l.name)) | ||
self.test_loss_metrics.append(Loss(l.name)) | ||
self.train_stats['train_' + l.name] = [] | ||
self.train_stats['val_' + l.name] = [] | ||
# only record the latest loss numbers after each batch | ||
self.train_stats['batch_' + l.name] = 0. | ||
|
||
|
@@ -130,11 +140,13 @@ def __init__(self, net, | |
self.trainers = [trainers] | ||
else: | ||
self.trainers = trainers or [] | ||
if not self.trainers: | ||
warnings.warn("No trainer specified, default SGD optimizer " | ||
"with learning rate 0.001 is used.") | ||
self.trainers = [gluon.Trainer(self.net.collect_params(), | ||
'sgd', {'learning_rate': 0.001})] | ||
if not self.trainers: | ||
warnings.warn("No trainer specified, default SGD optimizer " | ||
"with learning rate 0.001 is used.") | ||
self.trainers = [gluon.Trainer(self.net.collect_params(), | ||
'sgd', {'learning_rate': 0.001})] | ||
else: | ||
raise ValueError("Invalid trainer specified, please provide a valid gluon.Trainer") | ||
|
||
def _is_initialized(self): | ||
param_dict = self.net.collect_params() | ||
|
@@ -156,7 +168,33 @@ def _batch_fn(self, batch, ctx, is_iterator=False): | |
label = gluon.utils.split_and_load(label, ctx_list=ctx, batch_axis=0) | ||
return data, label | ||
|
||
def _evaluate(self, val_data, batch_fn=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what if i just want to validate on a single item? can i not pass X, y? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can be done by wrapping X, y in a dataloader/dataiter and passing it as val_data There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we expose this method? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done! |
||
for metric in self.test_metrics + self.test_loss_metrics: | ||
metric.reset() | ||
|
||
for i, batch in enumerate(val_data): | ||
if not batch_fn: | ||
if isinstance(val_data, gluon.data.DataLoader): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. move this if/else into into self._batch_fn There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this check needs to be before calling |
||
data, label = self._batch_fn(batch, self.context) | ||
elif isinstance(val_data, DataIter): | ||
data, label = self._batch_fn(batch, self.context, is_iterator=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same above |
||
else: | ||
raise ValueError("You are using a custom iteration, please also provide " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be helpful to the end user if could provide more detailed exception. you can append the below statement at the end, something like this: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for pointing it out, updated! |
||
"batch_fn to extract data and label") | ||
else: | ||
data, label = batch_fn(batch, self.context) | ||
pred = [self.net(x) for x in data] | ||
losses = [] | ||
for loss in self.loss: | ||
losses.append([loss(y_hat, y) for y_hat, y in zip(pred, label)]) | ||
# update metrics | ||
for metric in self.test_metrics: | ||
metric.update(label, pred) | ||
for loss, loss_metric, in zip(losses, self.test_loss_metrics): | ||
loss_metric.update(0, [l for l in loss]) | ||
|
||
def fit(self, train_data, | ||
val_data=None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this shouldn't be optional There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Users might want to train without a validation set. Although this is rare, still keeping it optional provides a bit of flexibility There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see a reason why users would not a validation dataset, its required to know that the model is not overfitting/ |
||
epochs=1, | ||
batch_size=None, | ||
event_handlers=None, | ||
|
@@ -192,6 +230,9 @@ def fit(self, train_data, | |
not any(isinstance(handler, LoggingHandler) for handler in event_handlers): | ||
event_handlers.append(LoggingHandler(self)) | ||
|
||
# Check for validation data | ||
do_validation = True if val_data else False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done! |
||
|
||
# training begin | ||
for handler in event_handlers: | ||
handler.train_begin() | ||
|
@@ -204,7 +245,7 @@ def fit(self, train_data, | |
for handler in event_handlers: | ||
handler.epoch_begin() | ||
|
||
for metric in self.metrics + self.loss_metrics: | ||
for metric in self.train_metrics + self.train_loss_metrics: | ||
metric.reset() | ||
|
||
for i, batch in enumerate(train_data): | ||
|
@@ -233,11 +274,11 @@ def fit(self, train_data, | |
for l in loss: | ||
l.backward() | ||
|
||
# update metrics | ||
for metric in self.metrics: | ||
# update train metrics | ||
for metric in self.train_metrics: | ||
metric.update(label, pred) | ||
self.train_stats['batch_' + metric.name] = metric.get()[1] | ||
for loss, loss_metric, in zip(losses, self.loss_metrics): | ||
for loss, loss_metric, in zip(losses, self.train_loss_metrics): | ||
loss_metric.update(0, [l for l in loss]) | ||
self.train_stats['batch_' + loss_metric.name] = loss_metric.get()[1] | ||
|
||
|
@@ -253,8 +294,14 @@ def fit(self, train_data, | |
for handler in event_handlers: | ||
handler.batch_end() | ||
|
||
for metric in self.metrics + self.loss_metrics: | ||
if do_validation: | ||
self._evaluate(val_data, batch_fn) | ||
|
||
for metric in self.train_metrics + self.train_loss_metrics: | ||
self.train_stats['train_' + metric.name].append(metric.get()[1]) | ||
for metric in self.test_metrics + self.test_loss_metrics: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we rename There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we can use |
||
self.train_stats['val_' + metric.name].append(metric.get()[1]) | ||
|
||
# epoch end | ||
for handler in event_handlers: | ||
handler.epoch_end() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this failed sanity check: import it first