-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add dataloader arg to Trainer.test() #1393
Comments
Hi! thanks for your contribution!, great first issue! |
I am in favour of adding this option, but first, lets see how it fits the API |
test is meant to ONLY operate on the test set. it’s meant to keep people from using the test set when they shouldn’t haha (ie: only right before publication or right before production use). additions that i’m not sure align well
additions that are good
|
btw I'm interested in how to "train a model using 5-fold cross-validation" in PL. |
Let's do this:
|
|
* Add test_dataloaders to test method * Remove test_dataloaders from .fit() * Fix code comment * Fix tests * Add test_dataloaders to test method (Lightning-AI#1393) * Fix failing tests * Update docs (Lightning-AI#1393)
Hey @rohitgr7! The link seems to be broken, could you point to any other resource? Thanks! |
🚀 Feature
It would be nice if you could use a model for inference using:
Trainer.test(model, test_dataloaders=test_loader)
Motivation
This will match the calling structure for
Trainer.fit()
and allow fortest
to be called on any dataset multiple timesPitch
Here's a use case. After training a model using 5-fold cross-validation, you may want to stack the 5 checkpoints across multiple models, which will require a) out-of-fold (OOF) predictions and b) the 5 test predictions (which will be averaged). It would be cool if a & b could be generated as follows:
Alternatives
Maybe I'm misunderstanding how
test
works and there is an easier way? Or perhaps the best way to do this is to write an inference function as you would in pure PyTorch?Additional context
The text was updated successfully, but these errors were encountered: