diff --git a/CHANGELOG.md b/CHANGELOG.md index 642e8dd25c436..40f7690e5b352 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -363,6 +363,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Disable quantization aware training observers by default during validating/testing/predicting stages ([#8540](https://github.com/PyTorchLightning/pytorch-lightning/pull/8540)) +- Enabled `on_load_checkpoint` for `LightningDataModule` for all `trainer_fn` ([#10238](https://github.com/PyTorchLightning/pytorch-lightning/pull/10238)) + + ### Deprecated - Deprecated trainer argument `terminate_on_nan` in favour of `detect_anomaly`([#9175](https://github.com/PyTorchLightning/pytorch-lightning/pull/9175)) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 74522424c5326..96d2fb58de905 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1043,8 +1043,8 @@ def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None # restore modules after setup self.checkpoint_connector.resume_start(checkpoint_path) self.checkpoint_connector.restore_model() + self.checkpoint_connector.restore_datamodule() if self.state.fn == TrainerFn.FITTING: - self.checkpoint_connector.restore_datamodule() # restore callback states self.checkpoint_connector.restore_callbacks() diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 539767ac2d686..51b51bfbd011a 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -24,6 +24,7 @@ from pytorch_lightning import LightningDataModule, Trainer from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import AttributeDict from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden @@ -306,7 +307,7 @@ def test_train_val_loop_only(tmpdir): assert trainer.callback_metrics["train_loss"] < 1.0 -def test_dm_checkpoint_save(tmpdir): +def test_dm_checkpoint_save_and_load(tmpdir): class CustomBoringModel(BoringModel): def validation_step(self, batch, batch_idx): out = super().validation_step(batch, batch_idx) @@ -334,13 +335,19 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: ) # fit model - trainer.fit(model, dm) + trainer.fit(model, datamodule=dm) assert trainer.state.finished, f"Training failed with {trainer.state}" checkpoint_path = list(trainer.checkpoint_callback.best_k_models.keys())[0] checkpoint = torch.load(checkpoint_path) assert dm.__class__.__name__ in checkpoint assert checkpoint[dm.__class__.__name__] == dm.__class__.__name__ + for trainer_fn in TrainerFn: + trainer.state.fn = trainer_fn + with mock.patch.object(dm, "on_load_checkpoint") as dm_mock: + trainer._restore_modules_and_callbacks(checkpoint_path) + dm_mock.assert_called_once() + def test_full_loop(tmpdir): reset_seed()