-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Resume training from the last checkpoint #5325
Comments
Hi! thanks for your contribution!, great first issue! |
At least, The issue may be in training loop. |
@Borda any updates? |
@sourabh-nutonomy regarding your scenario what do you expect to happen, continue at the same state (global step) as it was interrupted, e.g. 20% of the first epoch, or train the first epoch again? |
@Borda Yes, I understand that we cannot shuffle the data the same way without seeding the training. But I think that can be left to the users to figure out. They can either seed their training or they can ignore it if they want to. But I am expecting that the training restarts at the same global step when the interruption happened (so after 20% of the first epoch in the example). I can also describe my use-case which would explain the need to start at the same global step. We currently use AWS spot instances to train our nets. They can go down anytime so we need to resume training from the same global step when the experiment restarts. If we start from the beginning of the epoch, in the worst case we can get stuck in an infinite loop where we keep retraining the same epoch (if the spot instances keep restarting). If we start from the next epoch (like what is happening currently), we would miss lots of gradient updates. Like we are missing 80% of steps in epoch 1 in the example we are discussing. Now if the experiment restarts multiple times, it will further exacerbate the problem. Even when we run our experiments on our in-house cluster and not spot instances, our datasets are huge. So if we retrain the first epoch, we loose a few hours. |
I see it would be great to allow it, but I would rather classify it as a feature instead of bug because this is expected behavior... |
@Borda Can you confirm this will be part of Lightning 1.2? Any timeline on that? Not having this feature is blocking us from adopting Lightning. |
Unfortunately, it won't be available in 1.2 but we are prioritizing this feature for our next release!! |
I got into this issue from the one requesting to be able to save every-n-steps link. Restraining checkpointing at every epoch is hugely innocent, especially when the dataset is large, which is the main reason I chose pytorch-lighting to organise the experiments in a more systematical way. |
I think #6429 is a requirement for this. @carmocca @awaelchli @justusschock I'll update that issue with a more concrete proposal |
What if you seed everything with a random number under the hood and then save this seed in the module's checkpoint? Maybe that could keep runs stochastic but reproducible. |
Sure, this is one of the things required to support this, but not the only one! |
Any news on this one? |
Still work in progress 📈 |
I have also just run into this issue. Hopefully support for it in the next release. 👍 |
Is resume training possible with the current version using the iter-based method? Is there a way to calculate near results in the current version? (Assuming that the seed is different and it is not reproduced) Personally, the priority of this issue is quite high. |
@Keiku There are several open PRs related to this. So far this is not yet supported, but we are actively working on it. |
Probably due to the seed problem, the validation accuracy are different, but it seems that resume training was possible. Clean Run (20 epochs on CIFAR10 Training)Validation Acc: 85.18 Interrupted Run (Interrupted in the middle of 10 epochs on CIFAR10 Training)Validation Acc: 84.9 I'm using |
Hi @Keiku, can you kindly share how you do to resume training? Directly using |
@Wuziyi616 Please refer to my implementation of cifar10. Keiku/PyTorch-Lightning-CIFAR10: "Not too complicated" training code for CIFAR-10 by PyTorch Lightning https://github.com/Keiku/PyTorch-Lightning-CIFAR10 |
Great, thanks! So far I think the major modification you did is adding the Anyway, thanks for your help! |
@Wuziyi616 I have the same understanding as you in terms of this problem. You may consider saving checkpoints frequently for each small iteration (epoch). I am also waiting for the fault tolerant training feature of pytorch-lightning. |
Yes, I completely agree with you. Thanks a lot! |
I think this feature is vital for fault tolerant training. Where is this currently at? |
@AntreasAntoniou See this comment. #5325 (comment) |
The docs for fault-tolerance are here: And there's a proposal to simplify the checkpoint management: #11912 |
Hi @carmocca , I looked at the link you posted. I can't find an example. How can I resume the training after a failure? |
You want to set |
Hi I'm still facing a pytorch-forecasting==0.10.1 Error trace: `KeyError Traceback (most recent call last)
/tmp/ipykernel_10535/1750406609.py in <module>
4 train_dataloaders=train_dataloader,
5 val_dataloaders=val_dataloader,
----> 6 ckpt_path=best_model_path,
7 )
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
581 self.strategy._lightning_module = model
582 call._call_and_handle_interrupt(
--> 583 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
584 )
585
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/call.py in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
36 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
37 else:
---> 38 return trainer_fn(*args, **kwargs)
39
40 except _TunerExitException:
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
622 model_connected=self.lightning_module is not None,
623 )
--> 624 self._run(model, ckpt_path=self.ckpt_path)
625
626 assert self.state.stopped
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _run(self, model, ckpt_path)
1055 # restore optimizers, etc.
1056 log.detail(f"{self.__class__.__name__}: restoring training state")
-> 1057 self._checkpoint_connector.restore_training_state()
1058
1059 self._checkpoint_connector.resume_end()
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py in restore_training_state(self)
294 if self.trainer.state.fn == TrainerFn.FITTING:
295 # restore optimizers and schedulers state
--> 296 self.restore_optimizers_and_schedulers()
297
298 def restore_precision_plugin_state(self) -> None:
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py in restore_optimizers_and_schedulers(self)
409 " This is probably due to `ModelCheckpoint.save_weights_only` being set to `True`."
410 )
--> 411 self.restore_optimizers()
412
413 if "lr_schedulers" not in self._loaded_checkpoint:
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py in restore_optimizers(self)
424
425 # restore the optimizers
--> 426 self.trainer.strategy.load_optimizer_state_dict(self._loaded_checkpoint)
427
428 def restore_lr_schedulers(self) -> None:
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/strategies/strategy.py in load_optimizer_state_dict(self, checkpoint)
366 optimizer_states = checkpoint["optimizer_states"]
367 for optimizer, opt_state in zip(self.optimizers, optimizer_states):
--> 368 optimizer.load_state_dict(opt_state)
369 _optimizer_to_device(optimizer, self.root_device)
370
/opt/conda/lib/python3.7/site-packages/torch/optim/optimizer.py in load_state_dict(self, state_dict)
242 param_groups = [
243 update_group(g, ng) for g, ng in zip(groups, saved_groups)]
--> 244 self.__setstate__({'state': state, 'param_groups': param_groups})
245
246 def zero_grad(self, set_to_none: bool = False):
/opt/conda/lib/python3.7/site-packages/pytorch_forecasting/optim.py in __setstate__(self, state)
131 def __setstate__(self, state: dict) -> None:
132 super().__setstate__(state)
--> 133 self.radam_buffer = state["radam_buffer"]
134 self.alpha = state["alpha"]
135 self.k = state["k"]
KeyError: 'radam_buffer'` |
@himanshudoi It looks like you are trying to reload an optimizer with a checkpoint that used a different optimizer class |
🐛 Bug
What am I trying to do?
save_last=True
.resume_from_checkpoint
argument of theTrainer
.Expected Behavior
The Trainer is supposed to restore the entire training state (epoch, step, LR scheduler) according to the documentation here.
Actual Behavior
If the training is interrupted during an epoch, the ModelCheckpoint callback correctly saves the model and the training state. However, when we resume training, the training actually starts from the next epoch. So let's say we interrupted training when 20% of the first epoch had finished. When we resume training, the trainer actually starts from the second epoch, thereby skipping 80% of the steps of the first epoch.
Please reproduce using the BoringModel
https://colab.research.google.com/drive/1f32UPGA8eINrz17ykt9krCfjllI3krOI?usp=sharing
To Reproduce
Please follow the steps in the BoringModel above. I am not sure if the tensorboard snippets will be visible so I have taken screenshots here:
Clean Run
Interrupted Run
Environment
cc @Borda @carmocca @justusschock @awaelchli @ninginthecloud
The text was updated successfully, but these errors were encountered: