diff --git a/CHANGELOG.md b/CHANGELOG.md index a32265dbe55f3..dfcacc2f62229 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -212,6 +212,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - MLflowLogger now uses the env variable `MLFLOW_TRACKING_URI` as default tracking uri ([#7457](https://github.com/PyTorchLightning/pytorch-lightning/pull/7457)) +- Changed `Trainer` arg and functionality from `reload_dataloaders_every_epoch` to `reload_dataloaders_every_n_epochs` ([#5043](https://github.com/PyTorchLightning/pytorch-lightning/pull/5043)) - Changed `WandbLogger(log_model={True/'all'})` to log models as artifacts ([#6231](https://github.com/PyTorchLightning/pytorch-lightning/pull/6231)) @@ -288,6 +289,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `mode` parameter in `ModelSummary` in favor of `max_depth` ([#8062](https://github.com/PyTorchLightning/pytorch-lightning/pull/8062)) +- Deprecated `reload_dataloaders_every_epoch` argument of `Trainer` in favor of `reload_dataloaders_every_n_epochs` ([#5043](https://github.com/PyTorchLightning/pytorch-lightning/pull/5043)) + + ### Removed - Removed `ProfilerConnector` ([#7654](https://github.com/PyTorchLightning/pytorch-lightning/pull/7654)) @@ -708,6 +712,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). [#6323](https://github.com/PyTorchLightning/pytorch-lightning/pull/6323), [#6211](https://github.com/PyTorchLightning/pytorch-lightning/pull/6211)) + ## [1.2.9] - 2021-04-20 ### Fixed @@ -752,8 +757,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Enforce an epoch scheduler interval when using SWA ([#6588](https://github.com/PyTorchLightning/pytorch-lightning/pull/6588)) - Fixed TPU Colab hang issue, post training ([#6816](https://github.com/PyTorchLightning/pytorch-lightning/pull/6816)) - Fixed a bug where `TensorBoardLogger` would give a warning and not log correctly to a symbolic link `save_dir` ([#6730](https://github.com/PyTorchLightning/pytorch-lightning/pull/6730)) - - - Fixed bug where `predict` could not be used when `progress_bar_refresh_rate=0` ([#6884](https://github.com/PyTorchLightning/pytorch-lightning/pull/6884)) diff --git a/docs/source/common/trainer.rst b/docs/source/common/trainer.rst index 0983f0acb9eec..572ea5b4b4d09 100644 --- a/docs/source/common/trainer.rst +++ b/docs/source/common/trainer.rst @@ -1297,8 +1297,8 @@ Note: Lightning will set it to 20 in these environments if the user does not provide a value. - This argument is ignored if a custom callback is passed to :paramref:`~Trainer.callbacks`. -reload_dataloaders_every_epoch -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +reload_dataloaders_every_n_epochs +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. raw:: html @@ -1308,19 +1308,20 @@ reload_dataloaders_every_epoch | -Set to True to reload dataloaders every epoch. +Set to a postive integer to reload dataloaders every n epochs. .. code-block:: python - # if False (default) + # if 0 (default) train_loader = model.train_dataloader() for epoch in epochs: for batch in train_loader: ... - # if True + # if a positive integer for epoch in epochs: - train_loader = model.train_dataloader() + if not epoch % reload_dataloaders_every_n_epochs: + train_loader = model.train_dataloader() for batch in train_loader: .. _replace-sampler-ddp: diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 50b058c3c24c2..e0b4f7c74e477 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -435,8 +435,9 @@ def train_dataloader(self) -> TRAIN_DATALOADERS: A collection of :class:`torch.utils.data.DataLoader` specifying training samples. In the case of multiple dataloaders, please see this :ref:`page `. - The dataloader you return will not be called every epoch unless you set - :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch` to ``True``. + The dataloader you return will not be reloaded unless you set + :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to + a positive integer. For data processing use the following pattern: @@ -505,8 +506,9 @@ def test_dataloader(self) -> EVAL_DATALOADERS: r""" Implement one or multiple PyTorch DataLoaders for testing. - The dataloader you return will not be called every epoch unless you set - :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch` to ``True``. + The dataloader you return will not be reloaded unless you set + :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to + a postive integer. For data processing use the following pattern: @@ -565,8 +567,9 @@ def val_dataloader(self) -> EVAL_DATALOADERS: r""" Implement one or multiple PyTorch DataLoaders for validation. - The dataloader you return will not be called every epoch unless you set - :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch` to ``True``. + The dataloader you return will not be reloaded unless you set + :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to + a positive integer. It's recommended that all data downloads and preparation happen in :meth:`prepare_data`. diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 4ab939b2e97be..2f6e14b93b767 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -180,7 +180,7 @@ def reload_evaluation_dataloaders(self) -> None: model = self.trainer.lightning_module if self.trainer.testing: self.trainer.reset_test_dataloader(model) - elif self.trainer.val_dataloaders is None or self.trainer.reload_dataloaders_every_epoch: + elif self.trainer.val_dataloaders is None or self.trainer._should_reload_dl_epoch: self.trainer.reset_val_dataloader(model) def on_evaluation_start(self, *args: Any, **kwargs: Any) -> None: diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 555b9bf6b76a3..f2681398ff347 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -191,7 +191,7 @@ def on_advance_start(self) -> None: model = self.trainer.lightning_module # reset train dataloader - if self.current_epoch != 0 and self.trainer.reload_dataloaders_every_epoch: + if self.current_epoch != 0 and self.trainer._should_reload_dl_epoch: self.trainer.reset_train_dataloader(model) # TODO: specify the possible exception diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index b95db75bde8c3..92019edbeff56 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -16,6 +16,7 @@ import pytorch_lightning as pl from pytorch_lightning.trainer.supporters import prefetch_iterator +from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS @@ -28,7 +29,11 @@ def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_ self.multiple_trainloader_mode = multiple_trainloader_mode def on_trainer_init( - self, check_val_every_n_epoch: int, reload_dataloaders_every_epoch: bool, prepare_data_per_node: bool + self, + check_val_every_n_epoch: int, + reload_dataloaders_every_n_epochs: int, + reload_dataloaders_every_epoch: bool, + prepare_data_per_node: bool, ) -> None: self.trainer.datamodule = None self.trainer.prepare_data_per_node = prepare_data_per_node @@ -39,7 +44,21 @@ def on_trainer_init( ) self.trainer.check_val_every_n_epoch = check_val_every_n_epoch - self.trainer.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch + + if reload_dataloaders_every_epoch: + reload_dataloaders_every_n_epochs = int(reload_dataloaders_every_epoch) + rank_zero_deprecation( + "`reload_dataloaders_every_epoch` is deprecated in v1.4 and will be removed in v1.6." + " Please use `reload_dataloaders_every_n_epochs` in Trainer." + ) + + if not isinstance(reload_dataloaders_every_n_epochs, int) or (reload_dataloaders_every_n_epochs < 0): + raise MisconfigurationException( + "`reload_dataloaders_every_n_epochs` should be an int >= 0," + f" got {reload_dataloaders_every_n_epochs}." + ) + + self.trainer.reload_dataloaders_every_n_epochs = reload_dataloaders_every_n_epochs self.trainer._is_data_prepared = False def get_profiled_train_dataloader(self, train_dataloader): diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index a519b82234a43..54d0079b9255e 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -59,6 +59,7 @@ class TrainerProperties(ABC): accelerator_connector: AcceleratorConnector callbacks: List[Callback] checkpoint_connector: CheckpointConnector + reload_dataloaders_every_n_epochs: int limit_val_batches: int logger: LightningLoggerBase logger_connector: LoggerConnector @@ -293,6 +294,12 @@ def progress_bar_dict(self) -> dict: ) return {**standard_metrics, **pbar_metrics} + @property + def _should_reload_dl_epoch(self) -> bool: + """ Check if dataloader should be reloaded in the current epoch. """ + n_epochs = self.reload_dataloaders_every_n_epochs + return n_epochs and (not self.current_epoch % n_epochs) + @property def disable_validation(self) -> bool: """ Check if validation is disabled during training. """ diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b984608c87d6d..b09249842b3d2 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -144,6 +144,7 @@ def __init__( profiler: Optional[Union[BaseProfiler, str]] = None, benchmark: bool = False, deterministic: bool = False, + reload_dataloaders_every_n_epochs: int = 0, reload_dataloaders_every_epoch: bool = False, auto_lr_find: Union[bool, str] = False, replace_sampler_ddp: bool = True, @@ -272,8 +273,15 @@ def __init__( num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine. Set it to `-1` to run all batches in all validation dataloaders. + reload_dataloaders_every_n_epochs: Set to a non-negative integer to reload dataloaders every n epochs. + Default: 0 + reload_dataloaders_every_epoch: Set to True to reload dataloaders every epoch. + .. deprecated:: v1.4 + ``reload_dataloaders_every_epoch`` has been deprecated in v1.4 and will be removed in v1.6. + Please use ``reload_dataloaders_every_n_epochs``. + replace_sampler_ddp: Explicitly enables or disables sampler replacement. If not specified this will toggled automatically when DDP is used. By default it will add ``shuffle=True`` for train sampler and ``shuffle=False`` for val/test sampler. If you want to customize it, @@ -382,7 +390,8 @@ def __init__( # init data flags self.data_connector.on_trainer_init( - check_val_every_n_epoch, reload_dataloaders_every_epoch, prepare_data_per_node + check_val_every_n_epoch, reload_dataloaders_every_n_epochs, reload_dataloaders_every_epoch, + prepare_data_per_node ) # init training tricks diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 30131cdcc80d2..c056ab1aa4fbf 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -455,9 +455,11 @@ def transfer_batch_to_device(self, batch, device, dataloader_idx): assert torch.allclose(batch_gpu.targets.cpu(), torch.ones(5, 1, dtype=torch.long) * 2) -def test_dm_reload_dataloaders_every_epoch(tmpdir): - """Test datamodule, where trainer argument - reload_dataloaders_every_epoch is set to True/False""" +def test_dm_reload_dataloaders_every_n_epochs(tmpdir): + """ + Test datamodule, where trainer argument + reload_dataloaders_every_n_epochs is set to a non negative integer + """ class CustomBoringDataModule(BoringDataModule): @@ -482,9 +484,9 @@ def train_dataloader(self): trainer = Trainer( default_root_dir=tmpdir, - max_epochs=2, - limit_train_batches=0.01, - reload_dataloaders_every_epoch=True, + max_epochs=3, + limit_train_batches=2, + reload_dataloaders_every_n_epochs=2, ) trainer.fit(model, dm) diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py index ac5b56ea00086..69d2a45530607 100644 --- a/tests/deprecated_api/test_remove_1-6.py +++ b/tests/deprecated_api/test_remove_1-6.py @@ -88,6 +88,28 @@ def test_v1_6_0_ddp_spawn_sync_batchnorm(): DDPSpawnPlugin(sync_batchnorm=False) +def test_v1_6_0_reload_dataloaders_every_epoch(tmpdir): + + model = BoringModel() + + with pytest.deprecated_call(match='`reload_dataloaders_every_epoch` is deprecated in v1.4 and will be removed'): + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=0.3, + limit_val_batches=0.3, + reload_dataloaders_every_epoch=True, + max_epochs=3, + ) + trainer.fit(model) + trainer.test() + + # verify the sequence + calls = trainer.dev_debugger.dataloader_sequence_calls + expected_sequence = ['val_dataloader'] + ['train_dataloader', 'val_dataloader'] * 3 + ['test_dataloader'] + for call, expected in zip(calls, expected_sequence): + assert call['name'] == expected + + def test_v1_6_0_tbptt_reduce_fx(tmpdir): class TestModel(BoringModel): diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 9e91c8dfe6e53..c4044935f4bd3 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -1304,7 +1304,7 @@ def test_dataloaders_load_only_once_val_interval(tmpdir): limit_train_batches=10, limit_val_batches=10, val_check_interval=0.3, - reload_dataloaders_every_epoch=True, + reload_dataloaders_every_n_epochs=True, max_epochs=3, ) trainer.fit(model) @@ -1368,17 +1368,16 @@ def test_dataloaders_load_only_once_no_sanity_check(tmpdir): assert call['name'] == expected -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) -def test_dataloaders_load_every_epoch(tmpdir): +@pytest.mark.parametrize("n", [1, 2]) +def test_dataloaders_load_every_n_epochs(tmpdir, n): - model = EvalModelTemplate() + model = BoringModel() - # logger file to get meta trainer = Trainer( default_root_dir=tmpdir, limit_train_batches=0.3, limit_val_batches=0.3, - reload_dataloaders_every_epoch=True, + reload_dataloaders_every_n_epochs=n, max_epochs=3, ) trainer.fit(model) @@ -1386,26 +1385,29 @@ def test_dataloaders_load_every_epoch(tmpdir): trainer.test() - assert len(trainer.dev_debugger.val_dataloader_calls) == 4 - assert len(trainer.dev_debugger.train_dataloader_calls) == 3 - assert len(trainer.dev_debugger.test_dataloader_calls) == 1 - # verify the sequence calls = trainer.dev_debugger.dataloader_sequence_calls - expected_sequence = [ - 'val_dataloader', - 'train_dataloader', - 'val_dataloader', - 'train_dataloader', - 'val_dataloader', - 'train_dataloader', - 'val_dataloader', - 'test_dataloader', - ] + expected_sequence = ['val_dataloader'] + if n == 1: + expected_sequence += ['train_dataloader', 'val_dataloader'] * 3 + elif n == 2: + expected_sequence += ['train_dataloader', 'val_dataloader'] * 2 + expected_sequence += ['test_dataloader'] + for call, expected in zip(calls, expected_sequence): assert call['name'] == expected +@pytest.mark.parametrize("n", ['test', -1]) +def test_dataloaders_load_every_n_epochs_exception(tmpdir, n): + + with pytest.raises(MisconfigurationException, match='should be an int >'): + Trainer( + default_root_dir=tmpdir, + reload_dataloaders_every_n_epochs=n, + ) + + @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) def test_dataloaders_load_every_epoch_no_sanity_check(tmpdir): @@ -1426,7 +1428,7 @@ def validation_step(self, batch, batch_idx): limit_train_batches=0.3, limit_val_batches=0.3, num_sanity_val_steps=0, - reload_dataloaders_every_epoch=True, + reload_dataloaders_every_n_epochs=True, max_epochs=3, callbacks=[checkpoint_callback], )