-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Checkpointing is broken on TPUs #2700
Comments
@lezwon mind have look? |
sure :] |
Thanks, @lezwon. You might want to check this fix here ibeltagy@a5c8d18 which works but I don't like it. I also tried calling the functions inside |
@ibeltagy Nice work :] I'll check out your solution and try and back with a fix on this. |
@ibeltagy I am able to reload the checkpoint successfully, however, the training fails due to some xla device issue. Is it the same error you face? could you share a notebook reproducing this issue? |
Loading the checkpoint fails only fails when loading without a TPU device available, as |
Is this still an issue? I am experiencing the exact same problem. |
🐛 Bug
Pytorch/XLA saves checkpoints using the following syntax which is not supported in pytorch-lightning.
It is a little tricky to support because
xm.save()
has a barrier inside it and it checks for rank=0 whiletorch.save
doesn't. This meanstorch.save
should be called only on the process with rank=0 (which pytorch-lighting does) butxm.save()
should be called by all processes (or it will wait forever at the barrier). This means pytorch-lightning code that checks for the rank (here will need to be switched off on TPUs.To Reproduce
ptl.Trainer(checkpoint_callback=[ModelCheckpoint(...)], num_tpu_cores=8)
ptl.Trainer(resume_from_checkpoint='path_to_saved_checkpoint', num_tpu_cores=8)
Expected behavior
Loading checkpoint successfully.
Environment
pytorch-lightning==v0.8.5
Additional Context
Thanks to @matt-peters for finding the bug and suggesting the solution mention below.
The text was updated successfully, but these errors were encountered: