Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Potential Leakage of Information Across Folds in Kfold.py #12300

Closed
JinLi711 opened this issue Mar 11, 2022 · 7 comments
Closed

Potential Leakage of Information Across Folds in Kfold.py #12300

JinLi711 opened this issue Mar 11, 2022 · 7 comments
Labels
checkpointing Related to checkpointing loops Related to the Loop API won't fix This will not be worked on

Comments

@JinLi711
Copy link

JinLi711 commented Mar 11, 2022

🐛 Bug

I believe that there can potentially be some leakage of information across folds when changing some parameters in the Kfold.py script. Say that the user chooses to save every single checkpoint. After the first fold training finishes, the second fold uses the same checkpoint directory as the first fold. So if the second fold finishes training and the user decides to load the best checkpoint, the second fold may potentially load a checkpoint from the training of the first fold.

To Reproduce

Run the script Kfold.py

Expected behavior

We expect that the training process of multiple folds is independent of one another.

Environment

  • PyTorch Lightning Version: 1.6.0dev
  • PyTorch Version: 1.10.0+cu102
  • Python version: 3.7.11
  • OS: Linux
  • CUDA/cuDNN version: Using CPU
  • GPU models and configuration:
  • How you installed PyTorch (conda, pip, source):
  • If compiling from source, the output of torch.__config__.show():
  • Any other relevant information:

Additional context

I think the solution may be to clear out previous checkpoints when starting out the new fold. We would also need to reset the checkpoint states (like reset minimum validation loss when advancing to the next fold).

cc @awaelchli @ananthsub @ninginthecloud @rohitgr7 @otaj @carmocca @justusschock

@carmocca
Copy link
Contributor

How are you loading the checkpoint? Are you using one of the checkpoints saved manually by the loop?
https://github.com/PyTorchLightning/pytorch-lightning/blob/1eff3b53c1ff9d362fc24a1e4fea6c0cfe78696b/pl_examples/loop_examples/kfold.py#L203
Or through ModelCheckpoint?

You might want to use one ModelCheckpoint instance per fold.

@carmocca carmocca added the checkpointing Related to checkpointing label Mar 21, 2022
@cvlearn913
Copy link

cvlearn913 commented Mar 23, 2022

I got same issue , one ModelCheckPoint seemd to be not work correctly, How to use one ModelCheckpoint instance for each fold?

How are you loading the checkpoint? Are you using one of the checkpoints saved manually by the loop?

https://github.com/PyTorchLightning/pytorch-lightning/blob/1eff3b53c1ff9d362fc24a1e4fea6c0cfe78696b/pl_examples/loop_examples/kfold.py#L203

Or through ModelCheckpoint?
You might want to use one ModelCheckpoint instance per fold.

@AlexTo
Copy link

AlexTo commented Apr 4, 2022

To use 1 model checkpoint per fold, here is how I did it:

  • In the model, log the metrics with different names for each fold, for e.g. val_loss should be f"fold_{fold}-val_loss"
def validation_step(self, batch, batch_idx):
    ...
    fold = self.trainer.fit_loop.current_fold 
    self.log(f"fold_{fold}-val_loss", loss.item(), on_step=False, on_epoch=True)
    ....
  • create multiple model checkpoint instances that monitor different fold val losses
model_checkpoints = [KFoldModelCheckpoint(
    filename="{" + f"fold_{f}-val_loss" + "}_{epoch}.pt",
    monitor=f"fold_{f}-val_loss",
    mode="min",
    every_n_epochs=1,
    save_top_k=3
) for f in range(num_folds)]
  • But note that the original ModelCheckpoint will throw an error because the model checkpoint for fold 0 can only monitor fold_0-val_loss so, during other folds, the metric fold_0-val_loss is not found. We can simply extend ModelCheckpoint to ignore folds that are not relevant
class KFoldModelCheckpoint(ModelCheckpoint):
    def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
        if self.save_top_k == 0:
            return
        # validate metric
        if self.monitor is not None:
            if self.monitor not in monitor_candidates:
                if "fold" in self.monitor: # if fold specific metrics are not found in monitor_candidates, just don't do anything
                    return
                else:
                    m = (
                        f"`ModelCheckpoint(monitor={self.monitor!r})` could not find the monitored key in the returned"
                        f" metrics: {list(monitor_candidates)}."
                        f" HINT: Did you call `log({self.monitor!r}, value)` in the `LightningModule`?"
                    )
                    if trainer.fit_loop.epoch_loop.val_loop._has_run:
                        raise MisconfigurationException(m)
                    warning_cache.warn(m)
            self._save_monitor_checkpoint(trainer, monitor_candidates)
        else:
            self._save_none_monitor_checkpoint(trainer, monitor_candidates)

Now, model checkpoint for k-fold will work properly ;)

@carmocca
Copy link
Contributor

carmocca commented Apr 4, 2022

But note that the original ModelCheckpoint will throw an error because the model checkpoint for fold 0 can only monitor fold_0-val_loss so, during other folds, the metric fold_0-val_loss is not found. We can simply extend ModelCheckpoint to ignore folds that are not relevant

I would suggest doing instead:

class KFoldModelCheckpoint(ModelCheckpoint):
    def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        if trainer.fit_loop.current_fold == self.fold:
            super().on_train_epoch_end(trainer, pl_module)

(also for on_train_batch_end and on_validation_end)

which is more efficient as checkpointing is completely skipped then

@AlexTo
Copy link

AlexTo commented Apr 6, 2022

thanks @carmocca, I find that we can also override _should_skip_saving_checkpoint so it looks even more semantically correct as _should_skip_saving_checkpoint is called at the beginning of on_train_batch_end, on_train_epoch_end and on_validation_end, hence, we only need to modify one function ;)
The KFoldModelCheckpoint can be refactored to be quite concise like this

class KFoldModelCheckpoint(ModelCheckpoint):

    def __init__(self, fold, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.fold = fold

    def _should_skip_saving_checkpoint(self, trainer: "pl.Trainer") -> bool:
        return trainer.fit_loop.current_fold != self.fold or super()._should_skip_saving_checkpoint(trainer)

@carmocca
Copy link
Contributor

carmocca commented Apr 7, 2022

That works too, but only for those who like the risk that _should_skip_saving_checkpoint could disappear (it's protected!)

@stale
Copy link

stale bot commented Jun 6, 2022

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Jun 6, 2022
@stale stale bot closed this as completed Jun 19, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
checkpointing Related to checkpointing loops Related to the Loop API won't fix This will not be worked on
Projects
None yet
Development

No branches or pull requests

5 participants