Skip to content

Commit

Permalink
enable on_load_checkpoint for datamodule for all trainer_fn (#1…
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitgr7 authored Nov 1, 2021
1 parent 45c45dc commit 6609b2e
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 3 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Quantization aware training observers are now disabled by default during validating/testing/predicting stages ([#8540](https://github.com/PyTorchLightning/pytorch-lightning/pull/8540))


- Enabled `on_load_checkpoint` for `LightningDataModule` for all `trainer_fn` ([#10238](https://github.com/PyTorchLightning/pytorch-lightning/pull/10238))


### Deprecated

- Deprecated Trainer argument `terminate_on_nan` in favor of `detect_anomaly`([#9175](https://github.com/PyTorchLightning/pytorch-lightning/pull/9175))
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,8 +1100,8 @@ def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None
# restore modules after setup
self.checkpoint_connector.resume_start(checkpoint_path)
self.checkpoint_connector.restore_model()
self.checkpoint_connector.restore_datamodule()
if self.state.fn == TrainerFn.FITTING:
self.checkpoint_connector.restore_datamodule()
# restore callback states
self.checkpoint_connector.restore_callbacks()

Expand Down
11 changes: 9 additions & 2 deletions tests/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from pytorch_lightning import LightningDataModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import AttributeDict
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
Expand Down Expand Up @@ -306,7 +307,7 @@ def test_train_val_loop_only(tmpdir):
assert trainer.callback_metrics["train_loss"] < 1.0


def test_dm_checkpoint_save(tmpdir):
def test_dm_checkpoint_save_and_load(tmpdir):
class CustomBoringModel(BoringModel):
def validation_step(self, batch, batch_idx):
out = super().validation_step(batch, batch_idx)
Expand Down Expand Up @@ -334,13 +335,19 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
)

# fit model
trainer.fit(model, dm)
trainer.fit(model, datamodule=dm)
assert trainer.state.finished, f"Training failed with {trainer.state}"
checkpoint_path = list(trainer.checkpoint_callback.best_k_models.keys())[0]
checkpoint = torch.load(checkpoint_path)
assert dm.__class__.__name__ in checkpoint
assert checkpoint[dm.__class__.__name__] == dm.__class__.__name__

for trainer_fn in TrainerFn:
trainer.state.fn = trainer_fn
with mock.patch.object(dm, "on_load_checkpoint") as dm_mock:
trainer._restore_modules_and_callbacks(checkpoint_path)
dm_mock.assert_called_once()


def test_full_loop(tmpdir):
reset_seed()
Expand Down

0 comments on commit 6609b2e

Please sign in to comment.