diff --git a/CHANGELOG.md b/CHANGELOG.md index 04c25eefcad66..e23630076eec6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -281,6 +281,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed + +- Fixed attaching train and validation dataloaders when `reload_dataloaders_every_epoch=True` and `num_sanity_val_steps=0` ([#7207](https://github.com/PyTorchLightning/pytorch-lightning/pull/7207)) + + - Added a barrier in the accelerator `teardown` to synchronize processes before execution finishes ([#6814](https://github.com/PyTorchLightning/pytorch-lightning/pull/6814)) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 8a007086fb380..cef70e2bf7811 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -188,11 +188,17 @@ def on_train_batch_end(self, epoch_output, batch_end_outputs, batch, batch_idx, # reset batch logger internals self.trainer.logger_connector.on_train_batch_end() - def reset_train_val_dataloaders(self, model): - if self.trainer.train_dataloader is None or not self.trainer.reload_dataloaders_every_epoch: + def reset_train_val_dataloaders(self, model) -> None: + """ + Resets train and val dataloaders if none are attached to the trainer. + + The val dataloader must be initialized before training loop starts, as the training loop + inspects the val dataloader to determine whether to run the evaluation loop. + """ + if self.trainer.train_dataloader is None: self.trainer.reset_train_dataloader(model) - if self.trainer.val_dataloaders is None and not self.trainer.reload_dataloaders_every_epoch: + if self.trainer.val_dataloaders is None: self.trainer.reset_val_dataloader(model) def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs): diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 2a744c9c05c73..a935fbd401e7e 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -25,6 +25,7 @@ import tests.helpers.pipelines as tpipes from pytorch_lightning import Callback, seed_everything, Trainer +from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_6 from pytorch_lightning.utilities.data import has_iterable_dataset, has_len @@ -1199,7 +1200,16 @@ def test_dataloaders_load_every_epoch(tmpdir): @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) def test_dataloaders_load_every_epoch_no_sanity_check(tmpdir): - model = EvalModelTemplate() + class TestModel(BoringModel): + + def validation_step(self, batch, batch_idx): + self.log("dummy_val", 5.0) + return super().validation_step(batch, batch_idx) + + model = TestModel() + + # This callback tests that the evaluation metrics are available by the time we run checkpointing + checkpoint_callback = ModelCheckpoint(monitor="dummy_val", save_top_k=1) # logger file to get meta trainer = Trainer( @@ -1209,21 +1219,32 @@ def test_dataloaders_load_every_epoch_no_sanity_check(tmpdir): num_sanity_val_steps=0, reload_dataloaders_every_epoch=True, max_epochs=3, + callbacks=[checkpoint_callback], ) trainer.fit(model) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" trainer.test() - assert len(trainer.dev_debugger.val_dataloader_calls) == 3 + assert len(trainer.dev_debugger.val_dataloader_calls) == 4 assert len(trainer.dev_debugger.train_dataloader_calls) == 3 assert len(trainer.dev_debugger.test_dataloader_calls) == 1 # verify the sequence calls = trainer.dev_debugger.dataloader_sequence_calls + expected_sequence = [ 'train_dataloader', 'val_dataloader', + # This has subsequent calls to val_dataloader + # because the training loop runs the evaluation loop, + # which reloads the val dataloader again. + # We cannot yet rely on trainer.current_epoch=0 to skip reloading + # the val dataloader on the first epoch because this only tracks the training epoch + # meaning multiple passes through the validation data within a single training epoch + # would not have the dataloader reloaded. + # This breaks the assumption behind reload_dataloaders_every_epoch=True + 'val_dataloader', 'train_dataloader', 'val_dataloader', 'train_dataloader',