-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Potential Leakage of Information Across Folds in Kfold.py #12300
Comments
How are you loading the checkpoint? Are you using one of the checkpoints saved manually by the loop? You might want to use one |
I got same issue , one
|
To use 1 model checkpoint per fold, here is how I did it:
Now, model checkpoint for k-fold will work properly ;) |
I would suggest doing instead: class KFoldModelCheckpoint(ModelCheckpoint):
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if trainer.fit_loop.current_fold == self.fold:
super().on_train_epoch_end(trainer, pl_module) (also for which is more efficient as checkpointing is completely skipped then |
thanks @carmocca, I find that we can also override
|
That works too, but only for those who like the risk that |
This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team! |
🐛 Bug
I believe that there can potentially be some leakage of information across folds when changing some parameters in the Kfold.py script. Say that the user chooses to save every single checkpoint. After the first fold training finishes, the second fold uses the same checkpoint directory as the first fold. So if the second fold finishes training and the user decides to load the best checkpoint, the second fold may potentially load a checkpoint from the training of the first fold.
To Reproduce
Run the script Kfold.py
Expected behavior
We expect that the training process of multiple folds is independent of one another.
Environment
conda
,pip
, source):torch.__config__.show()
:Additional context
I think the solution may be to clear out previous checkpoints when starting out the new fold. We would also need to reset the checkpoint states (like reset minimum validation loss when advancing to the next fold).
cc @awaelchli @ananthsub @ninginthecloud @rohitgr7 @otaj @carmocca @justusschock
The text was updated successfully, but these errors were encountered: