-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stop loading a few properties if checkpoint's dirpath
has changed
#12045
Stop loading a few properties if checkpoint's dirpath
has changed
#12045
Conversation
Co-authored-by: Rohit Gupta <[email protected]>
Here is a full example of what I mean in my comments above.
import os
import shutil
import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("latest_is_best", self.global_step)
return {"loss": loss}
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def run():
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
shutil.rmtree("./first", ignore_errors=True)
shutil.rmtree("./after-reload", ignore_errors=True)
model = BoringModel()
checkpoint = ModelCheckpoint(
dirpath="./first", monitor="latest_is_best", mode="max", save_top_k=3, save_last=True,
every_n_train_steps=1,
filename='{epoch}-{step}-{latest_is_best}',
)
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=5,
max_epochs=1,
enable_model_summary=False,
callbacks=[checkpoint],
)
trainer.fit(model, train_dataloaders=train_data)
# NOTE: last exists here, but not later on.
assert os.path.isfile("./first/last.ckpt")
# ------------------------
# RESUME WITH CHANGED PATH
checkpoint = ModelCheckpoint(
dirpath="./after-reload", monitor="latest_is_best", mode="max", save_top_k=3, save_last=True,
every_n_train_steps=1,
filename='{epoch}-{step}-{latest_is_best}',
)
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=5,
max_epochs=2, # MORE EPOCHS
enable_model_summary=False,
callbacks=[checkpoint],
)
trainer.fit(model, train_dataloaders=train_data, ckpt_path="./first/epoch=0-step=3-latest_is_best=3.0.ckpt")
# OBSERVE: last.ckpt has disappeared from the first folder
# BUG?
assert not os.path.isfile("./first/last.ckpt")
if __name__ == "__main__":
run() |
@awaelchli - Thank you so much for taking a look. You are right, the words were misleading, and they should be fixed now. Regarding the bug you mentioned, that's legit, and I'll create a follow-up PR to fix it. This also involves re-thinking if |
dirpath
has changeddirpath
has changed
Codecov Report
@@ Coverage Diff @@
## master #12045 +/- ##
========================================
- Coverage 92% 88% -4%
========================================
Files 205 205
Lines 17440 17472 +32
========================================
- Hits 15980 15336 -644
- Misses 1460 2136 +676 |
What does this PR do?
dirpath
is provided, current state ofon_load_checkpoint
will track all the properties (likebest_model_score
,kth_best_model_path
, etc.). We should only tracklast_model_path
andbest_model_path
and stop trackingbest_model_score
,kth_best_model_path
,kth_value
,best_k_models
if the checkpoint path has changed. Please see issue Resuming training with ModelCheckpoint can delete checkpoints in other runs #11379 and comment by @rohitgr7 here: Resuming training with ModelCheckpoint can delete checkpoints in other runs #11379 (comment).__resolve_ckpt_dir
has been moved fromon_pretrain_routine_start
to_setup
function. Please see the discussion here for more context.Fixes #11379
Does your PR introduce any breaking changes? If yes, please list them.
Yes, if the checkpoint has changed on resumed training, the following properties won't be reloaded:
best_model_score
kth_best_model_path
kth_value
best_k_models
Only
last_model_path
andbest_model_path
will be tracked.Additionally, a warning will be raised if the checkpoint path has changed.
Before submitting
PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:
Did you have fun?
Make sure you had fun coding 🙃