Skip to content

Commit

Permalink
enable on_load_checkpoint for datamodule
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitgr7 committed Oct 29, 2021
1 parent a5235d5 commit f2693f9
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,8 +1014,8 @@ def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None
# restore modules after setup
self.checkpoint_connector.resume_start(checkpoint_path)
self.checkpoint_connector.restore_model()
self.checkpoint_connector.restore_datamodule()
if self.state.fn == TrainerFn.FITTING:
self.checkpoint_connector.restore_datamodule()
# restore callback states
self.checkpoint_connector.restore_callbacks()

Expand Down
11 changes: 9 additions & 2 deletions tests/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from pytorch_lightning import LightningDataModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities import AttributeDict
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -308,7 +309,7 @@ def test_train_val_loop_only(tmpdir):
assert trainer.callback_metrics["train_loss"] < 1.0


def test_dm_checkpoint_save(tmpdir):
def test_dm_checkpoint_save_and_load(tmpdir):
class CustomBoringModel(BoringModel):
def validation_step(self, batch, batch_idx):
out = super().validation_step(batch, batch_idx)
Expand Down Expand Up @@ -336,13 +337,19 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
)

# fit model
trainer.fit(model, dm)
trainer.fit(model, datamodule=dm)
assert trainer.state.finished, f"Training failed with {trainer.state}"
checkpoint_path = list(trainer.checkpoint_callback.best_k_models.keys())[0]
checkpoint = torch.load(checkpoint_path)
assert dm.__class__.__name__ in checkpoint
assert checkpoint[dm.__class__.__name__] == dm.__class__.__name__

for trainer_fn in TrainerFn:
trainer.state.fn = trainer_fn
with mock.patch.object(dm, "on_load_checkpoint") as dm_mock:
trainer._restore_modules_and_callbacks(checkpoint_path)
dm_mock.assert_called_once()


def test_full_loop(tmpdir):
reset_seed()
Expand Down

0 comments on commit f2693f9

Please sign in to comment.