Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resume training from the last checkpoint #5325

Closed
sourabh-nutonomy opened this issue Jan 2, 2021 · 31 comments
Closed

Resume training from the last checkpoint #5325

sourabh-nutonomy opened this issue Jan 2, 2021 · 31 comments
Assignees
Labels
fault tolerance feature Is an improvement or enhancement help wanted Open to be worked on
Milestone

Comments

@sourabh-nutonomy
Copy link

sourabh-nutonomy commented Jan 2, 2021

🐛 Bug

What am I trying to do?

  1. Create a ModelCheckpoint callback with save_last=True.
  2. Interrupt training the model in the middle of an an epoch.
  3. Restart training using the resume_from_checkpoint argument of the Trainer.

Expected Behavior

The Trainer is supposed to restore the entire training state (epoch, step, LR scheduler) according to the documentation here.

Actual Behavior

If the training is interrupted during an epoch, the ModelCheckpoint callback correctly saves the model and the training state. However, when we resume training, the training actually starts from the next epoch. So let's say we interrupted training when 20% of the first epoch had finished. When we resume training, the trainer actually starts from the second epoch, thereby skipping 80% of the steps of the first epoch.

Please reproduce using the BoringModel

https://colab.research.google.com/drive/1f32UPGA8eINrz17ykt9krCfjllI3krOI?usp=sharing

To Reproduce

Please follow the steps in the BoringModel above. I am not sure if the tensorboard snippets will be visible so I have taken screenshots here:

Clean Run

clean_run

Interrupted Run

interrupted_run

Environment

  • CUDA:
    • GPU:
      • Tesla T4
    • available: True
    • version: 10.1
  • Packages:
    • numpy: 1.19.4
    • pyTorch_debug: True
    • pyTorch_version: 1.7.0+cu101
    • pytorch-lightning: 1.1.2
    • tqdm: 4.41.1
  • System:
    • OS: Linux
    • architecture:
      • 64bit
    • processor: x86_64
    • python: 3.6.9
    • version: Proposal for help #1 SMP Thu Jul 23 08:00:38 PDT 2020

cc @Borda @carmocca @justusschock @awaelchli @ninginthecloud

@sourabh-nutonomy sourabh-nutonomy added bug Something isn't working help wanted Open to be worked on labels Jan 2, 2021
@github-actions
Copy link
Contributor

github-actions bot commented Jan 2, 2021

Hi! thanks for your contribution!, great first issue!

@tchaton tchaton added this to the 1.1.x milestone Jan 4, 2021
@tchaton tchaton added the priority: 0 High priority task label Jan 4, 2021
@Borda Borda self-assigned this Jan 4, 2021
@edenlightning
Copy link
Contributor

@Borda any updates?

@Borda Borda added priority: 1 Medium priority task and removed priority: 0 High priority task labels Jan 12, 2021
@Borda
Copy link
Member

Borda commented Jan 12, 2021

@sourabh-nutonomy regarding your scenario what do you expect to happen, continue at the same state (global step) as it was interrupted, e.g. 20% of the first epoch, or train the first epoch again?
One case is if you use shuffle for data there is no simple way how to just continue training unless you seed your training...

@sourabh-nutonomy
Copy link
Author

@Borda Yes, I understand that we cannot shuffle the data the same way without seeding the training. But I think that can be left to the users to figure out. They can either seed their training or they can ignore it if they want to. But I am expecting that the training restarts at the same global step when the interruption happened (so after 20% of the first epoch in the example). I can also describe my use-case which would explain the need to start at the same global step.

We currently use AWS spot instances to train our nets. They can go down anytime so we need to resume training from the same global step when the experiment restarts. If we start from the beginning of the epoch, in the worst case we can get stuck in an infinite loop where we keep retraining the same epoch (if the spot instances keep restarting). If we start from the next epoch (like what is happening currently), we would miss lots of gradient updates. Like we are missing 80% of steps in epoch 1 in the example we are discussing. Now if the experiment restarts multiple times, it will further exacerbate the problem.

Even when we run our experiments on our in-house cluster and not spot instances, our datasets are huge. So if we retrain the first epoch, we loose a few hours.

@Borda
Copy link
Member

Borda commented Jan 15, 2021

I see it would be great to allow it, but I would rather classify it as a feature instead of bug because this is expected behavior...
Just checking and there will be multiple places in the codebase that need to be changed accordingly
https://github.com/PyTorchLightning/pytorch-lightning/blob/9ebbfece5e2c56bb5300cfffafb129e399492469/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py#L214

@Borda Borda added feature Is an improvement or enhancement and removed bug Something isn't working labels Jan 15, 2021
@Borda Borda modified the milestones: 1.1.x, 1.2 Jan 15, 2021
@Alex-nutonomy
Copy link

@Borda Can you confirm this will be part of Lightning 1.2? Any timeline on that? Not having this feature is blocking us from adopting Lightning.

@edenlightning
Copy link
Contributor

Unfortunately, it won't be available in 1.2 but we are prioritizing this feature for our next release!!

@edenlightning edenlightning added priority: 0 High priority task and removed priority: 1 Medium priority task labels Feb 9, 2021
@junjy007
Copy link

junjy007 commented Mar 7, 2021

I got into this issue from the one requesting to be able to save every-n-steps link.

Restraining checkpointing at every epoch is hugely innocent, especially when the dataset is large, which is the main reason I chose pytorch-lighting to organise the experiments in a more systematical way.

@ananthsub
Copy link
Contributor

I think #6429 is a requirement for this. @carmocca @awaelchli @justusschock I'll update that issue with a more concrete proposal

@toliz
Copy link

toliz commented Mar 25, 2021

What if you seed everything with a random number under the hood and then save this seed in the module's checkpoint?

Maybe that could keep runs stochastic but reproducible.

@carmocca
Copy link
Contributor

What if you seed everything with a random number under the hood and then save this seed in the module's checkpoint?

Sure, this is one of the things required to support this, but not the only one!

@edenlightning edenlightning removed the priority: 0 High priority task label Apr 7, 2021
@edenlightning edenlightning modified the milestones: v1.3, v1.4 Apr 27, 2021
@GiuliaLanzillotta
Copy link

Any news on this one?

@carmocca
Copy link
Contributor

Still work in progress 📈

@cn4750
Copy link

cn4750 commented Jun 26, 2021

I have also just run into this issue. Hopefully support for it in the next release. 👍

@edenlightning edenlightning modified the milestones: v1.4, v1.5 Jul 6, 2021
@Keiku
Copy link
Contributor

Keiku commented Jul 12, 2021

Is resume training possible with the current version using the iter-based method? Is there a way to calculate near results in the current version? (Assuming that the seed is different and it is not reproduced)

Personally, the priority of this issue is quite high.

@justusschock
Copy link
Member

justusschock commented Jul 12, 2021

@Keiku There are several open PRs related to this. So far this is not yet supported, but we are actively working on it.

@Keiku
Copy link
Contributor

Keiku commented Jul 17, 2021

Probably due to the seed problem, the validation accuracy are different, but it seems that resume training was possible.

Clean Run (20 epochs on CIFAR10 Training)

Validation Acc: 85.18

Interrupted Run (Interrupted in the middle of 10 epochs on CIFAR10 Training)

Validation Acc: 84.9

I'm using pytorch-lightning == 1.3.8, but with the current functionality, I can do this.

@awaelchli
Copy link
Contributor

Fault tolerant training will be an experimental feature in 1.5. First version of the docs were recently added here. #9130 will track the progress of the remaining work.

@Wuziyi616
Copy link

Hi @Keiku, can you kindly share how you do to resume training? Directly using Trainer(resume_from_checkpoint=xxx)?

@Keiku
Copy link
Contributor

Keiku commented Oct 14, 2021

@Wuziyi616 Please refer to my implementation of cifar10.

Keiku/PyTorch-Lightning-CIFAR10: "Not too complicated" training code for CIFAR-10 by PyTorch Lightning https://github.com/Keiku/PyTorch-Lightning-CIFAR10

@Wuziyi616
Copy link

Great, thanks! So far I think the major modification you did is adding the resume_from_checkpoint argument when creating Trainer, which I tried and seems to work well now. Like if you interrupt in the middle of epoch 10, then by loading the checkpoint which is saved at the end of epoch 9, you can re-start training from the beginning of epoch 10. But I think one point in this issue is that, we actually want to resume training from exactly where we are interrupted (the middle of epoch 10). Since if we have a very large dataset, we don't want to discard the iterations we've trained. And this feature seems will be supported in Pytorch-Lightning 1.5 (not yet) ...

Anyway, thanks for your help!

@Keiku
Copy link
Contributor

Keiku commented Oct 14, 2021

@Wuziyi616 I have the same understanding as you in terms of this problem. You may consider saving checkpoints frequently for each small iteration (epoch). I am also waiting for the fault tolerant training feature of pytorch-lightning.

@Wuziyi616
Copy link

Yes, I completely agree with you. Thanks a lot!

@AntreasAntoniou
Copy link

I think this feature is vital for fault tolerant training. Where is this currently at?

@Keiku
Copy link
Contributor

Keiku commented Feb 24, 2022

@AntreasAntoniou See this comment. #5325 (comment)

@carmocca
Copy link
Contributor

The docs for fault-tolerance are here:
https://pytorch-lightning.readthedocs.io/en/stable/advanced/fault_tolerant_training.html

And there's a proposal to simplify the checkpoint management: #11912

@mfoglio
Copy link

mfoglio commented Jun 2, 2022

Hi @carmocca , I looked at the link you posted. I can't find an example. How can I resume the training after a failure?

@carmocca
Copy link
Contributor

You want to set trainer.fit(..., ckpt_path="a_ckpt_to_resume_from").

@himanshudoi
Copy link

himanshudoi commented Nov 18, 2022

Hi I'm still facing a KeyError: 'radam_buffer' on the Trainer.fit command on passing the ckpt_path of an interrupted checkpoint.

pytorch-forecasting==0.10.1
pytorch-lightning==1.8.1

Error trace:

 `KeyError                                  Traceback (most recent call last)
/tmp/ipykernel_10535/1750406609.py in <module>
      4     train_dataloaders=train_dataloader,
      5     val_dataloaders=val_dataloader,
----> 6     ckpt_path=best_model_path,
      7 )

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    581         self.strategy._lightning_module = model
    582         call._call_and_handle_interrupt(
--> 583             self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    584         )
    585 

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/call.py in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     36             return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
     37         else:
---> 38             return trainer_fn(*args, **kwargs)
     39 
     40     except _TunerExitException:

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    622             model_connected=self.lightning_module is not None,
    623         )
--> 624         self._run(model, ckpt_path=self.ckpt_path)
    625 
    626         assert self.state.stopped

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _run(self, model, ckpt_path)
   1055         # restore optimizers, etc.
   1056         log.detail(f"{self.__class__.__name__}: restoring training state")
-> 1057         self._checkpoint_connector.restore_training_state()
   1058 
   1059         self._checkpoint_connector.resume_end()

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py in restore_training_state(self)
    294         if self.trainer.state.fn == TrainerFn.FITTING:
    295             # restore optimizers and schedulers state
--> 296             self.restore_optimizers_and_schedulers()
    297 
    298     def restore_precision_plugin_state(self) -> None:

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py in restore_optimizers_and_schedulers(self)
    409                     " This is probably due to `ModelCheckpoint.save_weights_only` being set to `True`."
    410                 )
--> 411             self.restore_optimizers()
    412 
    413         if "lr_schedulers" not in self._loaded_checkpoint:

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py in restore_optimizers(self)
    424 
    425         # restore the optimizers
--> 426         self.trainer.strategy.load_optimizer_state_dict(self._loaded_checkpoint)
    427 
    428     def restore_lr_schedulers(self) -> None:

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/strategies/strategy.py in load_optimizer_state_dict(self, checkpoint)
    366         optimizer_states = checkpoint["optimizer_states"]
    367         for optimizer, opt_state in zip(self.optimizers, optimizer_states):
--> 368             optimizer.load_state_dict(opt_state)
    369             _optimizer_to_device(optimizer, self.root_device)
    370 

/opt/conda/lib/python3.7/site-packages/torch/optim/optimizer.py in load_state_dict(self, state_dict)
    242         param_groups = [
    243             update_group(g, ng) for g, ng in zip(groups, saved_groups)]
--> 244         self.__setstate__({'state': state, 'param_groups': param_groups})
    245 
    246     def zero_grad(self, set_to_none: bool = False):

/opt/conda/lib/python3.7/site-packages/pytorch_forecasting/optim.py in __setstate__(self, state)
    131     def __setstate__(self, state: dict) -> None:
    132         super().__setstate__(state)
--> 133         self.radam_buffer = state["radam_buffer"]
    134         self.alpha = state["alpha"]
    135         self.k = state["k"]

KeyError: 'radam_buffer'`

@carmocca
Copy link
Contributor

@himanshudoi It looks like you are trying to reload an optimizer with a checkpoint that used a different optimizer class

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
fault tolerance feature Is an improvement or enhancement help wanted Open to be worked on
Projects
None yet
Development

No branches or pull requests