-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNet-1349][Fit API]Add validation support and unit tests for fit() API #14442
Conversation
__all__ = ['Estimator'] | ||
|
||
|
||
class Estimator(object): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wasn't this class included in this PR : #14346 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I have made changes in it for validation support.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd recommend either adding the changes for validation metrics as a commit to the existing PR #14346 , OR, wait until that PR gets merged, and add this as a singular commit so that the diff is incremental.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@piyushghai I asked @abhinavs95 to open this PR so I can review it early. We can do a rebase and resolve the conflict once parent PR is merged.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your contribution! generally looks good! added a few comments.
Remember to rebase and resolve the conflict once parent PR is merged.
label = gluon.utils.split_and_load(label, ctx_list=ctx, batch_axis=0) | ||
return data, label | ||
|
||
def _test(self, val_data, batch_fn=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe rename to _evaluate(self, eval_data)
, as this can be used for both validation (on epoch end) and test(on train end) dataset.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
self.train_loss_metrics.append(Loss(l.name)) | ||
self.test_loss_metrics.append(Loss(l.name)) | ||
self.train_stats['train_' + l.name] = [] | ||
self.train_stats['test_' + l.name] = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use val_
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
# only record the latest metric numbers after each batch | ||
self.train_stats['batch_' + metric.name] = 0. | ||
for metric in self.test_metrics: | ||
self.train_stats['test_' + metric.name] = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use val_
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
msg = '\n[Epoch %d] finished in %.3fs: ' % (epoch, epoch_time) | ||
for key in self._estimator.train_stats.keys(): | ||
if do_validation: | ||
if key.startswith('train_') or key.startswith('test_'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can remove the if else and no need to pass do_validation
. The logic can simply be if key starts with train
or val
, log it, even if no validation is done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done! Also moved val stats update loop to fit() to avoid key error. If no val set is passed 'nan' will be logged for val stats.
@mxnet-label-bot add [Gluon, Test] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, please fix the CI failure
@@ -26,6 +26,7 @@ | |||
from ...context import Context, cpu, gpu, num_gpus | |||
from ...io import DataIter | |||
from ...metric import EvalMetric, Loss | |||
import copy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this failed sanity check: import it first
self.train_stats['train_' + metric.name].append(metric.get()[1]) | ||
for metric in self.test_metrics + self.test_loss_metrics: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we rename test_metrics
to val_metrics
for consistency, since we are referring to them as validation stuff throughout
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can use test_metrics
for both validation (at epoch end) and test (at train end) dataset
loss=loss, | ||
trainers=trainer, | ||
context=ctx) | ||
est.fit(train_data=train_data, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are we asserting against here ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this checks if the estimator works with no metric specified, doesn't throw any error/warning when its successful
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So there's a particular happy path that you're trying to test here
@@ -64,17 +65,21 @@ def __init__(self, net, | |||
self.loss = [loss] | |||
else: | |||
self.loss = loss or [] | |||
if not self.loss: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think this if loop is not correct
if isinstance(loss, gluon.loss.Loss):
self.loss = [loss]
elif isinstance(loss, list) and all([isinstance(l, gluon.loss.Loss) for l in loss]):
self.loss = loss
else:
raise ValueError("loss must be a Loss or a list of Loss, refer to gluon.loss.Loss:{}".format(loss))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
if not isinstance(metric, EvalMetric): | ||
raise ValueError("metrics must be a Metric or a list of Metric, refer to mxnet.metric.EvalMetric") | ||
# Use same metrics for validation | ||
self.test_metrics = copy.deepcopy(self.train_metrics) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rename self.test_metrics
-> self.val_metrics
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
for l in self.loss: | ||
if not isinstance(loss, gluon.loss.Loss): | ||
if not isinstance(l, gluon.loss.Loss): | ||
raise ValueError("loss must be a Loss or a list of Loss, refer to gluon.loss.Loss") | ||
|
||
if isinstance(metrics, EvalMetric): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use logic similar to above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
@@ -156,7 +168,33 @@ def _batch_fn(self, batch, ctx, is_iterator=False): | |||
label = gluon.utils.split_and_load(label, ctx_list=ctx, batch_axis=0) | |||
return data, label | |||
|
|||
def _evaluate(self, val_data, batch_fn=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if i just want to validate on a single item? can i not pass X, y?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can be done by wrapping X, y in a dataloader/dataiter and passing it as val_data
if isinstance(val_data, gluon.data.DataLoader): | ||
data, label = self._batch_fn(batch, self.context) | ||
elif isinstance(val_data, DataIter): | ||
data, label = self._batch_fn(batch, self.context, is_iterator=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same above
|
||
for _, batch in enumerate(val_data): | ||
if not batch_fn: | ||
if isinstance(val_data, gluon.data.DataLoader): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move this if/else into into self._batch_fn
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this check needs to be before calling batch_fn
as val_data
is not available to it
@@ -156,7 +168,33 @@ def _batch_fn(self, batch, ctx, is_iterator=False): | |||
label = gluon.utils.split_and_load(label, ctx_list=ctx, batch_axis=0) | |||
return data, label | |||
|
|||
def _evaluate(self, val_data, batch_fn=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we expose this method?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
@@ -192,6 +230,9 @@ def fit(self, train_data, | |||
not any(isinstance(handler, LoggingHandler) for handler in event_handlers): | |||
event_handlers.append(LoggingHandler(self)) | |||
|
|||
# Check for validation data | |||
do_validation = True if val_data else False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
def fit(self, train_data, | ||
val_data=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this shouldn't be optional
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Users might want to train without a validation set. Although this is rare, still keeping it optional provides a bit of flexibility
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see a reason why users would not a validation dataset, its required to know that the model is not overfitting/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@roywei @nswamy @piyushghai Thanks for the review. I have addressed all the comments.
|
||
if isinstance(metrics, EvalMetric): | ||
self.metrics = [metrics] | ||
self.train_metrics = [metrics] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you infer from the loss function? use 'Accuracy' as default when not passed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
'Accuracy' will only work for classification cases, for other cases it will give inaccurate resutls or even fail. Also I'm not sure how we can infer metrics from loss function as there isn't a direct correlation between them, do you have any suggestions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should still infer metrics from known Loss functions (at least from the examples you know)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added Accuracy metric as default for SoftmaxCrossEntropy loss for now. Will add more in a followup PR.
# record a history of metrics over each epoch | ||
self.train_stats['train_' + metric.name] = [] | ||
# only record the latest metric numbers after each batch | ||
self.train_stats['batch_' + metric.name] = 0. | ||
self.loss_metrics = [] | ||
for metric in self.val_metrics: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we have one for loop for self.train_metrics
and self.val_metrics
since value for both the parameter is same according to line 82. Something like this for train_m, val_m in zip(self.train_metrics, self.val_metrics)
. Though zip()
operator stop after exhausting shorter array but since both the array are of same length, we can use zip()
operator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We want to keep the train and val metrics separate here. Currently we are using the same metrics for val and train but future updates may involve separate user specified val metrics in which case combining this update loop won't work.
elif isinstance(val_data, DataIter): | ||
data, label = self._batch_fn(batch, self.context, is_iterator=True) | ||
else: | ||
raise ValueError("You are using a custom iteration, please also provide " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be helpful to the end user if could provide more detailed exception. you can append the below statement at the end, something like this: or you can provide the data in terms of gluon.data.DataLoader or mx.io.DataIter.
Please also change this statement in fit()
method if you are changing it here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing it out, updated!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you fix the context validation ? currently if its a list of contexts it will fail.
@@ -77,6 +77,10 @@ def __init__(self, net, | |||
raise ValueError("metrics must be a Metric or a list of Metric, " | |||
"refer to mxnet.metric.EvalMetric:{}".format(metrics)) | |||
|
|||
# Use default mx.metric.Accuracy() for gluon.loss.SoftmaxCrossEntropyLoss() | |||
if not self.train_metrics and any([isinstance(l, gluon.loss.SoftmaxCrossEntropyLoss) for l in self.loss]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lets get this from a map of Loss->[default metrics] in the next version.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, tracking it using JIRA issue: https://issues.apache.org/jira/browse/MXNET-1364
…API (apache#14442) * added estimator unittests * add more tests for estimator * added validation logic * added error handlers, unittests * improve val stats * fix pylint * fix pylint * update unit test * fix tests * fix tests * updated metrics, val logic * trigger ci * trigger ci * update metric, batch_fn error handler * update context logic, add default metric
…API (apache#14442) * added estimator unittests * add more tests for estimator * added validation logic * added error handlers, unittests * improve val stats * fix pylint * fix pylint * update unit test * fix tests * fix tests * updated metrics, val logic * trigger ci * trigger ci * update metric, batch_fn error handler * update context logic, add default metric
…API (apache#14442) * added estimator unittests * add more tests for estimator * added validation logic * added error handlers, unittests * improve val stats * fix pylint * fix pylint * update unit test * fix tests * fix tests * updated metrics, val logic * trigger ci * trigger ci * update metric, batch_fn error handler * update context logic, add default metric
…API (apache#14442) * added estimator unittests * add more tests for estimator * added validation logic * added error handlers, unittests * improve val stats * fix pylint * fix pylint * update unit test * fix tests * fix tests * updated metrics, val logic * trigger ci * trigger ci * update metric, batch_fn error handler * update context logic, add default metric
* [MXNet-1334][Fit API]base class for estimator and eventhandler (#14346) * base class for estimator and eventhandler * add license * add event handlers * fix pylint * improve arg check * fix pylint * add unit tests * Fixed issue where the estimator was printing beyond the dataset size … (#14464) * Fixed issue where the estimator was printing beyond the dataset size for the last batch * Added comments * Nudge to CI * [MXNet-1349][Fit API]Add validation support and unit tests for fit() API (#14442) * added estimator unittests * add more tests for estimator * added validation logic * added error handlers, unittests * improve val stats * fix pylint * fix pylint * update unit test * fix tests * fix tests * updated metrics, val logic * trigger ci * trigger ci * update metric, batch_fn error handler * update context logic, add default metric * [MXNet-1340][Fit API]Update train stats (#14494) * add train history * update history * update test * avoid calling empty methods * remove train history object * fix pylint * add unit test * fix test * update categorize handlers * [MXNet-1375][Fit API]Added RNN integration test for fit() API (#14547) * Added RNN integration test for fit() API * Addressed review comments: change in JenkinFile, tmp directory, ctx with condense if/else, renamed imports * CPU test doesn't require nvidiadocker container * Modified the structure by removing the redundant code * [MXNet-1343][Fit API]Add CNN integration test for fit() API (#14405) * added cnn intg tests for fit api * updated cnn intg tests * added functions for nightly test * updated runtime_function * updated intg tests * updated init, datapath, refs * added validation data * update cpu test * refactor code * updated context * [MXNET-1344, 1346][FIT API] Retrieve Batch size and Logging verbose support for Gluon fit() API (#14587) * Retrieve Batch size and Logging verbose support for Gluon fit() API * NIT changes * Addressed review comments: shifted the batch size code to a separate method, sentence correction * Modified unittest * removed redundant parameter * Resolve CI test failure * only support DataLoader for now, future PRs will include DataIter to DataLoader converter * Get the number of samples from shape attribute instead of length due to low space complexity * Simplified batch size retrieval code * removed batch_size parameter from fit() method and fixed the tests * Verbose exception handling * Assigning constant to a verbose * Modified exception message * Resolved undefined class reference * Addressed review comments: Modified verbose level names, docs, variable names * Update estimator.py * move estimator to contrib (#14633) * move to gluon contrib (#14635) * [Fit API] improve event handlers (#14685) * improve event handlers * update tests * passing weakref of estimator * fix unit test * fix test * fix pylint * fix test * fix pylint * move default metric logic * combine nightly tests * [MXNET-1396][Fit-API] Update default handler logic (#14765) * move to nightly for binaries * update default handler * fix pylint * trigger ci * trigger ci * [Fit API] update estimator (#14849) * address comments * add comment * check available context * fix bug * change cpu check * [Fit-API] Adress PR comments (#14885) * address comments * update checkpoint * test symbol save * address comments * add resume * update doc and resume checkpoint * update docs * trigger ci * trigger ci
…API (#14442) * added estimator unittests * add more tests for estimator * added validation logic * added error handlers, unittests * improve val stats * fix pylint * fix pylint * update unit test * fix tests * fix tests * updated metrics, val logic * trigger ci * trigger ci * update metric, batch_fn error handler * update context logic, add default metric
* [MXNet-1334][Fit API]base class for estimator and eventhandler (apache#14346) * base class for estimator and eventhandler * add license * add event handlers * fix pylint * improve arg check * fix pylint * add unit tests * Fixed issue where the estimator was printing beyond the dataset size … (apache#14464) * Fixed issue where the estimator was printing beyond the dataset size for the last batch * Added comments * Nudge to CI * [MXNet-1349][Fit API]Add validation support and unit tests for fit() API (apache#14442) * added estimator unittests * add more tests for estimator * added validation logic * added error handlers, unittests * improve val stats * fix pylint * fix pylint * update unit test * fix tests * fix tests * updated metrics, val logic * trigger ci * trigger ci * update metric, batch_fn error handler * update context logic, add default metric * [MXNet-1340][Fit API]Update train stats (apache#14494) * add train history * update history * update test * avoid calling empty methods * remove train history object * fix pylint * add unit test * fix test * update categorize handlers * [MXNet-1375][Fit API]Added RNN integration test for fit() API (apache#14547) * Added RNN integration test for fit() API * Addressed review comments: change in JenkinFile, tmp directory, ctx with condense if/else, renamed imports * CPU test doesn't require nvidiadocker container * Modified the structure by removing the redundant code * [MXNet-1343][Fit API]Add CNN integration test for fit() API (apache#14405) * added cnn intg tests for fit api * updated cnn intg tests * added functions for nightly test * updated runtime_function * updated intg tests * updated init, datapath, refs * added validation data * update cpu test * refactor code * updated context * [MXNET-1344, 1346][FIT API] Retrieve Batch size and Logging verbose support for Gluon fit() API (apache#14587) * Retrieve Batch size and Logging verbose support for Gluon fit() API * NIT changes * Addressed review comments: shifted the batch size code to a separate method, sentence correction * Modified unittest * removed redundant parameter * Resolve CI test failure * only support DataLoader for now, future PRs will include DataIter to DataLoader converter * Get the number of samples from shape attribute instead of length due to low space complexity * Simplified batch size retrieval code * removed batch_size parameter from fit() method and fixed the tests * Verbose exception handling * Assigning constant to a verbose * Modified exception message * Resolved undefined class reference * Addressed review comments: Modified verbose level names, docs, variable names * Update estimator.py * move estimator to contrib (apache#14633) * move to gluon contrib (apache#14635) * [Fit API] improve event handlers (apache#14685) * improve event handlers * update tests * passing weakref of estimator * fix unit test * fix test * fix pylint * fix test * fix pylint * move default metric logic * combine nightly tests * [MXNET-1396][Fit-API] Update default handler logic (apache#14765) * move to nightly for binaries * update default handler * fix pylint * trigger ci * trigger ci * [Fit API] update estimator (apache#14849) * address comments * add comment * check available context * fix bug * change cpu check * [Fit-API] Adress PR comments (apache#14885) * address comments * update checkpoint * test symbol save * address comments * add resume * update doc and resume checkpoint * update docs * trigger ci * trigger ci
…API (apache#14442) * added estimator unittests * add more tests for estimator * added validation logic * added error handlers, unittests * improve val stats * fix pylint * fix pylint * update unit test * fix tests * fix tests * updated metrics, val logic * trigger ci * trigger ci * update metric, batch_fn error handler * update context logic, add default metric
Description
Adding validation support and unit tests for fit() API.
This PR depends on the parent PR for fit() API #14346
JIRA epic: https://issues.apache.org/jira/projects/MXNET/issues/MXNET-1333
Design: https://cwiki.apache.org/confluence/display/MXNET/Gluon+Fit+API+-+Tech+Design
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments