From 2d9f6507d6e46a1f94b7cee865f4e66ed0a80d65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 22 Jun 2021 12:07:53 +0200 Subject: [PATCH 001/157] rename training_loop -> epoch_Loop --- pytorch_lightning/core/lightning.py | 4 +- pytorch_lightning/core/optimizer.py | 2 +- pytorch_lightning/loops/fit_loop.py | 45 ++++++++++--------- tests/deprecated_api/test_remove_1-5.py | 2 +- .../loops/test_evaluation_loop_flow.py | 8 ++-- .../loops/test_training_loop_flow_scalar.py | 12 ++--- tests/trainer/test_trainer.py | 16 +++---- 7 files changed, 45 insertions(+), 44 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index fb0b19899561d..e7c9852968b36 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1371,7 +1371,7 @@ def training_step(...): # backward self._running_manual_backward = True - self.trainer.fit_loop.training_loop.batch_loop.backward(loss, optimizer=None, opt_idx=None, *args, **kwargs) + self.trainer.fit_loop.epoch_loop.batch_loop.backward(loss, optimizer=None, opt_idx=None, *args, **kwargs) self._running_manual_backward = False def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args, **kwargs) -> None: @@ -1471,7 +1471,7 @@ def optimizer_step( If you are overriding this method, make sure that you pass the ``optimizer_closure`` parameter to ``optimizer.step()`` function as shown in the examples. This ensures that ``training_step()``, ``optimizer.zero_grad()``, ``backward()`` are called within - :meth:`~pytorch_lightning.trainer.fit_loop.training_loop.batch_loop.TrainingBatchLoop.advance`. + :meth:`~pytorch_lightning.loops.training_batch_loop.TrainingBatchLoop.advance`. Args: epoch: Current epoch diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index 1da8a7af36221..3572a79b9bd84 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -120,7 +120,7 @@ def toggle_model(self, sync_grad: bool = True): during the accumulation phase. Setting `sync_grad` to False will block this synchronization and improve performance. """ - with self._trainer.fit_loop.training_loop.batch_loop.block_ddp_sync_behaviour(not sync_grad): + with self._trainer.fit_loop.epoch_loop.batch_loop.block_ddp_sync_behaviour(not sync_grad): self._toggle_model() yield self._untoggle_model() diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index c7c2585feb129..a0f9aed4068d3 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -51,8 +51,9 @@ def __init__( super().__init__() self.max_epochs = 1000 if (max_epochs is None and max_steps is None) else max_epochs self.min_epochs = 1 if (min_epochs is None and min_steps is None) else min_epochs - self.training_loop = TrainingEpochLoop(min_steps, max_steps) + self.epoch_loop = TrainingEpochLoop(min_steps, max_steps) self.validation_loop = EvaluationDataLoaderLoop() + self.results = ResultCollection(training=True) @property def results(self) -> ResultCollection: @@ -75,59 +76,59 @@ def current_epoch(self, value: int) -> None: @property def global_step(self) -> int: """Returns the global step""" - return self.training_loop.global_step + return self.epoch_loop.global_step @global_step.setter def global_step(self, value: int) -> None: - """Sets the global step (forwards to training_loop)""" - self.training_loop.global_step = value + """Sets the global step (forwards to epoch_loop)""" + self.epoch_loop.global_step = value @property def total_batch_idx(self) -> int: """Returns the total number of batches already run (across all epochs)""" - return self.training_loop.total_batch_idx + return self.epoch_loop.total_batch_idx @property def batch_idx(self) -> int: """Returns the number of batches already run within this epoch""" - return self.training_loop.iteration_count + return self.epoch_loop.iteration_count @property def split_idx(self) -> int: """Returns the index of the current batch split (within the current batch) for bptt""" - return self.training_loop.split_idx + return self.epoch_loop.split_idx @property def min_steps(self) -> int: # TODO(@justusschock): Why aren't we using the attribute in this class? """Returns the minimum numnber of steps to run""" - return self.training_loop.min_steps + return self.epoch_loop.min_steps @property def max_steps(self) -> int: """Returns the maximum number of steps to run""" - return self.training_loop.max_steps + return self.epoch_loop.max_steps @max_steps.setter def max_steps(self, value: int) -> None: - """Sets the maximum number of steps (forwards to training_loop)""" + """Sets the maximum number of steps (forwards to epoch_loop)""" # TODO(@awaelchli): This setter is required by debugging connector (fast dev run), should be avoided - self.training_loop.max_steps = value + self.epoch_loop.max_steps = value @property def running_loss(self) -> TensorRunningAccum: """Returns the running loss""" - return self.training_loop.batch_loop.running_loss + return self.epoch_loop.batch_loop.running_loss @property def _skip_backward(self) -> bool: """ Determines whether the loop will skip backward during automatic optimization. """ - return self.training_loop.batch_loop._skip_backward + return self.epoch_loop.batch_loop._skip_backward @_skip_backward.setter def _skip_backward(self, value: bool) -> None: """ Determines whether the loop will skip backward during automatic optimization. """ - self.training_loop.batch_loop._skip_backward = value + self.epoch_loop.batch_loop._skip_backward = value @property def done(self) -> bool: @@ -165,7 +166,7 @@ def skip(self) -> bool: def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None: """Connects the loop with necessary arguments like the trainer""" super().connect(trainer, *args, **kwargs) - self.training_loop.connect(trainer) + self.epoch_loop.connect(trainer) self.validation_loop.connect(trainer) def reset(self) -> None: @@ -193,7 +194,7 @@ def on_advance_start(self) -> None: self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module) # stores accumulated grad fractions per batch - self.training_loop.batch_loop.accumulated_loss = TensorRunningAccum( + self.epoch_loop.batch_loop.accumulated_loss = TensorRunningAccum( window_length=self.trainer.accumulate_grad_batches ) @@ -204,7 +205,7 @@ def advance(self) -> None: with self.trainer.profiler.profile("run_training_epoch"): # run train epoch - epoch_output = self.training_loop.run(train_dataloader) + epoch_output = self.epoch_loop.run(train_dataloader) if epoch_output is None: return @@ -220,10 +221,10 @@ def advance(self) -> None: def on_advance_end(self) -> None: """Updates the LR schedulers and does some internal bookkeeping""" - if self.training_loop.batches_seen == 0: + if self.epoch_loop.batches_seen == 0: return - self.training_loop.update_lr_schedulers('epoch', update_plateau_schedulers=True) + self.epoch_loop.update_lr_schedulers('epoch', update_plateau_schedulers=True) did_train_only = self.trainer.disable_validation or self.trainer.evaluation_loop.skip if did_train_only: @@ -241,10 +242,10 @@ def on_run_end(self) -> None: # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates # when a checkpoint was saved at the last step - self.training_loop.global_step -= 1 + self.epoch_loop.global_step -= 1 # TODO: see discussion/rework https://github.com/PyTorchLightning/pytorch-lightning/issues/7406 self._check_checkpoint_callback(should_update=True, is_last=True) - self.training_loop.global_step += 1 + self.epoch_loop.global_step += 1 # hook self.trainer.call_hook("on_train_end") @@ -266,7 +267,7 @@ def on_run_end(self) -> None: def should_accumulate(self) -> bool: """Whether the gradients should be accumulated""" - return self.training_loop.batch_loop.should_accumulate() + return self.epoch_loop.batch_loop.should_accumulate() def _check_checkpoint_callback(self, should_update: bool, is_last: bool = False): """Checks if checkpointing needs to be done""" diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index f8595390dd768..70bcc71d0a2a6 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -244,7 +244,7 @@ def on_train_epoch_end(self, outputs): # noqa with pytest.deprecated_call(match="old signature will be removed in v1.5"): trainer.fit(model) - trainer.fit_loop.training_loop._warning_cache.clear() + trainer.fit_loop.epoch_loop._warning_cache.clear() class NewSignature(Callback): diff --git a/tests/trainer/loops/test_evaluation_loop_flow.py b/tests/trainer/loops/test_evaluation_loop_flow.py index c9eb997c98dd6..14cb4ce4ae7f8 100644 --- a/tests/trainer/loops/test_evaluation_loop_flow.py +++ b/tests/trainer/loops/test_evaluation_loop_flow.py @@ -69,7 +69,7 @@ def backward(self, loss, optimizer, optimizer_idx): # simulate training manually trainer.state.stage = RunningStage.TRAINING batch_idx, batch = 0, next(iter(model.train_dataloader())) - out = trainer.fit_loop.training_loop.batch_loop.run(batch, batch_idx, 0) + out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx, 0) assert out.signal == 0 train_step_out = out.training_step_output @@ -79,7 +79,7 @@ def backward(self, loss, optimizer, optimizer_idx): assert train_step_out.minimize.item() == 171 # make sure the optimizer closure returns the correct things - opt_closure_result = trainer.fit_loop.training_loop.batch_loop.training_step_and_backward( + opt_closure_result = trainer.fit_loop.epoch_loop.batch_loop.training_step_and_backward( batch, batch_idx, 0, @@ -140,7 +140,7 @@ def backward(self, loss, optimizer, optimizer_idx): trainer.state.stage = RunningStage.TRAINING # make sure training outputs what is expected batch_idx, batch = 0, next(iter(model.train_dataloader())) - out = trainer.fit_loop.training_loop.batch_loop.run(batch, batch_idx, 0) + out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx, 0) assert out.signal == 0 train_step_out = out.training_step_output @@ -150,7 +150,7 @@ def backward(self, loss, optimizer, optimizer_idx): assert train_step_out.minimize.item() == 171 # make sure the optimizer closure returns the correct things - opt_closure_result = trainer.fit_loop.training_loop.batch_loop.training_step_and_backward( + opt_closure_result = trainer.fit_loop.epoch_loop.batch_loop.training_step_and_backward( batch, batch_idx, 0, trainer.optimizers[0], hiddens=None ) assert opt_closure_result['loss'].item() == 171 diff --git a/tests/trainer/loops/test_training_loop_flow_scalar.py b/tests/trainer/loops/test_training_loop_flow_scalar.py index 0e57797a80890..9b438aea45f87 100644 --- a/tests/trainer/loops/test_training_loop_flow_scalar.py +++ b/tests/trainer/loops/test_training_loop_flow_scalar.py @@ -149,7 +149,7 @@ def backward(self, loss, optimizer, optimizer_idx): trainer.state.stage = RunningStage.TRAINING # make sure training outputs what is expected batch_idx, batch = 0, next(iter(model.train_dataloader())) - out = trainer.fit_loop.training_loop.batch_loop.run(batch, batch_idx, 0) + out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx, 0) assert out.signal == 0 train_step_out = out.training_step_output @@ -159,7 +159,7 @@ def backward(self, loss, optimizer, optimizer_idx): assert train_step_out.minimize.item() == 171 # make sure the optimizer closure returns the correct things - opt_closure_result = trainer.fit_loop.training_loop.batch_loop.training_step_and_backward( + opt_closure_result = trainer.fit_loop.epoch_loop.batch_loop.training_step_and_backward( batch, batch_idx, 0, @@ -227,7 +227,7 @@ def backward(self, loss, optimizer, optimizer_idx): trainer.state.stage = RunningStage.TRAINING # make sure training outputs what is expected batch_idx, batch = 0, next(iter(model.train_dataloader())) - out = trainer.fit_loop.training_loop.batch_loop.run(batch, batch_idx, 0) + out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx, 0) assert out.signal == 0 train_step_out = out.training_step_output @@ -237,7 +237,7 @@ def backward(self, loss, optimizer, optimizer_idx): assert train_step_out.minimize.item() == 171 # make sure the optimizer closure returns the correct things - opt_closure_result = trainer.fit_loop.training_loop.batch_loop.training_step_and_backward( + opt_closure_result = trainer.fit_loop.epoch_loop.batch_loop.training_step_and_backward( batch, batch_idx, 0, trainer.optimizers[0], hiddens=None ) assert opt_closure_result['loss'].item() == 171 @@ -313,7 +313,7 @@ def training_step(self, batch, batch_idx): # manually check a few batches for batch_idx, batch in enumerate(model.train_dataloader()): - out = trainer.fit_loop.training_loop.batch_loop.run(batch, batch_idx, 0) + out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx, 0) if not batch_idx % 2: assert out.training_step_output == [[]] assert out.signal == 0 @@ -358,7 +358,7 @@ def train_dataloader(self): # manually check a few batches for batch_idx, batch in enumerate(model.train_dataloader()): - out = trainer.fit_loop.training_loop.batch_loop.run(batch, batch_idx, 0) + out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx, 0) if not batch_idx % 2: assert out.training_step_output == [[]] assert out.signal == 0 diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index f75ed3ac340f4..7d29376efbb0b 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -926,7 +926,7 @@ def test_gradient_clipping(tmpdir): default_root_dir=tmpdir, ) - old_training_step_and_backward = trainer.fit_loop.training_loop.batch_loop.training_step_and_backward + old_training_step_and_backward = trainer.fit_loop.epoch_loop.batch_loop.training_step_and_backward def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens): """ @@ -940,7 +940,7 @@ def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hidde return ret_val - trainer.fit_loop.training_loop.batch_loop.training_step_and_backward = training_step_and_backward + trainer.fit_loop.epoch_loop.batch_loop.training_step_and_backward = training_step_and_backward # for the test model.prev_called_batch_idx = 0 @@ -964,7 +964,7 @@ def test_gradient_clipping_by_value(tmpdir): default_root_dir=tmpdir ) - old_training_step_and_backward = trainer.fit_loop.training_loop.batch_loop.training_step_and_backward + old_training_step_and_backward = trainer.fit_loop.epoch_loop.batch_loop.training_step_and_backward def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens): """ @@ -980,7 +980,7 @@ def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hidde return ret_val - trainer.fit_loop.training_loop.batch_loop.training_step_and_backward = training_step_and_backward + trainer.fit_loop.epoch_loop.batch_loop.training_step_and_backward = training_step_and_backward # for the test model.prev_called_batch_idx = 0 @@ -1005,7 +1005,7 @@ def test_gradient_clipping_fp16(tmpdir): default_root_dir=tmpdir, ) - old_training_step_and_backward = trainer.fit_loop.training_loop.batch_loop.training_step_and_backward + old_training_step_and_backward = trainer.fit_loop.epoch_loop.batch_loop.training_step_and_backward def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens): """ @@ -1019,7 +1019,7 @@ def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hidde return ret_val - trainer.fit_loop.training_loop.batch_loop.training_step_and_backward = training_step_and_backward + trainer.fit_loop.epoch_loop.batch_loop.training_step_and_backward = training_step_and_backward model.prev_called_batch_idx = 0 trainer.fit(model) @@ -1044,7 +1044,7 @@ def test_gradient_clipping_by_value_fp16(tmpdir): default_root_dir=tmpdir, ) - old_training_step_and_backward = trainer.fit_loop.training_loop.batch_loop.training_step_and_backward + old_training_step_and_backward = trainer.fit_loop.epoch_loop.batch_loop.training_step_and_backward def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens): """ @@ -1060,7 +1060,7 @@ def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hidde return ret_val - trainer.fit_loop.training_loop.batch_loop.training_step_and_backward = training_step_and_backward + trainer.fit_loop.epoch_loop.batch_loop.training_step_and_backward = training_step_and_backward model.prev_called_batch_idx = 0 trainer.fit(model) From 03470f16625a9803af44f756d6254c68e9d1635c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 22 Jun 2021 12:11:58 +0200 Subject: [PATCH 002/157] EvaluationDataLoaderLoop -> EvaluationLoop --- CHANGELOG.md | 2 +- pytorch_lightning/loops/__init__.py | 2 +- pytorch_lightning/loops/dataloader/__init__.py | 2 +- .../loops/dataloader/evaluation_dataloader_loop.py | 2 +- pytorch_lightning/trainer/properties.py | 6 +++--- tests/trainer/loops/test_evaluation_loop.py | 2 +- 6 files changed, 8 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a0a423d6968f9..d1bfed8ab2ceb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -148,7 +148,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Removed the `on_epoch` guard from the "should stop" validation check ([#7701](https://github.com/PyTorchLightning/pytorch-lightning/pull/7701)) * Refactored internal loop interface; added new classes `FitLoop`, `TrainingEpochLoop`, `TrainingBatchLoop` ([#7871](https://github.com/PyTorchLightning/pytorch-lightning/pull/7871)) * Removed `pytorch_lightning/trainer/training_loop.py` ([#7985](https://github.com/PyTorchLightning/pytorch-lightning/pull/7985)) - * Refactored evaluation loop interface; added new classes `DataLoaderLoop`, `EvaluationDataLoaderLoop`, `EvaluationEpochLoop` ([#7990](https://github.com/PyTorchLightning/pytorch-lightning/pull/7990)) + * Refactored evaluation loop interface; added new classes `DataLoaderLoop`, `EvaluationLoop`, `EvaluationEpochLoop` ([#7990](https://github.com/PyTorchLightning/pytorch-lightning/pull/7990)) * Removed `pytorch_lightning/trainer/evaluation_loop.py` ([#8056](https://github.com/PyTorchLightning/pytorch-lightning/pull/8056)) * Restricted public access to several internal functions ([#8024](https://github.com/PyTorchLightning/pytorch-lightning/pull/8024)) * Refactored trainer `_run_*` functions and separate evaluation loops ([#8065](https://github.com/PyTorchLightning/pytorch-lightning/pull/8065)) diff --git a/pytorch_lightning/loops/__init__.py b/pytorch_lightning/loops/__init__.py index f908bd4df05a5..2e56693db6446 100644 --- a/pytorch_lightning/loops/__init__.py +++ b/pytorch_lightning/loops/__init__.py @@ -14,7 +14,7 @@ from pytorch_lightning.loops.base import Loop # noqa: F401 from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop # noqa: F401 -from pytorch_lightning.loops.dataloader.evaluation_dataloader_loop import EvaluationDataLoaderLoop # noqa: F401 +from pytorch_lightning.loops.dataloader.evaluation_dataloader_loop import EvaluationLoop # noqa: F401 from pytorch_lightning.loops.fit_loop import FitLoop # noqa: F401 from pytorch_lightning.loops.training_batch_loop import TrainingBatchLoop # noqa: F401 from pytorch_lightning.loops.training_epoch_loop import TrainingEpochLoop # noqa: F401 diff --git a/pytorch_lightning/loops/dataloader/__init__.py b/pytorch_lightning/loops/dataloader/__init__.py index 47da26d0ba5a5..c77711f2f5c86 100644 --- a/pytorch_lightning/loops/dataloader/__init__.py +++ b/pytorch_lightning/loops/dataloader/__init__.py @@ -13,4 +13,4 @@ # limitations under the License. from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop # noqa: F401 -from pytorch_lightning.loops.dataloader.evaluation_dataloader_loop import EvaluationDataLoaderLoop # noqa: F401 +from pytorch_lightning.loops.dataloader.evaluation_dataloader_loop import EvaluationLoop # noqa: F401 diff --git a/pytorch_lightning/loops/dataloader/evaluation_dataloader_loop.py b/pytorch_lightning/loops/dataloader/evaluation_dataloader_loop.py index e5565d6a8912b..ef94b89a11f4b 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_dataloader_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_dataloader_loop.py @@ -26,7 +26,7 @@ from pytorch_lightning.utilities.types import EPOCH_OUTPUT -class EvaluationDataLoaderLoop(DataLoaderLoop): +class EvaluationLoop(DataLoaderLoop): """Loops over all dataloaders for evaluation.""" def __init__(self): diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 1fd82b7c3e28c..2ddd43c789771 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -29,7 +29,7 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.loggers.tensorboard import TensorBoardLogger -from pytorch_lightning.loops.dataloader.evaluation_dataloader_loop import EvaluationDataLoaderLoop +from pytorch_lightning.loops.dataloader.evaluation_dataloader_loop import EvaluationLoop from pytorch_lightning.loops.fit_loop import FitLoop from pytorch_lightning.plugins import ParallelPlugin, PrecisionPlugin, TrainingTypePlugin from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector @@ -524,10 +524,10 @@ def min_steps(self) -> Optional[int]: @property def is_last_batch(self) -> bool: - return self.fit_loop.training_loop.is_last_batch + return self.fit_loop.epoch_loop.is_last_batch @property - def _active_loop(self) -> Optional[Union[FitLoop, EvaluationDataLoaderLoop]]: + def _active_loop(self) -> Optional[Union[FitLoop, EvaluationLoop]]: if self.training: return self.fit_loop elif self.sanity_checking or self.evaluating: diff --git a/tests/trainer/loops/test_evaluation_loop.py b/tests/trainer/loops/test_evaluation_loop.py index 0d7584628b933..8f3cbaaa3cf00 100644 --- a/tests/trainer/loops/test_evaluation_loop.py +++ b/tests/trainer/loops/test_evaluation_loop.py @@ -22,7 +22,7 @@ @mock.patch( - "pytorch_lightning.loops.dataloader.evaluation_dataloader_loop.EvaluationDataLoaderLoop.on_evaluation_epoch_end" + "pytorch_lightning.loops.dataloader.evaluation_dataloader_loop.EvaluationLoop.on_evaluation_epoch_end" ) def test_on_evaluation_epoch_end(eval_epoch_end_mock, tmpdir): """ From 20d835eecef869aca289170750d062710e7a107e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 22 Jun 2021 12:17:27 +0200 Subject: [PATCH 003/157] proposed rename files --- pytorch_lightning/loops/__init__.py | 8 ++++---- pytorch_lightning/loops/batch/__init__.py | 0 .../loops/{ => batch}/training_batch_loop.py | 0 pytorch_lightning/loops/dataloader/__init__.py | 16 ---------------- .../loops/{dataloader => }/dataloader_loop.py | 0 pytorch_lightning/loops/epoch/__init__.py | 0 .../loops/{ => epoch}/evaluation_epoch_loop.py | 0 .../loops/{ => epoch}/training_epoch_loop.py | 2 +- ...ion_dataloader_loop.py => evaluation_loop.py} | 4 ++-- pytorch_lightning/loops/fit_loop.py | 2 +- pytorch_lightning/trainer/properties.py | 2 +- 11 files changed, 9 insertions(+), 25 deletions(-) create mode 100644 pytorch_lightning/loops/batch/__init__.py rename pytorch_lightning/loops/{ => batch}/training_batch_loop.py (100%) delete mode 100644 pytorch_lightning/loops/dataloader/__init__.py rename pytorch_lightning/loops/{dataloader => }/dataloader_loop.py (100%) create mode 100644 pytorch_lightning/loops/epoch/__init__.py rename pytorch_lightning/loops/{ => epoch}/evaluation_epoch_loop.py (100%) rename pytorch_lightning/loops/{ => epoch}/training_epoch_loop.py (99%) rename pytorch_lightning/loops/{dataloader/evaluation_dataloader_loop.py => evaluation_loop.py} (98%) diff --git a/pytorch_lightning/loops/__init__.py b/pytorch_lightning/loops/__init__.py index 2e56693db6446..06b566f9fdae4 100644 --- a/pytorch_lightning/loops/__init__.py +++ b/pytorch_lightning/loops/__init__.py @@ -13,8 +13,8 @@ # limitations under the License. from pytorch_lightning.loops.base import Loop # noqa: F401 -from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop # noqa: F401 -from pytorch_lightning.loops.dataloader.evaluation_dataloader_loop import EvaluationLoop # noqa: F401 +from pytorch_lightning.loops.dataloader_loop import DataLoaderLoop # noqa: F401 +from pytorch_lightning.loops.evaluation_loop import EvaluationLoop # noqa: F401 from pytorch_lightning.loops.fit_loop import FitLoop # noqa: F401 -from pytorch_lightning.loops.training_batch_loop import TrainingBatchLoop # noqa: F401 -from pytorch_lightning.loops.training_epoch_loop import TrainingEpochLoop # noqa: F401 +from pytorch_lightning.loops.batch.training_batch_loop import TrainingBatchLoop # noqa: F401 +from pytorch_lightning.loops.epoch.training_epoch_loop import TrainingEpochLoop # noqa: F401 diff --git a/pytorch_lightning/loops/batch/__init__.py b/pytorch_lightning/loops/batch/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/pytorch_lightning/loops/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py similarity index 100% rename from pytorch_lightning/loops/training_batch_loop.py rename to pytorch_lightning/loops/batch/training_batch_loop.py diff --git a/pytorch_lightning/loops/dataloader/__init__.py b/pytorch_lightning/loops/dataloader/__init__.py deleted file mode 100644 index c77711f2f5c86..0000000000000 --- a/pytorch_lightning/loops/dataloader/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop # noqa: F401 -from pytorch_lightning.loops.dataloader.evaluation_dataloader_loop import EvaluationLoop # noqa: F401 diff --git a/pytorch_lightning/loops/dataloader/dataloader_loop.py b/pytorch_lightning/loops/dataloader_loop.py similarity index 100% rename from pytorch_lightning/loops/dataloader/dataloader_loop.py rename to pytorch_lightning/loops/dataloader_loop.py diff --git a/pytorch_lightning/loops/epoch/__init__.py b/pytorch_lightning/loops/epoch/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/pytorch_lightning/loops/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py similarity index 100% rename from pytorch_lightning/loops/evaluation_epoch_loop.py rename to pytorch_lightning/loops/epoch/evaluation_epoch_loop.py diff --git a/pytorch_lightning/loops/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py similarity index 99% rename from pytorch_lightning/loops/training_epoch_loop.py rename to pytorch_lightning/loops/epoch/training_epoch_loop.py index a82f4b72e070b..f40276c3c535e 100644 --- a/pytorch_lightning/loops/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -18,7 +18,7 @@ import pytorch_lightning as pl from pytorch_lightning.loops.base import Loop -from pytorch_lightning.loops.training_batch_loop import TrainingBatchLoop +from pytorch_lightning.loops.batch.training_batch_loop import TrainingBatchLoop from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden diff --git a/pytorch_lightning/loops/dataloader/evaluation_dataloader_loop.py b/pytorch_lightning/loops/evaluation_loop.py similarity index 98% rename from pytorch_lightning/loops/dataloader/evaluation_dataloader_loop.py rename to pytorch_lightning/loops/evaluation_loop.py index ef94b89a11f4b..e4a71b9d4607b 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_dataloader_loop.py +++ b/pytorch_lightning/loops/evaluation_loop.py @@ -18,8 +18,8 @@ from torch.utils.data.dataloader import DataLoader import pytorch_lightning as pl -from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop -from pytorch_lightning.loops.evaluation_epoch_loop import EvaluationEpochLoop +from pytorch_lightning.loops.dataloader_loop import DataLoaderLoop +from pytorch_lightning.loops.epoch.evaluation_epoch_loop import EvaluationEpochLoop from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.model_helpers import is_overridden diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index a0f9aed4068d3..ec07b0a21dffe 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -19,7 +19,7 @@ import pytorch_lightning as pl from pytorch_lightning.loops.base import Loop from pytorch_lightning.loops.dataloader.evaluation_dataloader_loop import EvaluationDataLoaderLoop -from pytorch_lightning.loops.training_epoch_loop import TrainingEpochLoop +from pytorch_lightning.loops.epoch.training_epoch_loop import TrainingEpochLoop from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import rank_zero_info diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 2ddd43c789771..ee2ad3c00eaf6 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -29,7 +29,7 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.loggers.tensorboard import TensorBoardLogger -from pytorch_lightning.loops.dataloader.evaluation_dataloader_loop import EvaluationLoop +from pytorch_lightning.loops.evaluation_loop import EvaluationLoop from pytorch_lightning.loops.fit_loop import FitLoop from pytorch_lightning.plugins import ParallelPlugin, PrecisionPlugin, TrainingTypePlugin from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector From bb8a4de1e4952a19aa8d01ac9ad15c4bd6741bec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 22 Jun 2021 12:40:08 +0200 Subject: [PATCH 004/157] imports --- pytorch_lightning/loops/__init__.py | 6 +++--- pytorch_lightning/loops/batch/__init__.py | 15 +++++++++++++++ pytorch_lightning/loops/dataloader/__init__.py | 16 ++++++++++++++++ .../loops/{ => dataloader}/dataloader_loop.py | 0 .../loops/{ => dataloader}/evaluation_loop.py | 2 +- pytorch_lightning/loops/epoch/__init__.py | 16 ++++++++++++++++ pytorch_lightning/trainer/properties.py | 2 +- 7 files changed, 52 insertions(+), 5 deletions(-) create mode 100644 pytorch_lightning/loops/dataloader/__init__.py rename pytorch_lightning/loops/{ => dataloader}/dataloader_loop.py (100%) rename pytorch_lightning/loops/{ => dataloader}/evaluation_loop.py (99%) diff --git a/pytorch_lightning/loops/__init__.py b/pytorch_lightning/loops/__init__.py index 06b566f9fdae4..77ba43b5705a9 100644 --- a/pytorch_lightning/loops/__init__.py +++ b/pytorch_lightning/loops/__init__.py @@ -13,8 +13,8 @@ # limitations under the License. from pytorch_lightning.loops.base import Loop # noqa: F401 -from pytorch_lightning.loops.dataloader_loop import DataLoaderLoop # noqa: F401 -from pytorch_lightning.loops.evaluation_loop import EvaluationLoop # noqa: F401 -from pytorch_lightning.loops.fit_loop import FitLoop # noqa: F401 from pytorch_lightning.loops.batch.training_batch_loop import TrainingBatchLoop # noqa: F401 +from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop # noqa: F401 +from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop # noqa: F401 from pytorch_lightning.loops.epoch.training_epoch_loop import TrainingEpochLoop # noqa: F401 +from pytorch_lightning.loops.fit_loop import FitLoop # noqa: F401 diff --git a/pytorch_lightning/loops/batch/__init__.py b/pytorch_lightning/loops/batch/__init__.py index e69de29bb2d1d..6e6522165404a 100644 --- a/pytorch_lightning/loops/batch/__init__.py +++ b/pytorch_lightning/loops/batch/__init__.py @@ -0,0 +1,15 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pytorch_lightning.loops.batch.training_batch_loop import TrainingBatchLoop # noqa: F401 diff --git a/pytorch_lightning/loops/dataloader/__init__.py b/pytorch_lightning/loops/dataloader/__init__.py new file mode 100644 index 0000000000000..437ddc7c75e9e --- /dev/null +++ b/pytorch_lightning/loops/dataloader/__init__.py @@ -0,0 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop # noqa: F401 +from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop # noqa: F401 diff --git a/pytorch_lightning/loops/dataloader_loop.py b/pytorch_lightning/loops/dataloader/dataloader_loop.py similarity index 100% rename from pytorch_lightning/loops/dataloader_loop.py rename to pytorch_lightning/loops/dataloader/dataloader_loop.py diff --git a/pytorch_lightning/loops/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py similarity index 99% rename from pytorch_lightning/loops/evaluation_loop.py rename to pytorch_lightning/loops/dataloader/evaluation_loop.py index e4a71b9d4607b..c01cbe55d72e4 100644 --- a/pytorch_lightning/loops/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -18,7 +18,7 @@ from torch.utils.data.dataloader import DataLoader import pytorch_lightning as pl -from pytorch_lightning.loops.dataloader_loop import DataLoaderLoop +from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop from pytorch_lightning.loops.epoch.evaluation_epoch_loop import EvaluationEpochLoop from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.states import TrainerFn diff --git a/pytorch_lightning/loops/epoch/__init__.py b/pytorch_lightning/loops/epoch/__init__.py index e69de29bb2d1d..08d0c6a63c342 100644 --- a/pytorch_lightning/loops/epoch/__init__.py +++ b/pytorch_lightning/loops/epoch/__init__.py @@ -0,0 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pytorch_lightning.loops.epoch.evaluation_epoch_loop import EvaluationEpochLoop # noqa: F401 +from pytorch_lightning.loops.epoch.training_epoch_loop import TrainingEpochLoop # noqa: F401 diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index ee2ad3c00eaf6..33082d8e92e05 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -29,7 +29,7 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.loggers.tensorboard import TensorBoardLogger -from pytorch_lightning.loops.evaluation_loop import EvaluationLoop +from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop from pytorch_lightning.loops.fit_loop import FitLoop from pytorch_lightning.plugins import ParallelPlugin, PrecisionPlugin, TrainingTypePlugin from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector From a23eb529e82aacd7c4902bad24f4ba8f0b9eb592 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 22 Jun 2021 10:46:38 +0000 Subject: [PATCH 005/157] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/trainer/loops/test_evaluation_loop.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/trainer/loops/test_evaluation_loop.py b/tests/trainer/loops/test_evaluation_loop.py index 8f3cbaaa3cf00..62740d8a27d5c 100644 --- a/tests/trainer/loops/test_evaluation_loop.py +++ b/tests/trainer/loops/test_evaluation_loop.py @@ -21,9 +21,7 @@ from tests.helpers.runif import RunIf -@mock.patch( - "pytorch_lightning.loops.dataloader.evaluation_dataloader_loop.EvaluationLoop.on_evaluation_epoch_end" -) +@mock.patch("pytorch_lightning.loops.dataloader.evaluation_dataloader_loop.EvaluationLoop.on_evaluation_epoch_end") def test_on_evaluation_epoch_end(eval_epoch_end_mock, tmpdir): """ Tests that `on_evaluation_epoch_end` is called From 7fa3f727d191a58bcd8e1c7c928f7d9c395898cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 23 Jun 2021 11:42:54 +0200 Subject: [PATCH 006/157] bad merge --- pytorch_lightning/loops/fit_loop.py | 7 +++---- .../connectors/logger_connector/logger_connector.py | 2 +- pytorch_lightning/trainer/properties.py | 6 +++--- pytorch_lightning/trainer/trainer.py | 6 +++--- tests/trainer/loops/test_evaluation_loop.py | 2 +- 5 files changed, 11 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index ec07b0a21dffe..8fda6bde5d9cc 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -18,7 +18,7 @@ import pytorch_lightning as pl from pytorch_lightning.loops.base import Loop -from pytorch_lightning.loops.dataloader.evaluation_dataloader_loop import EvaluationDataLoaderLoop +from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop from pytorch_lightning.loops.epoch.training_epoch_loop import TrainingEpochLoop from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.supporters import TensorRunningAccum @@ -52,13 +52,12 @@ def __init__( self.max_epochs = 1000 if (max_epochs is None and max_steps is None) else max_epochs self.min_epochs = 1 if (min_epochs is None and min_steps is None) else min_epochs self.epoch_loop = TrainingEpochLoop(min_steps, max_steps) - self.validation_loop = EvaluationDataLoaderLoop() - self.results = ResultCollection(training=True) + self.validation_loop = EvaluationLoop() @property def results(self) -> ResultCollection: if self.trainer.training: - return self.training_loop.results + return self.epoch_loop.results elif self.trainer.validating: return self.validation_loop.results raise RuntimeError("`FitLoop.results` property isn't defined. Accessed outside of scope") diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 25526a829c0a8..27407fb98c159 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -299,7 +299,7 @@ def progress_bar_metrics(self) -> Dict[str, float]: return self._progress_bar_metrics def teardown(self): - self.trainer.fit_loop.training_loop._results.cpu() + self.trainer.fit_loop.epoch_loop._results.cpu() self.trainer.fit_loop.validation_loop._results.cpu() self.trainer.validation_loop._results.cpu() self.trainer.test_loop._results.cpu() diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 33082d8e92e05..5becc9d78c2aa 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -63,8 +63,8 @@ class TrainerProperties(ABC): logger_connector: LoggerConnector state: TrainerState fit_loop: FitLoop - validation_loop: EvaluationDataLoaderLoop - test_loop: EvaluationDataLoaderLoop + validation_loop: EvaluationLoop + test_loop: EvaluationLoop """ Accelerator properties """ @@ -489,7 +489,7 @@ def sanity_checking(self, val: bool) -> None: """ @property - def evaluation_loop(self) -> EvaluationDataLoaderLoop: + def evaluation_loop(self) -> EvaluationLoop: if self.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING): return self.fit_loop.validation_loop elif self.state.fn == TrainerFn.VALIDATING: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index dd201b49e427b..759906b89deee 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -27,7 +27,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.loops.dataloader.evaluation_dataloader_loop import EvaluationDataLoaderLoop +from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop from pytorch_lightning.loops.dataloader.prediction_dataloader_loop import PredictionDataLoaderLoop from pytorch_lightning.loops.fit_loop import FitLoop from pytorch_lightning.plugins import Plugin @@ -343,8 +343,8 @@ def __init__( self.tuner = Tuner(self) self.fit_loop = FitLoop(min_epochs, max_epochs, min_steps, max_steps) - self.validation_loop = EvaluationDataLoaderLoop() - self.test_loop = EvaluationDataLoaderLoop() + self.validation_loop = EvaluationLoop() + self.test_loop = EvaluationLoop() self.predict_loop = PredictionDataLoaderLoop() self.fit_loop.connect(self) self.validation_loop.connect(self) diff --git a/tests/trainer/loops/test_evaluation_loop.py b/tests/trainer/loops/test_evaluation_loop.py index 62740d8a27d5c..2a0f95a19209b 100644 --- a/tests/trainer/loops/test_evaluation_loop.py +++ b/tests/trainer/loops/test_evaluation_loop.py @@ -21,7 +21,7 @@ from tests.helpers.runif import RunIf -@mock.patch("pytorch_lightning.loops.dataloader.evaluation_dataloader_loop.EvaluationLoop.on_evaluation_epoch_end") +@mock.patch("pytorch_lightning.loops.dataloader.evaluation_loop.EvaluationLoop.on_evaluation_epoch_end") def test_on_evaluation_epoch_end(eval_epoch_end_mock, tmpdir): """ Tests that `on_evaluation_epoch_end` is called From 4657935b7c93d4997699e83720787d4e33204eb8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 23 Jun 2021 11:46:29 +0200 Subject: [PATCH 007/157] prediction loop renaming --- .../{prediction_dataloader_loop.py => prediction_loop.py} | 4 ++-- pytorch_lightning/loops/{ => epoch}/prediction_epoch_loop.py | 0 pytorch_lightning/trainer/trainer.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) rename pytorch_lightning/loops/dataloader/{prediction_dataloader_loop.py => prediction_loop.py} (97%) rename pytorch_lightning/loops/{ => epoch}/prediction_epoch_loop.py (100%) diff --git a/pytorch_lightning/loops/dataloader/prediction_dataloader_loop.py b/pytorch_lightning/loops/dataloader/prediction_loop.py similarity index 97% rename from pytorch_lightning/loops/dataloader/prediction_dataloader_loop.py rename to pytorch_lightning/loops/dataloader/prediction_loop.py index 80077e1e2aaae..542f94fdb087e 100644 --- a/pytorch_lightning/loops/dataloader/prediction_dataloader_loop.py +++ b/pytorch_lightning/loops/dataloader/prediction_loop.py @@ -5,13 +5,13 @@ import pytorch_lightning as pl from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop -from pytorch_lightning.loops.prediction_epoch_loop import PredictionEpochLoop +from pytorch_lightning.loops.epoch.prediction_epoch_loop import PredictionEpochLoop from pytorch_lightning.plugins import DDPSpawnPlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.types import _PREDICT_OUTPUT -class PredictionDataLoaderLoop(DataLoaderLoop): +class PredictionLoop(DataLoaderLoop): """Loop to run over dataloaders for prediction""" def __init__(self): diff --git a/pytorch_lightning/loops/prediction_epoch_loop.py b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py similarity index 100% rename from pytorch_lightning/loops/prediction_epoch_loop.py rename to pytorch_lightning/loops/epoch/prediction_epoch_loop.py diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 759906b89deee..c5ee90cd126ce 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -28,7 +28,7 @@ from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop -from pytorch_lightning.loops.dataloader.prediction_dataloader_loop import PredictionDataLoaderLoop +from pytorch_lightning.loops.dataloader.prediction_loop import PredictionLoop from pytorch_lightning.loops.fit_loop import FitLoop from pytorch_lightning.plugins import Plugin from pytorch_lightning.plugins.environments import ClusterEnvironment @@ -345,7 +345,7 @@ def __init__( self.fit_loop = FitLoop(min_epochs, max_epochs, min_steps, max_steps) self.validation_loop = EvaluationLoop() self.test_loop = EvaluationLoop() - self.predict_loop = PredictionDataLoaderLoop() + self.predict_loop = PredictionLoop() self.fit_loop.connect(self) self.validation_loop.connect(self) self.test_loop.connect(self) From 3b7eaecd471700cbe23137afc8a7cba965e34975 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 23 Jun 2021 11:46:39 +0200 Subject: [PATCH 008/157] update changelog --- CHANGELOG.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d1bfed8ab2ceb..2d93c86957d50 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -146,13 +146,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Simplified "should run validation" logic ([#7682](https://github.com/PyTorchLightning/pytorch-lightning/pull/7682)) * Simplified logic for updating the learning rate for schedulers ([#7682](https://github.com/PyTorchLightning/pytorch-lightning/pull/7682)) * Removed the `on_epoch` guard from the "should stop" validation check ([#7701](https://github.com/PyTorchLightning/pytorch-lightning/pull/7701)) - * Refactored internal loop interface; added new classes `FitLoop`, `TrainingEpochLoop`, `TrainingBatchLoop` ([#7871](https://github.com/PyTorchLightning/pytorch-lightning/pull/7871)) + * Refactored internal loop interface; added new classes `FitLoop`, `TrainingEpochLoop`, `TrainingBatchLoop` ([#7871](https://github.com/PyTorchLightning/pytorch-lightning/pull/7871), [#8077](https://github.com/PyTorchLightning/pytorch-lightning/pull/8077)) * Removed `pytorch_lightning/trainer/training_loop.py` ([#7985](https://github.com/PyTorchLightning/pytorch-lightning/pull/7985)) - * Refactored evaluation loop interface; added new classes `DataLoaderLoop`, `EvaluationLoop`, `EvaluationEpochLoop` ([#7990](https://github.com/PyTorchLightning/pytorch-lightning/pull/7990)) + * Refactored evaluation loop interface; added new classes `DataLoaderLoop`, `EvaluationLoop`, `EvaluationEpochLoop` ([#7990](https://github.com/PyTorchLightning/pytorch-lightning/pull/7990), [#8077](https://github.com/PyTorchLightning/pytorch-lightning/pull/8077)) * Removed `pytorch_lightning/trainer/evaluation_loop.py` ([#8056](https://github.com/PyTorchLightning/pytorch-lightning/pull/8056)) * Restricted public access to several internal functions ([#8024](https://github.com/PyTorchLightning/pytorch-lightning/pull/8024)) * Refactored trainer `_run_*` functions and separate evaluation loops ([#8065](https://github.com/PyTorchLightning/pytorch-lightning/pull/8065)) - * Refactored prediction loop interface; added new classes `PredictionDataLoaderLoop`, `PredictionEpochLoop` ([#7700](https://github.com/PyTorchLightning/pytorch-lightning/pull/7700)) + * Refactored prediction loop interface; added new classes `PredictionLoop`, `PredictionEpochLoop` ([#7700](https://github.com/PyTorchLightning/pytorch-lightning/pull/7700), [#8077](https://github.com/PyTorchLightning/pytorch-lightning/pull/8077)) * Removed `pytorch_lightning/trainer/predict_loop.py` ([#8094](https://github.com/PyTorchLightning/pytorch-lightning/pull/8094)) From 9538c659085160b1e387bbe3e52ef2d68644e41e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 23 Jun 2021 11:58:44 +0200 Subject: [PATCH 009/157] update init files --- pytorch_lightning/loops/__init__.py | 3 +++ pytorch_lightning/loops/dataloader/__init__.py | 1 + pytorch_lightning/loops/epoch/__init__.py | 1 + 3 files changed, 5 insertions(+) diff --git a/pytorch_lightning/loops/__init__.py b/pytorch_lightning/loops/__init__.py index 77ba43b5705a9..fa15a0513ae5d 100644 --- a/pytorch_lightning/loops/__init__.py +++ b/pytorch_lightning/loops/__init__.py @@ -16,5 +16,8 @@ from pytorch_lightning.loops.batch.training_batch_loop import TrainingBatchLoop # noqa: F401 from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop # noqa: F401 from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop # noqa: F401 +from pytorch_lightning.loops.dataloader.prediction_loop import PredictionLoop # noqa: F401 +from pytorch_lightning.loops.epoch.evaluation_epoch_loop import EvaluationEpochLoop # noqa: F401 +from pytorch_lightning.loops.epoch.prediction_epoch_loop import PredictionEpochLoop # noqa: F401 from pytorch_lightning.loops.epoch.training_epoch_loop import TrainingEpochLoop # noqa: F401 from pytorch_lightning.loops.fit_loop import FitLoop # noqa: F401 diff --git a/pytorch_lightning/loops/dataloader/__init__.py b/pytorch_lightning/loops/dataloader/__init__.py index 437ddc7c75e9e..db2b2f7926d50 100644 --- a/pytorch_lightning/loops/dataloader/__init__.py +++ b/pytorch_lightning/loops/dataloader/__init__.py @@ -14,3 +14,4 @@ from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop # noqa: F401 from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop # noqa: F401 +from pytorch_lightning.loops.dataloader.prediction_loop import PredictionLoop # noqa: F401 diff --git a/pytorch_lightning/loops/epoch/__init__.py b/pytorch_lightning/loops/epoch/__init__.py index 08d0c6a63c342..789953937a6b4 100644 --- a/pytorch_lightning/loops/epoch/__init__.py +++ b/pytorch_lightning/loops/epoch/__init__.py @@ -13,4 +13,5 @@ # limitations under the License. from pytorch_lightning.loops.epoch.evaluation_epoch_loop import EvaluationEpochLoop # noqa: F401 +from pytorch_lightning.loops.epoch.prediction_epoch_loop import PredictionEpochLoop # noqa: F401 from pytorch_lightning.loops.epoch.training_epoch_loop import TrainingEpochLoop # noqa: F401 From 5b1367722b6bd0120d4cbc6a882bb71786deeaf3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 23 Jun 2021 15:44:39 +0200 Subject: [PATCH 010/157] fix bad merge --- pytorch_lightning/callbacks/finetuning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/finetuning.py b/pytorch_lightning/callbacks/finetuning.py index a65dce9144fa9..cac4e4c9c857e 100644 --- a/pytorch_lightning/callbacks/finetuning.py +++ b/pytorch_lightning/callbacks/finetuning.py @@ -285,7 +285,7 @@ def _store( def on_train_epoch_start(self, trainer, pl_module): """Called when the epoch begins.""" - for opt_idx, optimizer in trainer.fit_loop.training_loop.batch_loop.get_active_optimizers(): + for opt_idx, optimizer in trainer.fit_loop.epoch_loop.batch_loop.get_active_optimizers(): num_param_groups = len(optimizer.param_groups) self.finetune_function(pl_module, trainer.current_epoch, optimizer, opt_idx) current_param_groups = optimizer.param_groups From 2edb154a94e5bc27b0e6258ac3c8f387797c491c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 23 Jun 2021 23:37:19 +0200 Subject: [PATCH 011/157] glue imports together --- pytorch_lightning/loops/__init__.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/loops/__init__.py b/pytorch_lightning/loops/__init__.py index fa15a0513ae5d..b7eb47167d26f 100644 --- a/pytorch_lightning/loops/__init__.py +++ b/pytorch_lightning/loops/__init__.py @@ -13,11 +13,7 @@ # limitations under the License. from pytorch_lightning.loops.base import Loop # noqa: F401 -from pytorch_lightning.loops.batch.training_batch_loop import TrainingBatchLoop # noqa: F401 -from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop # noqa: F401 -from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop # noqa: F401 -from pytorch_lightning.loops.dataloader.prediction_loop import PredictionLoop # noqa: F401 -from pytorch_lightning.loops.epoch.evaluation_epoch_loop import EvaluationEpochLoop # noqa: F401 -from pytorch_lightning.loops.epoch.prediction_epoch_loop import PredictionEpochLoop # noqa: F401 -from pytorch_lightning.loops.epoch.training_epoch_loop import TrainingEpochLoop # noqa: F401 +from pytorch_lightning.loops.batch import TrainingBatchLoop # noqa: F401 +from pytorch_lightning.loops.dataloader import DataLoaderLoop, EvaluationLoop, PredictionLoop # noqa: F401 +from pytorch_lightning.loops.epoch import EvaluationEpochLoop, PredictionEpochLoop, TrainingEpochLoop # noqa: F401 from pytorch_lightning.loops.fit_loop import FitLoop # noqa: F401 From 6f2733881b2658b091a0ed70472683d59e75d437 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 24 Jun 2021 00:13:46 +0200 Subject: [PATCH 012/157] connect logic for the fit loop --- pytorch_lightning/loops/fit_loop.py | 29 +++++++++++----------------- pytorch_lightning/trainer/trainer.py | 25 +++++++++++++++++++----- 2 files changed, 31 insertions(+), 23 deletions(-) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 8fda6bde5d9cc..ac62691f8001c 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -33,26 +33,19 @@ class FitLoop(Loop): Args: min_epochs: The minimum number of epochs max_epochs: The maximum number of epochs - min_steps: The minimum number of steps - max_steps: The maximum number of epoch .. note:: If neither the minimum epochs nor steps are specified the minimum number of epochs is set to 1 and if neither the maximum steps nor epochs are specified, the maximum epochs are set to 1000. """ - def __init__( - self, - min_epochs: Optional[int] = None, - max_epochs: Optional[int] = None, - min_steps: Optional[int] = None, - max_steps: Optional[int] = None - ): + # FIXME: update the note above + def __init__(self, min_epochs: Optional[int] = None, max_epochs: Optional[int] = None): super().__init__() - self.max_epochs = 1000 if (max_epochs is None and max_steps is None) else max_epochs - self.min_epochs = 1 if (min_epochs is None and min_steps is None) else min_epochs - self.epoch_loop = TrainingEpochLoop(min_steps, max_steps) - self.validation_loop = EvaluationLoop() + self.min_epochs = min_epochs + self.max_epochs = max_epochs + self.epoch_loop: Optional[TrainingEpochLoop] = None + self.validation_loop: Optional[EvaluationLoop] = None @property def results(self) -> ResultCollection: @@ -162,11 +155,11 @@ def skip(self) -> bool: """Whether we should skip the training and immediately return from the call to :meth:`run`.""" return self.done or self.trainer.num_training_batches == 0 - def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None: - """Connects the loop with necessary arguments like the trainer""" - super().connect(trainer, *args, **kwargs) - self.epoch_loop.connect(trainer) - self.validation_loop.connect(trainer) + def connect(self, trainer: 'pl.Trainer', epoch_loop: TrainingEpochLoop, validation_loop: EvaluationLoop) -> None: + """Connects the loop with a trainer and two other loops: a training epoch loop and a validation loop.""" + super().connect(trainer) + self.epoch_loop = epoch_loop + self.validation_loop = validation_loop def reset(self) -> None: """Resets the internal state of this loop""" diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index c5ee90cd126ce..922926720e4a9 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -27,6 +27,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.loggers import LightningLoggerBase +from pytorch_lightning.loops import TrainingEpochLoop from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop from pytorch_lightning.loops.dataloader.prediction_loop import PredictionLoop from pytorch_lightning.loops.fit_loop import FitLoop @@ -342,14 +343,28 @@ def __init__( self.slurm_connector = SLURMConnector(self) self.tuner = Tuner(self) - self.fit_loop = FitLoop(min_epochs, max_epochs, min_steps, max_steps) + # .fit() loop + self.fit_loop = FitLoop( + min_epochs=(1 if (min_epochs is None and min_steps is None) else min_epochs), + max_epochs=(1000 if (max_epochs is None and max_steps is None) else max_epochs), + ) + training_epoch_loop = TrainingEpochLoop(min_steps, max_steps) + validation_epoch_loop = EvaluationLoop() + training_epoch_loop.connect(trainer=self) + validation_epoch_loop.connect(trainer=self) + self.fit_loop.connect(trainer=self, epoch_loop=training_epoch_loop, validation_loop=validation_epoch_loop) + + # .validate() loop self.validation_loop = EvaluationLoop() + self.validation_loop.connect(trainer=self) + + # .test() loop self.test_loop = EvaluationLoop() + self.test_loop.connect(trainer=self) + + # .predict() loop self.predict_loop = PredictionLoop() - self.fit_loop.connect(self) - self.validation_loop.connect(self) - self.test_loop.connect(self) - self.predict_loop.connect(self) + self.predict_loop.connect(trainer=self) # training state if weights_summary is not None and weights_summary not in ModelSummary.MODES: From 7196ca24e4813a9dfc4e3635dd631cf66ae972e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 24 Jun 2021 00:24:23 +0200 Subject: [PATCH 013/157] connect batch loop --- pytorch_lightning/loops/epoch/training_epoch_loop.py | 7 +++---- pytorch_lightning/trainer/trainer.py | 6 ++++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index f40276c3c535e..0eb37714aaecf 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -72,11 +72,10 @@ def done(self) -> bool: max_steps_reached = self.max_steps is not None and self.global_step >= self.max_steps return max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(self.is_last_batch) - def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None: + def connect(self, trainer: 'pl.Trainer', batch_loop: TrainingBatchLoop) -> None: """Connects the loop with all necessary parts like trainer and accelerators""" - super().connect(trainer, *args, **kwargs) - self.batch_loop = TrainingBatchLoop() - self.batch_loop.connect(trainer) + super().connect(trainer) + self.batch_loop = batch_loop def reset(self) -> None: """Resets the internal state of the loop for a new run""" diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 922926720e4a9..4348f5ec7b10c 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -27,7 +27,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.loops import TrainingEpochLoop +from pytorch_lightning.loops import TrainingEpochLoop, TrainingBatchLoop from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop from pytorch_lightning.loops.dataloader.prediction_loop import PredictionLoop from pytorch_lightning.loops.fit_loop import FitLoop @@ -349,8 +349,10 @@ def __init__( max_epochs=(1000 if (max_epochs is None and max_steps is None) else max_epochs), ) training_epoch_loop = TrainingEpochLoop(min_steps, max_steps) + training_batch_loop = TrainingBatchLoop() validation_epoch_loop = EvaluationLoop() - training_epoch_loop.connect(trainer=self) + training_epoch_loop.connect(trainer=self, batch_loop=training_batch_loop) + training_batch_loop.connect(trainer=self) validation_epoch_loop.connect(trainer=self) self.fit_loop.connect(trainer=self, epoch_loop=training_epoch_loop, validation_loop=validation_epoch_loop) From f837389ce68364a14c6759d59501c25f06a3d683 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 30 Jun 2021 16:04:28 +0200 Subject: [PATCH 014/157] merge --- pytorch_lightning/loops/epoch/training_epoch_loop.py | 11 +++++------ pytorch_lightning/loops/fit_loop.py | 12 ++++++------ pytorch_lightning/trainer/trainer.py | 4 ++-- 3 files changed, 13 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index cd8b992b09d45..061fc1d49fe80 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -77,13 +77,12 @@ def done(self) -> bool: max_steps_reached = self.max_steps is not None and self.global_step >= self.max_steps return max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(self.is_last_batch) - def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None: + def connect(self, trainer: 'pl.Trainer', batch_loop, val_loop) -> None: """Connects the loop with all necessary parts like trainer and accelerators""" - super().connect(trainer, *args, **kwargs) - self.batch_loop = TrainingBatchLoop() - self.batch_loop.connect(trainer) - self.val_loop = loops.EvaluationLoop() - self.val_loop.connect(trainer) + super().connect(trainer) + self.batch_loop = batch_loop# or TrainingBatchLoop() + self.val_loop = val_loop #or loops.EvaluationLoop() + # self.val_loop.connect(trainer) def reset(self) -> None: """Resets the internal state of the loop for a new run""" diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index a7ebfbf3b4a41..1fa584b2d4c52 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -41,9 +41,9 @@ class FitLoop(Loop): # FIXME: update the note above def __init__(self, min_epochs: Optional[int] = None, max_epochs: Optional[int] = None): super().__init__() - self.max_epochs = 1000 if (max_epochs is None and max_steps is None) else max_epochs - self.min_epochs = 1 if (min_epochs is None and min_steps is None) else min_epochs - self.epoch_loop = TrainingEpochLoop(min_steps, max_steps) + self.max_epochs = min_epochs + self.min_epochs = max_epochs + self.epoch_loop = None @property def results(self) -> ResultCollection: @@ -149,10 +149,10 @@ def skip(self) -> bool: """Whether we should skip the training and immediately return from the call to :meth:`run`.""" return self.done or self.trainer.num_training_batches == 0 - def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None: + def connect(self, trainer: 'pl.Trainer', epoch_loop) -> None: """Connects the loop with necessary arguments like the trainer""" - super().connect(trainer, *args, **kwargs) - self.epoch_loop.connect(trainer) + super().connect(trainer) + self.epoch_loop = epoch_loop def reset(self) -> None: """Resets the internal state of this loop""" diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 869bc74af3523..2f6d57adc4be2 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -354,10 +354,10 @@ def __init__( training_epoch_loop = TrainingEpochLoop(min_steps, max_steps) training_batch_loop = TrainingBatchLoop() validation_epoch_loop = EvaluationLoop() - training_epoch_loop.connect(trainer=self, batch_loop=training_batch_loop) + training_epoch_loop.connect(trainer=self, batch_loop=training_batch_loop, val_loop=validation_epoch_loop) training_batch_loop.connect(trainer=self) validation_epoch_loop.connect(trainer=self) - self.fit_loop.connect(trainer=self, epoch_loop=training_epoch_loop, validation_loop=validation_epoch_loop) + self.fit_loop.connect(trainer=self, epoch_loop=training_epoch_loop) # .validate() loop self.validation_loop = EvaluationLoop() From 7f9d75c38a13ecac79dd6a708d4d4eb2de879f13 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 30 Jun 2021 14:10:10 +0000 Subject: [PATCH 015/157] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/loops/epoch/training_epoch_loop.py | 4 ++-- pytorch_lightning/trainer/trainer.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 061fc1d49fe80..fe9856a7557fe 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -80,8 +80,8 @@ def done(self) -> bool: def connect(self, trainer: 'pl.Trainer', batch_loop, val_loop) -> None: """Connects the loop with all necessary parts like trainer and accelerators""" super().connect(trainer) - self.batch_loop = batch_loop# or TrainingBatchLoop() - self.val_loop = val_loop #or loops.EvaluationLoop() + self.batch_loop = batch_loop # or TrainingBatchLoop() + self.val_loop = val_loop #or loops.EvaluationLoop() # self.val_loop.connect(trainer) def reset(self) -> None: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 2f6d57adc4be2..be466aa705a14 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -28,7 +28,7 @@ from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.loops import TrainingEpochLoop, TrainingBatchLoop +from pytorch_lightning.loops import TrainingBatchLoop, TrainingEpochLoop from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop from pytorch_lightning.loops.dataloader.prediction_loop import PredictionLoop from pytorch_lightning.loops.fit_loop import FitLoop From 4079283fda9450407ac2083f95e8894bf2957000 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 6 Jul 2021 16:11:22 +0200 Subject: [PATCH 016/157] wip --- pytorch_lightning/loops/fit_loop.py | 12 ++++++++---- pytorch_lightning/trainer/trainer.py | 19 ++++++++++++++----- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 84e934ceaf6c3..1b1e11efa3774 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -45,6 +45,7 @@ def __init__(self, min_epochs: Optional[int] = None, max_epochs: Optional[int] = self.max_epochs = min_epochs self.min_epochs = max_epochs self.epoch_loop = None + self.progress: Optional[FitLoopProgress] = None @property def current_epoch(self) -> int: @@ -161,16 +162,19 @@ def skip(self) -> bool: return self.done or self.trainer.num_training_batches == 0 def connect( - self, trainer: 'pl.Trainer', *args: Any, progress: Optional[FitLoopProgress] = None, **kwargs: Any + self, + trainer: 'pl.Trainer', + epoch_loop, + *args: Any, + progress: Optional[FitLoopProgress] = None, + **kwargs: Any, ) -> None: - def connect(self, trainer: 'pl.Trainer', epoch_loop) -> None: """Connects the loop with necessary arguments like the trainer""" super().connect(trainer, *args, **kwargs) if progress is not None: self.progress = progress - self.epoch_loop.connect(trainer, progress=self.progress.epoch) - super().connect(trainer) self.epoch_loop = epoch_loop + self.epoch_loop.connect(trainer, progress=self.progress.epoch) def reset(self) -> None: """Resets the internal state of this loop""" diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index d58bbebe2e711..24a835baaf8ae 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -361,14 +361,17 @@ def __init__( self.fit_loop.connect(trainer=self, epoch_loop=training_epoch_loop) # .validate() loop + # TODO: connect progress self.validate_loop = EvaluationLoop() - self.validate.connect(trainer=self) + self.validate_loop.connect(trainer=self) # .test() loop + # TODO: connect progress self.test_loop = EvaluationLoop() self.test_loop.connect(trainer=self) # .predict() loop + # TODO: connect progress self.predict_loop = PredictionLoop() self.predict_loop.connect(trainer=self) @@ -1016,11 +1019,17 @@ def _run_evaluate(self) -> _EVALUATE_OUTPUT: assert self.evaluating + if self.validating or self.sanity_checking: + loop = self.validate_loop + else: + assert self.testing + loop = self.test_loop + # reload dataloaders - self._evaluation_loop.reload_evaluation_dataloaders() + loop.reload_evaluation_dataloaders() with self.profiler.profile(f"run_{self.state.stage}_evaluation"), torch.no_grad(): - eval_loop_results = self._evaluation_loop.run() + eval_loop_results = loop.run() # remove the tensors from the eval results for i, result in enumerate(eval_loop_results): @@ -1050,11 +1059,11 @@ def _run_sanity_check(self, ref_model): self.on_sanity_check_start() # reload dataloaders - self._evaluation_loop.reload_evaluation_dataloaders() + self.validate_loop.reload_evaluation_dataloaders() # run eval step with torch.no_grad(): - self._evaluation_loop.run() + self.validate_loop.run() self.on_sanity_check_end() From 198fd2a880eeac172d2edd26ca4c36707176db32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 6 Jul 2021 16:16:27 +0200 Subject: [PATCH 017/157] undo --- pytorch_lightning/trainer/trainer.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 24a835baaf8ae..587a0e7121773 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -355,20 +355,21 @@ def __init__( training_epoch_loop = TrainingEpochLoop(min_steps, max_steps) training_batch_loop = TrainingBatchLoop() validation_epoch_loop = EvaluationLoop() + training_epoch_loop.connect(trainer=self, batch_loop=training_batch_loop, val_loop=validation_epoch_loop) training_batch_loop.connect(trainer=self) validation_epoch_loop.connect(trainer=self) - self.fit_loop.connect(trainer=self, epoch_loop=training_epoch_loop) + self.fit_loop.connect(trainer=self, epoch_loop=training_epoch_loop, progress=FitLoopProgress()) # .validate() loop # TODO: connect progress self.validate_loop = EvaluationLoop() - self.validate_loop.connect(trainer=self) + self.validate_loop.connect(trainer=self, progress=EpochLoopProgress()) # .test() loop # TODO: connect progress self.test_loop = EvaluationLoop() - self.test_loop.connect(trainer=self) + self.test_loop.connect(trainer=self, progress=EpochLoopProgress()) # .predict() loop # TODO: connect progress @@ -1019,17 +1020,11 @@ def _run_evaluate(self) -> _EVALUATE_OUTPUT: assert self.evaluating - if self.validating or self.sanity_checking: - loop = self.validate_loop - else: - assert self.testing - loop = self.test_loop - # reload dataloaders - loop.reload_evaluation_dataloaders() + self._evaluation_loop.reload_evaluation_dataloaders() with self.profiler.profile(f"run_{self.state.stage}_evaluation"), torch.no_grad(): - eval_loop_results = loop.run() + eval_loop_results = self._evaluation_loop.run() # remove the tensors from the eval results for i, result in enumerate(eval_loop_results): @@ -1059,11 +1054,11 @@ def _run_sanity_check(self, ref_model): self.on_sanity_check_start() # reload dataloaders - self.validate_loop.reload_evaluation_dataloaders() + self._evaluation_loop.reload_evaluation_dataloaders() # run eval step with torch.no_grad(): - self.validate_loop.run() + self._evaluation_loop.run() self.on_sanity_check_end() From aa85ce4022396ab9d299820accfd12ba2a69a006 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 7 Jul 2021 22:50:44 +0200 Subject: [PATCH 018/157] conflict --- pytorch_lightning/loops/epoch/training_epoch_loop.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 662ac1dd596a8..c0eb4862961d4 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -72,22 +72,20 @@ def done(self) -> bool: def connect( self, trainer: 'pl.Trainer', + batch_loop, + val_loop, *args: Any, progress: Optional[TrainingEpochProgress] = None, **kwargs: Any ) -> None: """Connects the loop with necessary arguments like the trainer""" super().connect(trainer, *args, **kwargs) + self.batch_loop = batch_loop # or TrainingBatchLoop() + self.val_loop = val_loop # or loops.EvaluationLoop() if progress is not None: self.progress = progress self.batch_loop.connect(trainer, progress=self.progress.batch, optim_progress=self.progress.optim) self.val_loop.connect(trainer, progress=self.progress.val) - def connect(self, trainer: 'pl.Trainer', batch_loop, val_loop) -> None: - """Connects the loop with all necessary parts like trainer and accelerators""" - super().connect(trainer) - self.batch_loop = batch_loop # or TrainingBatchLoop() - self.val_loop = val_loop #or loops.EvaluationLoop() - # self.val_loop.connect(trainer) def reset(self) -> None: """Resets the internal state of the loop for a new run""" From 8be321772a00bf15bc8367bef3bca049eac492d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 8 Jul 2021 10:00:46 +0200 Subject: [PATCH 019/157] link loops --- pl_examples/bug_report_model.py | 17 ++++++ pytorch_lightning/loops/base.py | 6 +- .../loops/batch/training_batch_loop.py | 2 +- .../loops/dataloader/evaluation_loop.py | 7 ++- .../loops/dataloader/prediction_loop.py | 6 +- .../loops/epoch/evaluation_epoch_loop.py | 2 +- .../loops/epoch/prediction_epoch_loop.py | 2 +- .../loops/epoch/training_epoch_loop.py | 11 ++-- pytorch_lightning/loops/fit_loop.py | 8 ++- pytorch_lightning/trainer/trainer.py | 58 ++++++++++++++----- 10 files changed, 89 insertions(+), 30 deletions(-) diff --git a/pl_examples/bug_report_model.py b/pl_examples/bug_report_model.py index f906ab9bde77c..649e5844ff969 100644 --- a/pl_examples/bug_report_model.py +++ b/pl_examples/bug_report_model.py @@ -4,6 +4,8 @@ from torch.utils.data import DataLoader, Dataset from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning.loops import FitLoop, TrainingEpochLoop, EvaluationLoop, TrainingBatchLoop +from pytorch_lightning.trainer.progress import FitLoopProgress class RandomDataset(Dataset): @@ -51,6 +53,7 @@ def run(): test_data = DataLoader(RandomDataset(32, 64), batch_size=2) model = BoringModel() + trainer = Trainer( default_root_dir=os.getcwd(), limit_train_batches=1, @@ -59,6 +62,20 @@ def run(): max_epochs=1, weights_summary=None, ) + + # construct loops + fit_loop = FitLoop() + train_epoch_loop = TrainingEpochLoop(min_steps=0, max_steps=2) + train_batch_loop = TrainingBatchLoop() + val_loop = EvaluationLoop() + + # link loops + train_epoch_loop.link(batch_loop=train_batch_loop, val_loop=val_loop) + fit_loop.link(epoch_loop=train_epoch_loop) + + # connect fit loop to trainer + trainer.fit_loop = fit_loop + trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data) trainer.test(model, dataloaders=test_data) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 3293b3eba29ab..564699f9c92bb 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -58,8 +58,12 @@ def skip(self) -> bool: """Determine whether to return immediately from the call to :meth:`run`.""" return False + def link(self, **kwargs: "Loop"): + """Optionally link one or multiple loops to this one. Linked loops should form a tree.""" + raise NotImplementedError(f"{self.__class__.__name__} does link any child loops.") + def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None: - """Connects Loop with all the necessary things like connectors and accelerators.""" + """Called by the Trainer. Connects a Loop with all the necessary components like progress, etc.""" # TODO(@justusschock): Make the trainer a weakref/proxy if not isinstance(trainer, pl.Trainer): raise MisconfigurationException( diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 9b803a2790d9d..b6d9366b6db50 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -67,7 +67,7 @@ def connect( optim_progress: Optional[OptimizationProgress] = None, **kwargs: Any ) -> None: - """Connects the loop with necessary arguments like the trainer""" + """Called by the Trainer. Connects a Loop with all the necessary components like progress, etc.""" super().connect(trainer, *args, **kwargs) if progress is not None: self.progress = progress diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 2f6e14b93b767..e278897995453 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -34,7 +34,6 @@ def __init__(self): super().__init__() self.outputs = [] self.progress = EpochLoopProgress() - self.epoch_loop = EvaluationEpochLoop() self._results = ResultCollection(training=False) @@ -66,10 +65,14 @@ def predictions(self): """Returns the predictions from all dataloaders""" return self.epoch_loop.predictions + def link(self, epoch_loop: EvaluationEpochLoop): + """Links the evaluation epoch loop with this loop.""" + self.epoch_loop = epoch_loop + def connect( self, trainer: "pl.Trainer", *args: Any, progress: Optional[EpochLoopProgress] = None, **kwargs: Any ) -> None: - """Connects the loop with necessary arguments like the trainer""" + """Called by the Trainer. Connects a Loop with all the necessary components like progress, etc.""" super().connect(trainer, *args, **kwargs) if progress is not None: self.progress = progress diff --git a/pytorch_lightning/loops/dataloader/prediction_loop.py b/pytorch_lightning/loops/dataloader/prediction_loop.py index 55647e5d7f2a3..c1f7c0cf7d464 100644 --- a/pytorch_lightning/loops/dataloader/prediction_loop.py +++ b/pytorch_lightning/loops/dataloader/prediction_loop.py @@ -20,7 +20,6 @@ def __init__(self): self.predictions: Optional[List[List[Any]]] = None self.epoch_batch_indices: Optional[List[List[int]]] = None self.progress = EpochLoopProgress() - self.epoch_loop = PredictionEpochLoop() self._results = None # for `trainer._results` access @@ -76,10 +75,13 @@ def done(self) -> bool: def skip(self) -> bool: return sum(self.max_batches) == 0 + def link(self, epoch_loop: PredictionEpochLoop): + self.epoch_loop = epoch_loop + def connect( self, trainer: "pl.Trainer", *args: Any, progress: Optional[EpochLoopProgress] = None, **kwargs: Any ) -> None: - """Connects the loop with necessary arguments like the trainer""" + """Called by the Trainer. Connects a Loop with all the necessary components like progress, etc.""" super().connect(trainer, *args, **kwargs) if progress is not None: self.progress = progress diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index c01b20a5f84e2..9e0b6a409f53b 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -46,7 +46,7 @@ def __init__(self) -> None: def connect( self, trainer: "pl.Trainer", *args: Any, progress: Optional[EpochProgress] = None, **kwargs: Any ) -> None: - """Connects the loop with necessary arguments like the trainer""" + """Called by the Trainer. Connects a Loop with all the necessary components like progress, etc.""" super().connect(trainer, *args, **kwargs) if progress is not None: self.progress = progress diff --git a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py index ea03be5ef0096..e4a3f894644e3 100644 --- a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py @@ -28,7 +28,7 @@ def __init__(self) -> None: def connect( self, trainer: "pl.Trainer", *args: Any, progress: Optional[EpochProgress] = None, **kwargs: Any ) -> None: - """Connects the loop with necessary arguments like the trainer""" + """Called by the Trainer. Connects a Loop with all the necessary components like progress, etc.""" super().connect(trainer, *args, **kwargs) if progress is not None: self.progress = progress diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index c0eb4862961d4..21fbcd92f0331 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -69,19 +69,20 @@ def done(self) -> bool: max_steps_reached = self.max_steps is not None and self.global_step >= self.max_steps return max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(self.is_last_batch) + def link(self, batch_loop, val_loop): + """Links a batch loop and a validation loop to this training epoch loop.""" + self.batch_loop = batch_loop + self.val_loop = val_loop + def connect( self, trainer: 'pl.Trainer', - batch_loop, - val_loop, *args: Any, progress: Optional[TrainingEpochProgress] = None, **kwargs: Any ) -> None: - """Connects the loop with necessary arguments like the trainer""" + """Called by the Trainer. Connects a Loop with all the necessary components like progress, etc.""" super().connect(trainer, *args, **kwargs) - self.batch_loop = batch_loop # or TrainingBatchLoop() - self.val_loop = val_loop # or loops.EvaluationLoop() if progress is not None: self.progress = progress self.batch_loop.connect(trainer, progress=self.progress.batch, optim_progress=self.progress.optim) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index feb5d72917f95..3adc787bf3efc 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -161,19 +161,21 @@ def skip(self) -> bool: """Whether we should skip the training and immediately return from the call to :meth:`run`.""" return self.done or self.trainer.num_training_batches == 0 + def link(self, epoch_loop: TrainingEpochLoop): + """Links a training epoch loop to this fit loop.""" + self.epoch_loop = epoch_loop + def connect( self, trainer: 'pl.Trainer', - epoch_loop, *args: Any, progress: Optional[FitLoopProgress] = None, **kwargs: Any, ) -> None: - """Connects the loop with necessary arguments like the trainer""" + """Called by the Trainer. Connects a Loop with all the necessary components like progress, etc.""" super().connect(trainer, *args, **kwargs) if progress is not None: self.progress = progress - self.epoch_loop = epoch_loop self.epoch_loop.connect(trainer, progress=self.progress.epoch) def reset(self) -> None: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index a1b869e16b7a9..22a977fb48248 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -359,33 +359,27 @@ def __init__( self.tuner = Tuner(self) # .fit() loop - self.fit_loop = FitLoop( + fit_loop = FitLoop( min_epochs=(1 if (min_epochs is None and min_steps is None) else min_epochs), max_epochs=(1000 if (max_epochs is None and max_steps is None) else max_epochs), ) training_epoch_loop = TrainingEpochLoop(min_steps, max_steps) training_batch_loop = TrainingBatchLoop() validation_epoch_loop = EvaluationLoop() + training_epoch_loop.link(batch_loop=training_batch_loop, val_loop=validation_epoch_loop) + fit_loop.link(epoch_loop=training_epoch_loop) - training_epoch_loop.connect(trainer=self, batch_loop=training_batch_loop, val_loop=validation_epoch_loop) - training_batch_loop.connect(trainer=self) - validation_epoch_loop.connect(trainer=self) - self.fit_loop.connect(trainer=self, epoch_loop=training_epoch_loop, progress=FitLoopProgress()) + # default .fit() loop + self.fit_loop = fit_loop - # .validate() loop - # TODO: connect progress + # default .validate() loop self.validate_loop = EvaluationLoop() - self.validate_loop.connect(trainer=self, progress=EpochLoopProgress()) - # .test() loop - # TODO: connect progress + # default .test() loop self.test_loop = EvaluationLoop() - self.test_loop.connect(trainer=self, progress=EpochLoopProgress()) - # .predict() loop - # TODO: connect progress + # default .predict() loop self.predict_loop = PredictionLoop() - self.predict_loop.connect(trainer=self) # training state if weights_summary is not None and weights_summary not in ModelSummary.MODES: @@ -459,6 +453,42 @@ def __init__( # Callback system self.on_init_end() + @property + def fit_loop(self): + return self._fit_loop + + @fit_loop.setter + def fit_loop(self, loop: FitLoop): + self._fit_loop = loop + self._fit_loop.connect(self, progress=FitLoopProgress()) + + @property + def validate_loop(self): + return self._validate_loop + + @validate_loop.setter + def validate_loop(self, loop: EvaluationLoop): + self._validate_loop = loop + self._validate_loop.connect(self, progress=EpochLoopProgress()) + + @property + def test_loop(self): + return self._test_loop + + @test_loop.setter + def test_loop(self, loop: EvaluationLoop): + self._test_loop = loop + self._test_loop.connect(self, progress=EpochLoopProgress()) + + @property + def predict_loop(self): + return self._predict_loop + + @predict_loop.setter + def predict_loop(self, loop: PredictionLoop): + self._predict_loop = loop + self._predict_loop.connect(self, progress=EpochLoopProgress()) + def _setup_on_init( self, num_sanity_val_steps: int, From caa1caac1d39119527ec2cb4cd37c6dd7a7b3c16 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 8 Jul 2021 08:02:02 +0000 Subject: [PATCH 020/157] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pl_examples/bug_report_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_examples/bug_report_model.py b/pl_examples/bug_report_model.py index 649e5844ff969..56a4e65a2bd20 100644 --- a/pl_examples/bug_report_model.py +++ b/pl_examples/bug_report_model.py @@ -4,7 +4,7 @@ from torch.utils.data import DataLoader, Dataset from pytorch_lightning import LightningModule, Trainer -from pytorch_lightning.loops import FitLoop, TrainingEpochLoop, EvaluationLoop, TrainingBatchLoop +from pytorch_lightning.loops import EvaluationLoop, FitLoop, TrainingBatchLoop, TrainingEpochLoop from pytorch_lightning.trainer.progress import FitLoopProgress From 0b81e02c450e2628a0fa729ccbdec2f331ab855a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 8 Jul 2021 10:52:22 +0200 Subject: [PATCH 021/157] examples --- pl_examples/bug_report_model.py | 17 --- pl_examples/example1.py | 84 ++++++++++++++ pl_examples/example2.py | 200 ++++++++++++++++++++++++++++++++ 3 files changed, 284 insertions(+), 17 deletions(-) create mode 100644 pl_examples/example1.py create mode 100644 pl_examples/example2.py diff --git a/pl_examples/bug_report_model.py b/pl_examples/bug_report_model.py index 56a4e65a2bd20..f906ab9bde77c 100644 --- a/pl_examples/bug_report_model.py +++ b/pl_examples/bug_report_model.py @@ -4,8 +4,6 @@ from torch.utils.data import DataLoader, Dataset from pytorch_lightning import LightningModule, Trainer -from pytorch_lightning.loops import EvaluationLoop, FitLoop, TrainingBatchLoop, TrainingEpochLoop -from pytorch_lightning.trainer.progress import FitLoopProgress class RandomDataset(Dataset): @@ -53,7 +51,6 @@ def run(): test_data = DataLoader(RandomDataset(32, 64), batch_size=2) model = BoringModel() - trainer = Trainer( default_root_dir=os.getcwd(), limit_train_batches=1, @@ -62,20 +59,6 @@ def run(): max_epochs=1, weights_summary=None, ) - - # construct loops - fit_loop = FitLoop() - train_epoch_loop = TrainingEpochLoop(min_steps=0, max_steps=2) - train_batch_loop = TrainingBatchLoop() - val_loop = EvaluationLoop() - - # link loops - train_epoch_loop.link(batch_loop=train_batch_loop, val_loop=val_loop) - fit_loop.link(epoch_loop=train_epoch_loop) - - # connect fit loop to trainer - trainer.fit_loop = fit_loop - trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data) trainer.test(model, dataloaders=test_data) diff --git a/pl_examples/example1.py b/pl_examples/example1.py new file mode 100644 index 0000000000000..56a4e65a2bd20 --- /dev/null +++ b/pl_examples/example1.py @@ -0,0 +1,84 @@ +import os + +import torch +from torch.utils.data import DataLoader, Dataset + +from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning.loops import EvaluationLoop, FitLoop, TrainingBatchLoop, TrainingEpochLoop +from pytorch_lightning.trainer.progress import FitLoopProgress + + +class RandomDataset(Dataset): + + def __init__(self, size, length): + self.len = length + self.data = torch.randn(length, size) + + def __getitem__(self, index): + return self.data[index] + + def __len__(self): + return self.len + + +class BoringModel(LightningModule): + + def __init__(self): + super().__init__() + self.layer = torch.nn.Linear(32, 2) + + def forward(self, x): + return self.layer(x) + + def training_step(self, batch, batch_idx): + loss = self(batch).sum() + self.log("train_loss", loss) + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + loss = self(batch).sum() + self.log("valid_loss", loss) + + def test_step(self, batch, batch_idx): + loss = self(batch).sum() + self.log("test_loss", loss) + + def configure_optimizers(self): + return torch.optim.SGD(self.layer.parameters(), lr=0.1) + + +def run(): + train_data = DataLoader(RandomDataset(32, 64), batch_size=2) + val_data = DataLoader(RandomDataset(32, 64), batch_size=2) + test_data = DataLoader(RandomDataset(32, 64), batch_size=2) + + model = BoringModel() + + trainer = Trainer( + default_root_dir=os.getcwd(), + limit_train_batches=1, + limit_val_batches=1, + num_sanity_val_steps=0, + max_epochs=1, + weights_summary=None, + ) + + # construct loops + fit_loop = FitLoop() + train_epoch_loop = TrainingEpochLoop(min_steps=0, max_steps=2) + train_batch_loop = TrainingBatchLoop() + val_loop = EvaluationLoop() + + # link loops + train_epoch_loop.link(batch_loop=train_batch_loop, val_loop=val_loop) + fit_loop.link(epoch_loop=train_epoch_loop) + + # connect fit loop to trainer + trainer.fit_loop = fit_loop + + trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data) + trainer.test(model, dataloaders=test_data) + + +if __name__ == '__main__': + run() diff --git a/pl_examples/example2.py b/pl_examples/example2.py new file mode 100644 index 0000000000000..fc95e5e4de27c --- /dev/null +++ b/pl_examples/example2.py @@ -0,0 +1,200 @@ +from collections import OrderedDict +from typing import Any, Union, Dict, Optional, Iterator, Tuple + +import torch +from torch import Tensor +from torch.optim import Optimizer + +from pytorch_lightning.loops import Loop +from pytorch_lightning.trainer.connectors.logger_connector.result import ( + ResultCollection, +) + +import os + +import torch +from torch.utils.data import DataLoader, Dataset + +from pytorch_lightning import LightningModule, Trainer + + + + +class SimpleLoop(Loop): + """This loop is for demonstration purposes only.""" + + def __init__(self, num_iterations: int = float("inf")): + super().__init__() + self.num_iterations = num_iterations + self.train_dataloader: Optional[Iterator] = None + + # required for trainer and logger connector + self._results = ResultCollection(training=True) + + @property + def global_step(self) -> int: + return self.iteration_count + + @property + def batch_idx(self) -> int: + # required by progress bar + return self.iteration_count + + @property + def running_loss(self) -> Tensor: + # required by progress bar + return torch.tensor(123.) + + @property + def current_epoch(self) -> int: + return 0 + + @property + def skip(self) -> bool: + return self.done or self.trainer.num_training_batches == 0 + + @property + def done(self) -> bool: + return self.iteration_count >= self.num_iterations + + def reset(self) -> None: + self.iteration_count = 0 + + def on_run_start(self) -> None: + self.train_dataloader = iter( + self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader) + ) + self.trainer.call_hook("on_train_start") + + def advance(self) -> None: + batch = next(self.train_dataloader) + + opt_idx = 0 + optimizer = self.trainer.optimizers[opt_idx] + + self.trainer.call_hook("on_train_batch_start", batch, self.iteration_count, dataloader_idx=0) + + output = self._run_optimization(batch, self.iteration_count, optimizer) + + # hook + self.trainer.call_hook( + "on_train_batch_end", output, batch, self.iteration_count, dataloader_idx=0 + ) + self.trainer.call_hook("on_batch_end") + + def on_run_end(self) -> None: + self.trainer.call_hook("on_train_end") + self.trainer.accelerator.on_train_end() + self.trainer._running_stage = None + + def _run_optimization(self, batch: Any, batch_idx: int, optimizer: Optimizer): + lightning_module = self.trainer.lightning_module + + # lightning module training_step + step_kwargs = OrderedDict([("batch", batch), ("batch_idx", batch_idx)]) + lightning_module._current_fx_name = "training_step" + training_step_output = self.trainer.accelerator.training_step(step_kwargs) + self.trainer.accelerator.post_training_step() + + training_step_output = self.trainer.call_hook( + "training_step_end", training_step_output + ) + loss, extra = self._process_training_step_output(training_step_output) + + # backward pass (single optimizer, no accumulation supported) + self.trainer.accelerator.backward( + loss, optimizer, optimizer_idx=0, should_accumulate=False + ) + + # optimizer step (no closures supported) + lightning_module.optimizer_step(optimizer=optimizer) + + output = extra + output["loss"] = loss.detach() + return output + + @staticmethod + def _process_training_step_output( + training_step_output: Union[Dict, Tensor] + ) -> Tuple[Tensor, Dict]: + loss = None + extra = {} + + if isinstance(training_step_output, dict): + loss = training_step_output.pop("loss") + extra = training_step_output + + elif isinstance(training_step_output, Tensor): + loss = training_step_output + + return loss, extra + + +class RandomDataset(Dataset): + + def __init__(self, size, length): + self.len = length + self.data = torch.randn(length, size) + + def __getitem__(self, index): + return self.data[index] + + def __len__(self): + return self.len + + +class BoringModel(LightningModule): + + def __init__(self): + super().__init__() + self.layer = torch.nn.Linear(32, 2) + + def forward(self, x): + return self.layer(x) + + def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + self.print("batch start:", batch_idx) + + def training_step(self, batch, batch_idx): + self.print("training_step:", batch_idx) + loss = self(batch).sum() + self.log("train_loss", loss) + return {"loss": loss} + + def backward(self, loss, *args, **kwargs): + self.print("backward:", loss) + return super().backward(loss, *args, **kwargs) + + def optimizer_step(self, *args, **kwargs): + self.print("optimizer_step") + return super().optimizer_step(*args, **kwargs) + + def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + self.print("batch end:", batch_idx) + + def configure_optimizers(self): + return torch.optim.SGD(self.layer.parameters(), lr=0.1) + + +def run(): + train_data = DataLoader(RandomDataset(32, 64), batch_size=2) + + model = BoringModel() + trainer = Trainer( + default_root_dir=os.getcwd(), + limit_train_batches=1, + limit_val_batches=1, + num_sanity_val_steps=0, + max_epochs=1, + weights_summary=None, + progress_bar_refresh_rate=1, + ) + + simple_loop = SimpleLoop(num_iterations=1000) + trainer.fit_loop = simple_loop + + trainer.fit(model, train_dataloader=train_data) + + +if __name__ == '__main__': + run() \ No newline at end of file From 1279cb95c7d2de2f3f0b19c41e35187f85da764f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 8 Jul 2021 08:53:34 +0000 Subject: [PATCH 022/157] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pl_examples/example2.py | 38 ++++++++++---------------------------- 1 file changed, 10 insertions(+), 28 deletions(-) diff --git a/pl_examples/example2.py b/pl_examples/example2.py index fc95e5e4de27c..498c11700cfe2 100644 --- a/pl_examples/example2.py +++ b/pl_examples/example2.py @@ -1,23 +1,15 @@ +import os from collections import OrderedDict -from typing import Any, Union, Dict, Optional, Iterator, Tuple +from typing import Any, Dict, Iterator, Optional, Tuple, Union import torch from torch import Tensor from torch.optim import Optimizer - -from pytorch_lightning.loops import Loop -from pytorch_lightning.trainer.connectors.logger_connector.result import ( - ResultCollection, -) - -import os - -import torch from torch.utils.data import DataLoader, Dataset from pytorch_lightning import LightningModule, Trainer - - +from pytorch_lightning.loops import Loop +from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection class SimpleLoop(Loop): @@ -61,9 +53,7 @@ def reset(self) -> None: self.iteration_count = 0 def on_run_start(self) -> None: - self.train_dataloader = iter( - self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader) - ) + self.train_dataloader = iter(self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader)) self.trainer.call_hook("on_train_start") def advance(self) -> None: @@ -77,9 +67,7 @@ def advance(self) -> None: output = self._run_optimization(batch, self.iteration_count, optimizer) # hook - self.trainer.call_hook( - "on_train_batch_end", output, batch, self.iteration_count, dataloader_idx=0 - ) + self.trainer.call_hook("on_train_batch_end", output, batch, self.iteration_count, dataloader_idx=0) self.trainer.call_hook("on_batch_end") def on_run_end(self) -> None: @@ -96,15 +84,11 @@ def _run_optimization(self, batch: Any, batch_idx: int, optimizer: Optimizer): training_step_output = self.trainer.accelerator.training_step(step_kwargs) self.trainer.accelerator.post_training_step() - training_step_output = self.trainer.call_hook( - "training_step_end", training_step_output - ) + training_step_output = self.trainer.call_hook("training_step_end", training_step_output) loss, extra = self._process_training_step_output(training_step_output) # backward pass (single optimizer, no accumulation supported) - self.trainer.accelerator.backward( - loss, optimizer, optimizer_idx=0, should_accumulate=False - ) + self.trainer.accelerator.backward(loss, optimizer, optimizer_idx=0, should_accumulate=False) # optimizer step (no closures supported) lightning_module.optimizer_step(optimizer=optimizer) @@ -114,9 +98,7 @@ def _run_optimization(self, batch: Any, batch_idx: int, optimizer: Optimizer): return output @staticmethod - def _process_training_step_output( - training_step_output: Union[Dict, Tensor] - ) -> Tuple[Tensor, Dict]: + def _process_training_step_output(training_step_output: Union[Dict, Tensor]) -> Tuple[Tensor, Dict]: loss = None extra = {} @@ -197,4 +179,4 @@ def run(): if __name__ == '__main__': - run() \ No newline at end of file + run() From 3c578a5592400bcb581960a074fb7e48be53c6c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 8 Jul 2021 11:13:48 +0200 Subject: [PATCH 023/157] rename --- pytorch_lightning/trainer/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 22a977fb48248..5b3977a409ce2 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -365,8 +365,8 @@ def __init__( ) training_epoch_loop = TrainingEpochLoop(min_steps, max_steps) training_batch_loop = TrainingBatchLoop() - validation_epoch_loop = EvaluationLoop() - training_epoch_loop.link(batch_loop=training_batch_loop, val_loop=validation_epoch_loop) + training_validation_loop = EvaluationLoop() + training_epoch_loop.link(batch_loop=training_batch_loop, val_loop=training_validation_loop) fit_loop.link(epoch_loop=training_epoch_loop) # default .fit() loop From 4e869c93582c53b1ce094251fdce14db9a5b3393 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 8 Jul 2021 11:24:34 +0200 Subject: [PATCH 024/157] fix bug --- pytorch_lightning/loops/fit_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 3adc787bf3efc..a9b91eeeecffc 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -42,8 +42,8 @@ class FitLoop(Loop): # FIXME: update the note above def __init__(self, min_epochs: Optional[int] = None, max_epochs: Optional[int] = None): super().__init__() - self.max_epochs = min_epochs - self.min_epochs = max_epochs + self.max_epochs = max_epochs + self.min_epochs = min_epochs self.epoch_loop = None self.progress: Optional[FitLoopProgress] = None From c5389f7f8943cc2e62bfcc50179ca8a45acf3426 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 8 Jul 2021 11:32:36 +0200 Subject: [PATCH 025/157] reset _notebooks --- _notebooks | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/_notebooks b/_notebooks index 3321b468e7816..29aea106edefc 160000 --- a/_notebooks +++ b/_notebooks @@ -1 +1 @@ -Subproject commit 3321b468e78167aaf056894e92ed6d649c76e89e +Subproject commit 29aea106edefc9d1904c0c17223a8ac2b15c48e7 From fe55d6e5bedd371de5b900947ae1cbb6ac6bfd0b Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 8 Jul 2021 12:41:27 +0200 Subject: [PATCH 026/157] resolve issues --- pytorch_lightning/loops/base.py | 90 +++++++++++++++++++++-- pytorch_lightning/trainer/progress.py | 14 ++-- tests/loops/test_loops.py | 100 +++++++++++++++++++++++++- 3 files changed, 189 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 3293b3eba29ab..61fef73982e20 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -13,12 +13,17 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, OrderedDict from deprecate import void import pytorch_lightning as pl +from pytorch_lightning.trainer.progress import BaseProgress from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.warnings import WarningCache + +warning_cache = WarningCache() class Loop(ABC): @@ -47,6 +52,40 @@ def __init__(self) -> None: self.iteration_count: int = 0 self.trainer: Optional['pl.Trainer'] = None self.restarting = False + self._loops = OrderedDict() + self._progress = OrderedDict() + + def __setattr__(self, name: str, value: Any) -> None: + if isinstance(value, Loop): + self._loops[name] = value + elif isinstance(value, BaseProgress): + self._progress[name] = value + else: + object.__setattr__(self, name, value) + + def __getattr__(self, name) -> Any: + loops = self.__dict__.get('_loops') + + if loops is not None and name in loops: + return loops[name] + + progress = self.__dict__.get('_progress') + + if progress is not None and name in progress: + return progress[name] + + if name not in self.__dict__: + raise AttributeError(f"{self.__class__.__name__} Loop doesn't have attribute {name}.") + + return self.__dict__[name] + + def __delattr__(self, name) -> None: + if name in self._loops: + del self._loops[name] + elif name in self._progress: + del self._progress[name] + else: + object.__delattr__(self, name) @property @abstractmethod @@ -89,6 +128,8 @@ def run(self, *args: Any, **kwargs: Any) -> Optional[Any]: return self.on_skip() if self.restarting: + if not is_overridden("restore", self, Loop): + warning_cache.warn(f"{self.__class__.__name__} Loop doesn't override the restore function.") self.restore() self.restarting = False else: @@ -108,7 +149,7 @@ def run(self, *args: Any, **kwargs: Any) -> Optional[Any]: output = self.on_run_end() return output - def restore(self) -> None: + def restore(self, state: Optional[Dict] = None) -> None: """Restore the internal state of the loop the beginning of run if restarting is ``True``.""" @abstractmethod @@ -142,9 +183,46 @@ def on_run_end(self) -> Any: def teardown(self) -> None: """Use to release memory etc.""" - def load_state_dict(self, state_dict: Dict) -> None: - """Restore the loop state from the provided state_dict.""" - def state_dict(self) -> Dict: - """Return the loop current states.""" + """Current Loop state""" return {} + + def load_state_dict(self, state_dict: Dict) -> None: + """Reload Loop state""" + + def get_state_dict(self, destination: Optional[OrderedDict] = None, prefix: Optional[str] = '') -> OrderedDict: + if destination is None: + destination = OrderedDict() + + destination[prefix + "state_dict"] = self.state_dict() + + for name, progress in self._progress.items(): + destination[prefix + name] = progress.state_dict() + + for name, loop in self._loops.items(): + loop.get_state_dict(destination, prefix + name + '.') + return destination + + def _load_from_state_dict(self, state_dict, prefix, strict, missing_keys, unexpected_keys, error_msgs): + self.load_state_dict(state_dict[prefix + "state_dict"]) + + for name, progress in self._progress.items(): + progress.load_state_dict(state_dict[prefix + name]) + + def _load_state_dict(self, state_dict: Dict, strict: bool = True): + + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + + state_dict = state_dict.copy() + + def load(loop, prefix=''): + loop._load_from_state_dict(state_dict, prefix, True, missing_keys, unexpected_keys, error_msgs) + loop.restarting = True + for name, loop_children in loop._loops.items(): + if loop_children is not None: + load(loop_children, prefix + name + '.') + + load(self) + load = None # break load->load reference cycle diff --git a/pytorch_lightning/trainer/progress.py b/pytorch_lightning/trainer/progress.py index 25f76ad085cc6..3acae2485cea0 100644 --- a/pytorch_lightning/trainer/progress.py +++ b/pytorch_lightning/trainer/progress.py @@ -16,7 +16,7 @@ @dataclass -class _DataclassStateDictMixin: +class BaseProgress: def state_dict(self) -> dict: return asdict(self) @@ -25,14 +25,14 @@ def load_state_dict(self, state_dict: dict) -> None: self.__dict__.update(state_dict) @classmethod - def from_state_dict(cls, state_dict: dict) -> "_DataclassStateDictMixin": + def from_state_dict(cls, state_dict: dict) -> "BaseProgress": obj = cls() obj.load_state_dict(state_dict) return obj @dataclass -class Tracker(_DataclassStateDictMixin): +class Tracker(BaseProgress): """ Track an event's progress. @@ -72,7 +72,7 @@ def __repr__(self): @dataclass -class Progress(_DataclassStateDictMixin): +class Progress(BaseProgress): """ Track aggregated and current progress. @@ -150,7 +150,7 @@ def load_state_dict(self, state_dict: dict) -> None: @dataclass -class OptimizerProgress(_DataclassStateDictMixin): +class OptimizerProgress(BaseProgress): """ Track optimizer progress. @@ -172,7 +172,7 @@ def load_state_dict(self, state_dict: dict) -> None: @dataclass -class OptimizationProgress(_DataclassStateDictMixin): +class OptimizationProgress(BaseProgress): """ Track optimization progress. @@ -203,7 +203,7 @@ def load_state_dict(self, state_dict: dict) -> None: @dataclass -class EpochLoopProgress(_DataclassStateDictMixin): +class EpochLoopProgress(BaseProgress): """ Tracks epoch loop progress. These counters are local to a trainer rank. By default, they are not globally synced across all ranks. diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index af5801d2b4552..9f553247ab9a3 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -11,10 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from typing import Dict, Iterator +from collections import OrderedDict +from copy import deepcopy +from dataclasses import dataclass +from typing import Any, Dict, Iterator from pytorch_lightning.loops.base import Loop +from pytorch_lightning.trainer.progress import BaseProgress def test_loop_restore(): @@ -72,3 +75,96 @@ def load_state_dict(self, state_dict: Dict) -> None: assert not loop.restarting assert loop.outputs == list(range(10)) + + +def test_loop_hierarchy(): + + @dataclass + class SimpleProgress(BaseProgress): + + increment: int = 0 + + def state_dict(self): + return {"increment": self.increment} + + def load_state_dict(self, state_dict): + self.increment = state_dict["increment"] + + class Simple(Loop): + + def __init__(self, a): + super().__init__() + self.a = a + self.progress = SimpleProgress() + + def advance(self, *args: Any, **kwargs: Any) -> None: + for loop in self._loops.values(): + loop.run() + self.progress.increment += 1 + self.progress.increment += 1 + + @property + def done(self) -> bool: + return self.iteration_count > 0 + + def reset(self) -> None: + pass + + def restore(self) -> None: + pass + + def state_dict(self) -> Dict: + return {"a": self.a} + + def load_state_dict(self, state_dict: Dict) -> None: + self.a = state_dict["a"] + + loop_parent = Simple(1) + loop_child = Simple(2) + loop_parent.loop_child = loop_child + state_dict = loop_parent.get_state_dict() + assert state_dict == OrderedDict([('state_dict', { + 'a': 1 + }), ('progress', { + 'increment': 0 + }), ('loop_child.state_dict', { + 'a': 2 + }), ('loop_child.progress', { + 'increment': 0 + })]) + + state_dict["loop_child.state_dict"]["a"] = 3 + loop_parent._load_state_dict(state_dict) + assert loop_parent.restarting + + loop_parent.run() + + loop_parent_copy = deepcopy(loop_parent) + assert loop_parent_copy.get_state_dict() == loop_parent.get_state_dict() + + assert loop_parent_copy.state_dict() == {'a': 1} + assert loop_parent_copy.loop_child.state_dict() == {'a': 3} + + assert not loop_parent.restarting + + state_dict = loop_parent.get_state_dict() + assert state_dict == OrderedDict([('state_dict', { + 'a': 1 + }), ('progress', { + 'increment': 2 + }), ('loop_child.state_dict', { + 'a': 3 + }), ('loop_child.progress', { + 'increment': 1 + })]) + + loop_parent = Simple(1) + loop_child = Simple(2) + loop_parent.loop_child = loop_child + loop_parent._load_state_dict(state_dict) + assert loop_parent.progress.increment == 2 + assert loop_parent.loop_child.progress.increment == 1 + + del loop_parent.loop_child + state_dict = loop_parent.get_state_dict() + assert state_dict == OrderedDict([('state_dict', {'a': 1}), ('progress', {'increment': 2})]) From 4ee4a7301b0b18dca7434ea5a050b49b78d9cf73 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 8 Jul 2021 15:02:20 +0200 Subject: [PATCH 027/157] update --- pytorch_lightning/loops/base.py | 24 +++++++++++++++++++++++- pytorch_lightning/trainer/progress.py | 11 ++++++++++- tests/loops/test_loops.py | 24 +++++++++++++++++++++++- 3 files changed, 56 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 61fef73982e20..59101e4b98ee8 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -18,7 +18,7 @@ from deprecate import void import pytorch_lightning as pl -from pytorch_lightning.trainer.progress import BaseProgress +from pytorch_lightning.trainer.progress import BaseProgress, ProgressDict from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.warnings import WarningCache @@ -55,8 +55,30 @@ def __init__(self) -> None: self._loops = OrderedDict() self._progress = OrderedDict() + @property + def is_leaf(self) -> bool: + loops = self.__dict__.get('_loops') + return len(loops) == 0 + + @property + def loop_progress(self) -> Dict[str, Any]: + progress = {} + for n, p in self.__dict__.get('_progress').items(): + progress[n] = p + + loops = self.__dict__.get('_loops') + + if loops is not None: + for name, loop in loops.items(): + progress[name] = ProgressDict(**loop.loop_progress) + return ProgressDict(**progress) + def __setattr__(self, name: str, value: Any) -> None: if isinstance(value, Loop): + if getattr(self, "__children__loops__", None) is not None and name not in self.__children__loops__: + raise MisconfigurationException( + f"The current loop accept only {self.__children__loops__} as children attribute names." + ) self._loops[name] = value elif isinstance(value, BaseProgress): self._progress[name] = value diff --git a/pytorch_lightning/trainer/progress.py b/pytorch_lightning/trainer/progress.py index 3acae2485cea0..54b85273d9c0a 100644 --- a/pytorch_lightning/trainer/progress.py +++ b/pytorch_lightning/trainer/progress.py @@ -12,7 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import asdict, dataclass, field -from typing import Optional +from typing import Dict, Optional + + +class ProgressDict(Dict): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + for k, v in kwargs.items(): + setattr(self, k, v) @dataclass diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 9f553247ab9a3..1061d7e2d0cee 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -16,8 +16,11 @@ from dataclasses import dataclass from typing import Any, Dict, Iterator +import pytest + from pytorch_lightning.loops.base import Loop -from pytorch_lightning.trainer.progress import BaseProgress +from pytorch_lightning.trainer.progress import BaseProgress, ProgressDict +from pytorch_lightning.utilities.exceptions import MisconfigurationException def test_loop_restore(): @@ -92,6 +95,8 @@ def load_state_dict(self, state_dict): class Simple(Loop): + __children__loops__ = ("loop_child") + def __init__(self, a): super().__init__() self.a = a @@ -123,6 +128,21 @@ def load_state_dict(self, state_dict: Dict) -> None: loop_child = Simple(2) loop_parent.loop_child = loop_child state_dict = loop_parent.get_state_dict() + + with pytest.raises(MisconfigurationException, match="The current loop accept only loop_child"): + loop_parent.wrong_name = loop_child + + loop_progress: ProgressDict = loop_parent.loop_progress + assert loop_progress["progress"] == loop_parent.progress + assert loop_progress["loop_child"]["progress"] == loop_child.progress + + assert loop_progress.progress == loop_parent.progress + assert loop_progress.loop_child.progress == loop_child.progress + + loop_progress = loop_child.loop_progress + assert loop_progress["progress"] == loop_child.progress + assert loop_progress.progress == loop_child.progress + assert state_dict == OrderedDict([('state_dict', { 'a': 1 }), ('progress', { @@ -133,6 +153,8 @@ def load_state_dict(self, state_dict: Dict) -> None: 'increment': 0 })]) + loop_parent.progress + state_dict["loop_child.state_dict"]["a"] = 3 loop_parent._load_state_dict(state_dict) assert loop_parent.restarting From 12914183842760407c054a648f06834fef25b3b5 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 8 Jul 2021 15:03:35 +0200 Subject: [PATCH 028/157] update --- pytorch_lightning/loops/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 59101e4b98ee8..9c8f6365685ac 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -13,7 +13,8 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, Dict, Optional, OrderedDict +from collections import OrderedDict +from typing import Any, Dict, Optional from deprecate import void From fe8ba38a78ed68d695b51cba90cc5505f5631263 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 8 Jul 2021 15:21:25 +0200 Subject: [PATCH 029/157] change connect method --- pl_examples/example1.py | 8 ++++-- pytorch_lightning/loops/base.py | 26 ++++++++--------- .../loops/batch/training_batch_loop.py | 28 +++++++++---------- .../loops/dataloader/evaluation_loop.py | 20 ++++++------- .../loops/dataloader/prediction_loop.py | 18 ++++++------ .../loops/epoch/evaluation_epoch_loop.py | 14 +++++----- .../loops/epoch/prediction_epoch_loop.py | 14 +++++----- .../loops/epoch/training_epoch_loop.py | 24 ++++------------ pytorch_lightning/loops/fit_loop.py | 17 ++--------- pytorch_lightning/trainer/trainer.py | 13 ++++----- 10 files changed, 80 insertions(+), 102 deletions(-) diff --git a/pl_examples/example1.py b/pl_examples/example1.py index 56a4e65a2bd20..60eb25d33f8e2 100644 --- a/pl_examples/example1.py +++ b/pl_examples/example1.py @@ -65,17 +65,21 @@ def run(): # construct loops fit_loop = FitLoop() + fit_loop.any = TrainingEpochLoop() + train_epoch_loop = TrainingEpochLoop(min_steps=0, max_steps=2) train_batch_loop = TrainingBatchLoop() val_loop = EvaluationLoop() # link loops - train_epoch_loop.link(batch_loop=train_batch_loop, val_loop=val_loop) - fit_loop.link(epoch_loop=train_epoch_loop) + train_epoch_loop.connect(batch_loop=train_batch_loop, val_loop=val_loop) + fit_loop.connect(epoch_loop=train_epoch_loop) # connect fit loop to trainer trainer.fit_loop = fit_loop + fit_loop.connect(epoch_loop=TrainingEpochLoop()) + trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data) trainer.test(model, dataloaders=test_data) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 564699f9c92bb..fdb3a81ad2b9b 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -16,6 +16,7 @@ from typing import Any, Dict, Optional from deprecate import void +from onnx.backend.test.case.node.loop import Loop import pytorch_lightning as pl from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -44,10 +45,14 @@ class Loop(ABC): """ def __init__(self) -> None: + self.parent: Optional["Loop"] = None self.iteration_count: int = 0 - self.trainer: Optional['pl.Trainer'] = None self.restarting = False + @property + def trainer(self) -> "pl.Trainer": + return self.parent.trainer + @property @abstractmethod def done(self) -> bool: @@ -58,18 +63,13 @@ def skip(self) -> bool: """Determine whether to return immediately from the call to :meth:`run`.""" return False - def link(self, **kwargs: "Loop"): - """Optionally link one or multiple loops to this one. Linked loops should form a tree.""" - raise NotImplementedError(f"{self.__class__.__name__} does link any child loops.") - - def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None: - """Called by the Trainer. Connects a Loop with all the necessary components like progress, etc.""" - # TODO(@justusschock): Make the trainer a weakref/proxy - if not isinstance(trainer, pl.Trainer): - raise MisconfigurationException( - f"Loop {self.__class__.__name__} should be connected to a `Trainer`, found: {trainer}." - ) - self.trainer = trainer + def connect(self, **kwargs: "Loop") -> None: + """Optionally connect one or multiple loops to this one. Linked loops should form a tree.""" + # if not isinstance(trainer, pl.Trainer): + # raise MisconfigurationException( + # f"Loop {self.__class__.__name__} should be connected to a `Trainer`, found: {trainer}." + # ) + # raise NotImplementedError(f"{self.__class__.__name__} does connect any child loops.") def on_skip(self) -> Optional[Any]: """ diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index b6d9366b6db50..eb64d3e53a643 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -59,20 +59,20 @@ def __init__(self) -> None: self._remaining_splits: Optional[List[Any]] = None self._skip_backward: bool = False - def connect( - self, - trainer: 'pl.Trainer', - *args: Any, - progress: Optional[BatchProgress] = None, - optim_progress: Optional[OptimizationProgress] = None, - **kwargs: Any - ) -> None: - """Called by the Trainer. Connects a Loop with all the necessary components like progress, etc.""" - super().connect(trainer, *args, **kwargs) - if progress is not None: - self.progress = progress - if optim_progress is not None: - self.optim_progress = optim_progress + # def connect( + # self, + # trainer: 'pl.Trainer', + # *args: Any, + # progress: Optional[BatchProgress] = None, + # optim_progress: Optional[OptimizationProgress] = None, + # **kwargs: Any + # ) -> None: + # """Called by the Trainer. Connects a Loop with all the necessary components like progress, etc.""" + # super().connect(trainer, *args, **kwargs) + # if progress is not None: + # self.progress = progress + # if optim_progress is not None: + # self.optim_progress = optim_progress @property def done(self) -> bool: diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index e278897995453..7484b3bc2b67c 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -65,18 +65,18 @@ def predictions(self): """Returns the predictions from all dataloaders""" return self.epoch_loop.predictions - def link(self, epoch_loop: EvaluationEpochLoop): - """Links the evaluation epoch loop with this loop.""" + def connect(self, epoch_loop: EvaluationEpochLoop): + """Connect the evaluation epoch loop with this loop.""" self.epoch_loop = epoch_loop - def connect( - self, trainer: "pl.Trainer", *args: Any, progress: Optional[EpochLoopProgress] = None, **kwargs: Any - ) -> None: - """Called by the Trainer. Connects a Loop with all the necessary components like progress, etc.""" - super().connect(trainer, *args, **kwargs) - if progress is not None: - self.progress = progress - self.epoch_loop.connect(trainer, progress=self.progress.epoch) + # def connect( + # self, trainer: "pl.Trainer", *args: Any, progress: Optional[EpochLoopProgress] = None, **kwargs: Any + # ) -> None: + # """Called by the Trainer. Connects a Loop with all the necessary components like progress, etc.""" + # super().connect(trainer, *args, **kwargs) + # if progress is not None: + # self.progress = progress + # self.epoch_loop.connect(trainer, progress=self.progress.epoch) @property def done(self) -> bool: diff --git a/pytorch_lightning/loops/dataloader/prediction_loop.py b/pytorch_lightning/loops/dataloader/prediction_loop.py index c1f7c0cf7d464..3a5f5a5d59081 100644 --- a/pytorch_lightning/loops/dataloader/prediction_loop.py +++ b/pytorch_lightning/loops/dataloader/prediction_loop.py @@ -75,17 +75,17 @@ def done(self) -> bool: def skip(self) -> bool: return sum(self.max_batches) == 0 - def link(self, epoch_loop: PredictionEpochLoop): + def connect(self, epoch_loop: PredictionEpochLoop): self.epoch_loop = epoch_loop - def connect( - self, trainer: "pl.Trainer", *args: Any, progress: Optional[EpochLoopProgress] = None, **kwargs: Any - ) -> None: - """Called by the Trainer. Connects a Loop with all the necessary components like progress, etc.""" - super().connect(trainer, *args, **kwargs) - if progress is not None: - self.progress = progress - self.epoch_loop.connect(trainer, progress=self.progress.epoch) + # def connect( + # self, trainer: "pl.Trainer", *args: Any, progress: Optional[EpochLoopProgress] = None, **kwargs: Any + # ) -> None: + # """Called by the Trainer. Connects a Loop with all the necessary components like progress, etc.""" + # super().connect(trainer, *args, **kwargs) + # if progress is not None: + # self.progress = progress + # self.epoch_loop.connect(trainer, progress=self.progress.epoch) def reset(self) -> None: """Resets the internal state of the loop for a new run""" diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index 9e0b6a409f53b..5e723a1ba3b81 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -43,13 +43,13 @@ def __init__(self) -> None: self.outputs: List[STEP_OUTPUT] = [] self.progress = EpochProgress() - def connect( - self, trainer: "pl.Trainer", *args: Any, progress: Optional[EpochProgress] = None, **kwargs: Any - ) -> None: - """Called by the Trainer. Connects a Loop with all the necessary components like progress, etc.""" - super().connect(trainer, *args, **kwargs) - if progress is not None: - self.progress = progress + # def connect( + # self, trainer: "pl.Trainer", *args: Any, progress: Optional[EpochProgress] = None, **kwargs: Any + # ) -> None: + # """Called by the Trainer. Connects a Loop with all the necessary components like progress, etc.""" + # super().connect(trainer, *args, **kwargs) + # if progress is not None: + # self.progress = progress @property def done(self) -> bool: diff --git a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py index e4a3f894644e3..e009a53bfa4a6 100644 --- a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py @@ -25,13 +25,13 @@ def __init__(self) -> None: self._warning_cache = WarningCache() self._all_batch_indices: List[int] = [] - def connect( - self, trainer: "pl.Trainer", *args: Any, progress: Optional[EpochProgress] = None, **kwargs: Any - ) -> None: - """Called by the Trainer. Connects a Loop with all the necessary components like progress, etc.""" - super().connect(trainer, *args, **kwargs) - if progress is not None: - self.progress = progress + # def connect( + # self, trainer: "pl.Trainer", *args: Any, progress: Optional[EpochProgress] = None, **kwargs: Any + # ) -> None: + # """Called by the Trainer. Connects a Loop with all the necessary components like progress, etc.""" + # super().connect(trainer, *args, **kwargs) + # if progress is not None: + # self.progress = progress @property def done(self) -> bool: diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 21fbcd92f0331..97741be881e78 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -47,8 +47,8 @@ def __init__(self, min_steps: int, max_steps: int): self.is_last_batch: Optional[bool] = None self.progress = TrainingEpochProgress() - self.batch_loop = TrainingBatchLoop() - self.val_loop = loops.EvaluationLoop() + self.batch_loop = None + self.val_loop = None self._results = ResultCollection(training=True) self._dataloader_idx: Optional[int] = None @@ -69,24 +69,12 @@ def done(self) -> bool: max_steps_reached = self.max_steps is not None and self.global_step >= self.max_steps return max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(self.is_last_batch) - def link(self, batch_loop, val_loop): - """Links a batch loop and a validation loop to this training epoch loop.""" + def connect(self, batch_loop: TrainingBatchLoop, val_loop: "loops.EvaluationLoop") -> None: + """Called by the Trainer. Connects a Loop with all the necessary components like progress, etc.""" self.batch_loop = batch_loop self.val_loop = val_loop - - def connect( - self, - trainer: 'pl.Trainer', - *args: Any, - progress: Optional[TrainingEpochProgress] = None, - **kwargs: Any - ) -> None: - """Called by the Trainer. Connects a Loop with all the necessary components like progress, etc.""" - super().connect(trainer, *args, **kwargs) - if progress is not None: - self.progress = progress - self.batch_loop.connect(trainer, progress=self.progress.batch, optim_progress=self.progress.optim) - self.val_loop.connect(trainer, progress=self.progress.val) + # if self.trainer is not None: + # self.batch_loop.connect() def reset(self) -> None: """Resets the internal state of the loop for a new run""" diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index a9b91eeeecffc..3f99da9c68a08 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -161,23 +161,10 @@ def skip(self) -> bool: """Whether we should skip the training and immediately return from the call to :meth:`run`.""" return self.done or self.trainer.num_training_batches == 0 - def link(self, epoch_loop: TrainingEpochLoop): - """Links a training epoch loop to this fit loop.""" + def connect(self, epoch_loop: TrainingEpochLoop): + """Connects a training epoch loop to this fit loop.""" self.epoch_loop = epoch_loop - def connect( - self, - trainer: 'pl.Trainer', - *args: Any, - progress: Optional[FitLoopProgress] = None, - **kwargs: Any, - ) -> None: - """Called by the Trainer. Connects a Loop with all the necessary components like progress, etc.""" - super().connect(trainer, *args, **kwargs) - if progress is not None: - self.progress = progress - self.epoch_loop.connect(trainer, progress=self.progress.epoch) - def reset(self) -> None: """Resets the internal state of this loop""" diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 5b3977a409ce2..066f36644df3d 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -358,7 +358,6 @@ def __init__( self.slurm_connector = SLURMConnector(self) self.tuner = Tuner(self) - # .fit() loop fit_loop = FitLoop( min_epochs=(1 if (min_epochs is None and min_steps is None) else min_epochs), max_epochs=(1000 if (max_epochs is None and max_steps is None) else max_epochs), @@ -366,8 +365,8 @@ def __init__( training_epoch_loop = TrainingEpochLoop(min_steps, max_steps) training_batch_loop = TrainingBatchLoop() training_validation_loop = EvaluationLoop() - training_epoch_loop.link(batch_loop=training_batch_loop, val_loop=training_validation_loop) - fit_loop.link(epoch_loop=training_epoch_loop) + training_epoch_loop.connect(batch_loop=training_batch_loop, val_loop=training_validation_loop) + fit_loop.connect(epoch_loop=training_epoch_loop) # default .fit() loop self.fit_loop = fit_loop @@ -460,7 +459,7 @@ def fit_loop(self): @fit_loop.setter def fit_loop(self, loop: FitLoop): self._fit_loop = loop - self._fit_loop.connect(self, progress=FitLoopProgress()) + self._fit_loop.trainer = self @property def validate_loop(self): @@ -469,7 +468,7 @@ def validate_loop(self): @validate_loop.setter def validate_loop(self, loop: EvaluationLoop): self._validate_loop = loop - self._validate_loop.connect(self, progress=EpochLoopProgress()) + self._validate_loop.trainer = self @property def test_loop(self): @@ -478,7 +477,7 @@ def test_loop(self): @test_loop.setter def test_loop(self, loop: EvaluationLoop): self._test_loop = loop - self._test_loop.connect(self, progress=EpochLoopProgress()) + self._test_loop.trainer = self @property def predict_loop(self): @@ -487,7 +486,7 @@ def predict_loop(self): @predict_loop.setter def predict_loop(self, loop: PredictionLoop): self._predict_loop = loop - self._predict_loop.connect(self, progress=EpochLoopProgress()) + self._predict_loop.trainer = self def _setup_on_init( self, From 5bfed2f8a96c72e875efd32dfe0ed2e51bad0f18 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 8 Jul 2021 13:22:38 +0000 Subject: [PATCH 030/157] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/loops/epoch/prediction_epoch_loop.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py index e009a53bfa4a6..43ccbc74209c8 100644 --- a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py @@ -29,9 +29,9 @@ def __init__(self) -> None: # self, trainer: "pl.Trainer", *args: Any, progress: Optional[EpochProgress] = None, **kwargs: Any # ) -> None: # """Called by the Trainer. Connects a Loop with all the necessary components like progress, etc.""" - # super().connect(trainer, *args, **kwargs) - # if progress is not None: - # self.progress = progress + # super().connect(trainer, *args, **kwargs) + # if progress is not None: + # self.progress = progress @property def done(self) -> bool: From 0e69bea98a779fc66a26cabea6c811a234036e90 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 8 Jul 2021 15:25:46 +0200 Subject: [PATCH 031/157] update --- pytorch_lightning/loops/base.py | 30 ++++++++++++++++++++++++++---- tests/loops/test_loops.py | 17 +++++++++++++++++ 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 9c8f6365685ac..5ae5bfe244de1 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -55,11 +55,20 @@ def __init__(self) -> None: self.restarting = False self._loops = OrderedDict() self._progress = OrderedDict() + self._num_parents: int = 0 @property - def is_leaf(self) -> bool: + def has_parent(self) -> Optional[bool]: + return self._num_parents > 0 + + @property + def has_children(self) -> bool: loops = self.__dict__.get('_loops') - return len(loops) == 0 + return len(loops) > 0 + + @property + def is_leaf(self) -> bool: + return not self.has_children and self.has_parent @property def loop_progress(self) -> Dict[str, Any]: @@ -75,12 +84,21 @@ def loop_progress(self) -> Dict[str, Any]: return ProgressDict(**progress) def __setattr__(self, name: str, value: Any) -> None: - if isinstance(value, Loop): + if isinstance(value, pl.Trainer): + # when assigning a Trainer to a loop, it will assign to its children too. + object.__setattr__(self, name, value) + for loop in self._loops.values(): + object.__setattr__(loop, name, value) + elif isinstance(value, Loop): if getattr(self, "__children__loops__", None) is not None and name not in self.__children__loops__: raise MisconfigurationException( f"The current loop accept only {self.__children__loops__} as children attribute names." ) - self._loops[name] = value + if value not in self._loops.values(): + self._loops[name] = value + value._num_parents += 1 + else: + raise MisconfigurationException("This loop has already been assigned.") elif isinstance(value, BaseProgress): self._progress[name] = value else: @@ -104,6 +122,7 @@ def __getattr__(self, name) -> Any: def __delattr__(self, name) -> None: if name in self._loops: + self._loops[name]._num_parents -= 1 del self._loops[name] elif name in self._progress: del self._progress[name] @@ -147,6 +166,9 @@ def run(self, *args: Any, **kwargs: Any) -> Optional[Any]: Returns: the output of :attr:`on_run_end` (often outputs collected from each step of the loop) """ + if self.trainer is None: + raise MisconfigurationException(f"The {self.__class__.__name__} Loop hasn't been attached to any Trainer.") + if self.skip: return self.on_skip() diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 1061d7e2d0cee..23e62faddaa55 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -20,6 +20,7 @@ from pytorch_lightning.loops.base import Loop from pytorch_lightning.trainer.progress import BaseProgress, ProgressDict +from pytorch_lightning.trainer.trainer import Trainer from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -124,9 +125,21 @@ def state_dict(self) -> Dict: def load_state_dict(self, state_dict: Dict) -> None: self.a = state_dict["a"] + grand_loop_parent = Simple(0) loop_parent = Simple(1) loop_child = Simple(2) loop_parent.loop_child = loop_child + + with pytest.raises(MisconfigurationException, match="Loop hasn't been attached to any Trainer."): + grand_loop_parent.run() + + grand_loop_parent.loop_child = loop_child + assert loop_child._num_parents == 2 + del grand_loop_parent.loop_child + assert loop_child._num_parents == 1 + assert loop_child.has_parent + assert loop_parent.has_children + state_dict = loop_parent.get_state_dict() with pytest.raises(MisconfigurationException, match="The current loop accept only loop_child"): @@ -143,6 +156,9 @@ def load_state_dict(self, state_dict: Dict) -> None: assert loop_progress["progress"] == loop_child.progress assert loop_progress.progress == loop_child.progress + loop_parent.trainer = Trainer() + assert loop_child.trainer == loop_parent.trainer + assert state_dict == OrderedDict([('state_dict', { 'a': 1 }), ('progress', { @@ -188,5 +204,6 @@ def load_state_dict(self, state_dict: Dict) -> None: assert loop_parent.loop_child.progress.increment == 1 del loop_parent.loop_child + assert loop_child._num_parents == 0 state_dict = loop_parent.get_state_dict() assert state_dict == OrderedDict([('state_dict', {'a': 1}), ('progress', {'increment': 2})]) From 368e17916ac60bd9ba78ad044d69c9505e5fca13 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 8 Jul 2021 15:36:17 +0200 Subject: [PATCH 032/157] add more exceptions --- pytorch_lightning/loops/base.py | 17 ++++++++++++----- tests/loops/test_loops.py | 7 +++++-- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 5ae5bfe244de1..5a9e3d6e80346 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -59,19 +59,23 @@ def __init__(self) -> None: @property def has_parent(self) -> Optional[bool]: + """Whether the number of loop parents is not null""" return self._num_parents > 0 @property def has_children(self) -> bool: + """Whether this loop has any children""" loops = self.__dict__.get('_loops') return len(loops) > 0 @property def is_leaf(self) -> bool: + """Whether this loop is a children and has no children itself.""" return not self.has_children and self.has_parent @property def loop_progress(self) -> Dict[str, Any]: + """Return the progress for the current loop and children loop.""" progress = {} for n, p in self.__dict__.get('_progress').items(): progress[n] = p @@ -94,11 +98,14 @@ def __setattr__(self, name: str, value: Any) -> None: raise MisconfigurationException( f"The current loop accept only {self.__children__loops__} as children attribute names." ) - if value not in self._loops.values(): - self._loops[name] = value - value._num_parents += 1 - else: - raise MisconfigurationException("This loop has already been assigned.") + for loop_name, loop in self._loops.items(): + if loop == value and name != loop_name: + raise MisconfigurationException( + f"The {self.__class__.__name__} already contains the provided loop " + f"{loop} under the attribute_name {loop_name}." + ) + self._loops[name] = value + value._num_parents += 1 elif isinstance(value, BaseProgress): self._progress[name] = value else: diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 23e62faddaa55..b19c387a501a1 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -96,7 +96,7 @@ def load_state_dict(self, state_dict): class Simple(Loop): - __children__loops__ = ("loop_child") + __children__loops__ = ("loop_child", "something") def __init__(self, a): super().__init__() @@ -130,6 +130,9 @@ def load_state_dict(self, state_dict: Dict) -> None: loop_child = Simple(2) loop_parent.loop_child = loop_child + with pytest.raises(MisconfigurationException, match="The Simple already contains the provided loop"): + loop_parent.something = loop_child + with pytest.raises(MisconfigurationException, match="Loop hasn't been attached to any Trainer."): grand_loop_parent.run() @@ -142,7 +145,7 @@ def load_state_dict(self, state_dict: Dict) -> None: state_dict = loop_parent.get_state_dict() - with pytest.raises(MisconfigurationException, match="The current loop accept only loop_child"): + with pytest.raises(MisconfigurationException, match="The current loop accept only"): loop_parent.wrong_name = loop_child loop_progress: ProgressDict = loop_parent.loop_progress From eb4475cf222715fa140b633959abb1eacb142a91 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 8 Jul 2021 15:38:04 +0200 Subject: [PATCH 033/157] resolve bug --- pytorch_lightning/loops/base.py | 18 +++++++++++------- tests/loops/test_loops.py | 3 +++ 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 5a9e3d6e80346..bf3534f397f8f 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -98,14 +98,18 @@ def __setattr__(self, name: str, value: Any) -> None: raise MisconfigurationException( f"The current loop accept only {self.__children__loops__} as children attribute names." ) + is_contained = False for loop_name, loop in self._loops.items(): - if loop == value and name != loop_name: - raise MisconfigurationException( - f"The {self.__class__.__name__} already contains the provided loop " - f"{loop} under the attribute_name {loop_name}." - ) - self._loops[name] = value - value._num_parents += 1 + if loop == value: + is_contained = True + if name != loop_name: + raise MisconfigurationException( + f"The {self.__class__.__name__} already contains the provided loop " + f"{loop} under the attribute_name {loop_name}." + ) + if not is_contained: + self._loops[name] = value + value._num_parents += 1 elif isinstance(value, BaseProgress): self._progress[name] = value else: diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index b19c387a501a1..e4c832eeb7f89 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -129,6 +129,9 @@ def load_state_dict(self, state_dict: Dict) -> None: loop_parent = Simple(1) loop_child = Simple(2) loop_parent.loop_child = loop_child + assert loop_child._num_parents == 1 + loop_parent.loop_child = loop_child + assert loop_child._num_parents == 1 with pytest.raises(MisconfigurationException, match="The Simple already contains the provided loop"): loop_parent.something = loop_child From 449ca622f0250f24b4b83f28f08348bb7441787d Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 8 Jul 2021 16:36:29 +0200 Subject: [PATCH 034/157] update --- pytorch_lightning/loops/base.py | 37 ++++++++++++++------------------ tests/loops/test_loops.py | 38 +++++++++++++++++++++++++-------- 2 files changed, 45 insertions(+), 30 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index bf3534f397f8f..2b822911411d6 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -55,12 +55,13 @@ def __init__(self) -> None: self.restarting = False self._loops = OrderedDict() self._progress = OrderedDict() - self._num_parents: int = 0 + self._has_parent: bool = False + self.__parent_loop: Optional['Loop'] = None @property def has_parent(self) -> Optional[bool]: - """Whether the number of loop parents is not null""" - return self._num_parents > 0 + """Whether this loop has been attached to another loop""" + return self._has_parent @property def has_children(self) -> bool: @@ -94,22 +95,18 @@ def __setattr__(self, name: str, value: Any) -> None: for loop in self._loops.values(): object.__setattr__(loop, name, value) elif isinstance(value, Loop): + if name == "_Loop__parent_loop": + object.__setattr__(self, name, value) + return if getattr(self, "__children__loops__", None) is not None and name not in self.__children__loops__: raise MisconfigurationException( - f"The current loop accept only {self.__children__loops__} as children attribute names." + f"The current loop accept only {self.__children__loops__} as children attribute names. Found {name}" ) - is_contained = False - for loop_name, loop in self._loops.items(): - if loop == value: - is_contained = True - if name != loop_name: - raise MisconfigurationException( - f"The {self.__class__.__name__} already contains the provided loop " - f"{loop} under the attribute_name {loop_name}." - ) - if not is_contained: - self._loops[name] = value - value._num_parents += 1 + if value._has_parent: + raise MisconfigurationException(f"This provided loop {value} already has a parent. ") + self._loops[name] = value + value._has_parent = True + value.__parent_loop = self elif isinstance(value, BaseProgress): self._progress[name] = value else: @@ -126,14 +123,12 @@ def __getattr__(self, name) -> Any: if progress is not None and name in progress: return progress[name] - if name not in self.__dict__: - raise AttributeError(f"{self.__class__.__name__} Loop doesn't have attribute {name}.") - - return self.__dict__[name] + return object.__getattribute__(self, name) def __delattr__(self, name) -> None: if name in self._loops: - self._loops[name]._num_parents -= 1 + self._loops[name]._has_parent = False + self._loops[name]._Loop__parent_loop = None del self._loops[name] elif name in self._progress: del self._progress[name] diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index e4c832eeb7f89..da9708f525feb 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -35,6 +35,10 @@ def __init__(self, dataset: Iterator): super().__init__() self.dataset = dataset + @property + def skip(self) -> bool: + return False + def restore(self) -> None: self.iter_dataset = iter(self.dataset) for _ in range(self.iteration_count): @@ -64,8 +68,11 @@ def load_state_dict(self, state_dict: Dict) -> None: self.iteration_count = state_dict["iteration_count"] self.outputs = state_dict["outputs"] + trainer = Trainer() + data = range(10) loop = Simple(data) + loop.trainer = trainer try: loop.run() state_dict = {} @@ -73,6 +80,7 @@ def load_state_dict(self, state_dict: Dict) -> None: state_dict = loop.state_dict() loop = Simple(data) + loop.trainer = trainer loop.load_state_dict(state_dict) loop.restarting = True loop.run() @@ -109,6 +117,10 @@ def advance(self, *args: Any, **kwargs: Any) -> None: self.progress.increment += 1 self.progress.increment += 1 + @property + def skip(self) -> bool: + return False + @property def done(self) -> bool: return self.iteration_count > 0 @@ -128,21 +140,28 @@ def load_state_dict(self, state_dict: Dict) -> None: grand_loop_parent = Simple(0) loop_parent = Simple(1) loop_child = Simple(2) + + assert not loop_child.has_parent loop_parent.loop_child = loop_child - assert loop_child._num_parents == 1 - loop_parent.loop_child = loop_child - assert loop_child._num_parents == 1 - with pytest.raises(MisconfigurationException, match="The Simple already contains the provided loop"): + assert loop_child._Loop__parent_loop == loop_parent + + assert loop_child.has_parent + + with pytest.raises(MisconfigurationException, match="already has a parent"): + loop_parent.loop_child = loop_child + + assert not loop_parent.skip + + with pytest.raises(MisconfigurationException, match="already has a parent"): loop_parent.something = loop_child with pytest.raises(MisconfigurationException, match="Loop hasn't been attached to any Trainer."): grand_loop_parent.run() - grand_loop_parent.loop_child = loop_child - assert loop_child._num_parents == 2 - del grand_loop_parent.loop_child - assert loop_child._num_parents == 1 + with pytest.raises(MisconfigurationException, match="already has a parent"): + grand_loop_parent.loop_child = loop_child + assert loop_child.has_parent assert loop_parent.has_children @@ -210,6 +229,7 @@ def load_state_dict(self, state_dict: Dict) -> None: assert loop_parent.loop_child.progress.increment == 1 del loop_parent.loop_child - assert loop_child._num_parents == 0 + assert not loop_child.has_parent + assert loop_child._Loop__parent_loop is None state_dict = loop_parent.get_state_dict() assert state_dict == OrderedDict([('state_dict', {'a': 1}), ('progress', {'increment': 2})]) From cdf38f0b79d094998c1e499c9d40f14dd3d56831 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 8 Jul 2021 16:38:28 +0200 Subject: [PATCH 035/157] update --- pytorch_lightning/loops/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 2b822911411d6..e9a0d0affd2de 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -71,8 +71,8 @@ def has_children(self) -> bool: @property def is_leaf(self) -> bool: - """Whether this loop is a children and has no children itself.""" - return not self.has_children and self.has_parent + """This loop is a leaf if it doesn't possess any loops.""" + return not self.has_children @property def loop_progress(self) -> Dict[str, Any]: From 88bafafa1a5d4177a16f3655e0c82b3dcde68a8f Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 8 Jul 2021 16:39:49 +0200 Subject: [PATCH 036/157] update changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f344f490de6c1..002e01098ab16 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -256,6 +256,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - `Trainer(resume_from_checkpoint=...)` now restores the model directly after `LightningModule.setup()`, which is before `LightningModule.configure_sharded_model()` ([#7652](https://github.com/PyTorchLightning/pytorch-lightning/pull/7652)) +- Improve `Loop` API to better handle children `state_dict` and `progress` ([#8334](https://github.com/PyTorchLightning/pytorch-lightning/pull/8334)) + + ### Deprecated From 0981e949c38ab1fd8ea6f74b8b2bbe90116adf85 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 8 Jul 2021 17:12:31 +0200 Subject: [PATCH 037/157] resolve bug --- pytorch_lightning/loops/base.py | 2 +- tests/loops/test_loops.py | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index e9a0d0affd2de..3f98cf4214523 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -93,7 +93,7 @@ def __setattr__(self, name: str, value: Any) -> None: # when assigning a Trainer to a loop, it will assign to its children too. object.__setattr__(self, name, value) for loop in self._loops.values(): - object.__setattr__(loop, name, value) + loop.__setattr__(name, value) elif isinstance(value, Loop): if name == "_Loop__parent_loop": object.__setattr__(self, name, value) diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index da9708f525feb..5b1b12ebdf054 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -233,3 +233,12 @@ def load_state_dict(self, state_dict: Dict) -> None: assert loop_child._Loop__parent_loop is None state_dict = loop_parent.get_state_dict() assert state_dict == OrderedDict([('state_dict', {'a': 1}), ('progress', {'increment': 2})]) + + grand_loop_parent = Simple(0) + loop_parent = Simple(1) + loop_child = Simple(2) + grand_loop_parent.loop_child = loop_parent + loop_parent.loop_child = loop_child + + grand_loop_parent.trainer = Trainer() + assert loop_child.trainer is not None From e429eba8d261e30a85f7a25d5bb845c9074a201b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 8 Jul 2021 21:00:18 +0200 Subject: [PATCH 038/157] add setter --- pytorch_lightning/loops/base.py | 8 ++++++-- pytorch_lightning/loops/dataloader/evaluation_loop.py | 5 +++++ pytorch_lightning/loops/dataloader/prediction_loop.py | 5 +++++ pytorch_lightning/loops/epoch/training_epoch_loop.py | 6 ++++++ pytorch_lightning/loops/fit_loop.py | 5 +++++ pytorch_lightning/trainer/trainer.py | 8 ++++---- 6 files changed, 31 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index fdb3a81ad2b9b..99521a7557efe 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -45,13 +45,17 @@ class Loop(ABC): """ def __init__(self) -> None: - self.parent: Optional["Loop"] = None self.iteration_count: int = 0 self.restarting = False + self._trainer = None @property def trainer(self) -> "pl.Trainer": - return self.parent.trainer + return self._trainer + + @trainer.setter + def trainer(self, trainer: "pl.Trainer"): + self._trainer = trainer @property @abstractmethod diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 7484b3bc2b67c..00f1d1783ddf4 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -40,6 +40,11 @@ def __init__(self): self._max_batches: Optional[Union[int, Sequence[int]]] = None self._has_run: bool = False + @DataLoaderLoop.trainer.setter + def trainer(self, trainer): + self._trainer = trainer + self.epoch_loop.trainer = trainer + @property def num_dataloaders(self) -> int: """Returns the total number of dataloaders""" diff --git a/pytorch_lightning/loops/dataloader/prediction_loop.py b/pytorch_lightning/loops/dataloader/prediction_loop.py index 3a5f5a5d59081..d05450f974daa 100644 --- a/pytorch_lightning/loops/dataloader/prediction_loop.py +++ b/pytorch_lightning/loops/dataloader/prediction_loop.py @@ -25,6 +25,11 @@ def __init__(self): self._results = None # for `trainer._results` access self._return_predictions: bool = False + @DataLoaderLoop.trainer.setter + def trainer(self, trainer): + self._trainer = trainer + self.epoch_loop.trainer = trainer + @property def return_predictions(self) -> bool: """Whether to return the predictions or not""" diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 97741be881e78..0feb824c26916 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -55,6 +55,12 @@ def __init__(self, min_steps: int, max_steps: int): self._warning_cache: WarningCache = WarningCache() self._epoch_output: Optional[List[List[STEP_OUTPUT]]] = None + @loops.Loop.trainer.setter + def trainer(self, trainer): + self._trainer = trainer + self.batch_loop.trainer = trainer + self.val_loop.trainer = trainer + @property def batch_idx(self) -> int: """Returns the current batch index (within this epoch)""" diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 3f99da9c68a08..3b8dff653cb75 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -47,6 +47,11 @@ def __init__(self, min_epochs: Optional[int] = None, max_epochs: Optional[int] = self.epoch_loop = None self.progress: Optional[FitLoopProgress] = None + @Loop.trainer.setter + def trainer(self, trainer): + self._trainer = trainer + self.epoch_loop.trainer = trainer + @property def current_epoch(self) -> int: """Return the current epoch""" diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 066f36644df3d..97af997aa4e45 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -458,8 +458,8 @@ def fit_loop(self): @fit_loop.setter def fit_loop(self, loop: FitLoop): + loop.trainer = self self._fit_loop = loop - self._fit_loop.trainer = self @property def validate_loop(self): @@ -467,8 +467,8 @@ def validate_loop(self): @validate_loop.setter def validate_loop(self, loop: EvaluationLoop): + loop.trainer = self self._validate_loop = loop - self._validate_loop.trainer = self @property def test_loop(self): @@ -476,8 +476,8 @@ def test_loop(self): @test_loop.setter def test_loop(self, loop: EvaluationLoop): + loop.trainer = self self._test_loop = loop - self._test_loop.trainer = self @property def predict_loop(self): @@ -485,8 +485,8 @@ def predict_loop(self): @predict_loop.setter def predict_loop(self, loop: PredictionLoop): + loop.trainer = self self._predict_loop = loop - self._predict_loop.trainer = self def _setup_on_init( self, From 8906eb0a907ee2a6a6d56378a43144e69bc0fe1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 8 Jul 2021 21:10:14 +0200 Subject: [PATCH 039/157] update example --- pl_examples/example1.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/pl_examples/example1.py b/pl_examples/example1.py index 60eb25d33f8e2..714e160380711 100644 --- a/pl_examples/example1.py +++ b/pl_examples/example1.py @@ -64,9 +64,7 @@ def run(): ) # construct loops - fit_loop = FitLoop() - fit_loop.any = TrainingEpochLoop() - + fit_loop = FitLoop(max_epochs=2) train_epoch_loop = TrainingEpochLoop(min_steps=0, max_steps=2) train_batch_loop = TrainingBatchLoop() val_loop = EvaluationLoop() @@ -78,8 +76,6 @@ def run(): # connect fit loop to trainer trainer.fit_loop = fit_loop - fit_loop.connect(epoch_loop=TrainingEpochLoop()) - trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data) trainer.test(model, dataloaders=test_data) From c10dfdf0de572968d79e70858cc32186606ec16b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 8 Jul 2021 21:33:17 +0200 Subject: [PATCH 040/157] connect trainer --- pytorch_lightning/loops/dataloader/evaluation_loop.py | 2 ++ pytorch_lightning/loops/dataloader/prediction_loop.py | 2 ++ pytorch_lightning/loops/epoch/training_epoch_loop.py | 5 +++-- pytorch_lightning/loops/fit_loop.py | 2 ++ 4 files changed, 9 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 00f1d1783ddf4..933890c26d6c6 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -73,6 +73,8 @@ def predictions(self): def connect(self, epoch_loop: EvaluationEpochLoop): """Connect the evaluation epoch loop with this loop.""" self.epoch_loop = epoch_loop + if self.trainer is not None: + self.epoch_loop.trainer = self.trainer # def connect( # self, trainer: "pl.Trainer", *args: Any, progress: Optional[EpochLoopProgress] = None, **kwargs: Any diff --git a/pytorch_lightning/loops/dataloader/prediction_loop.py b/pytorch_lightning/loops/dataloader/prediction_loop.py index d05450f974daa..4d19d62f536a9 100644 --- a/pytorch_lightning/loops/dataloader/prediction_loop.py +++ b/pytorch_lightning/loops/dataloader/prediction_loop.py @@ -82,6 +82,8 @@ def skip(self) -> bool: def connect(self, epoch_loop: PredictionEpochLoop): self.epoch_loop = epoch_loop + if self.trainer is not None: + self.epoch_loop.trainer = self.trainer # def connect( # self, trainer: "pl.Trainer", *args: Any, progress: Optional[EpochLoopProgress] = None, **kwargs: Any diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 0feb824c26916..c987080647df7 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -79,8 +79,9 @@ def connect(self, batch_loop: TrainingBatchLoop, val_loop: "loops.EvaluationLoop """Called by the Trainer. Connects a Loop with all the necessary components like progress, etc.""" self.batch_loop = batch_loop self.val_loop = val_loop - # if self.trainer is not None: - # self.batch_loop.connect() + if self.trainer is not None: + self.batch_loop.trainer = self.trainer + self.val_loop.trainer = self.trainer def reset(self) -> None: """Resets the internal state of the loop for a new run""" diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 3b8dff653cb75..2af5cc73986d5 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -169,6 +169,8 @@ def skip(self) -> bool: def connect(self, epoch_loop: TrainingEpochLoop): """Connects a training epoch loop to this fit loop.""" self.epoch_loop = epoch_loop + if self.trainer is not None: + self.epoch_loop.trainer = self.trainer def reset(self) -> None: """Resets the internal state of this loop""" From d532faed72d1bc6b3a87a563ff786c75496ab366 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 8 Jul 2021 21:33:23 +0200 Subject: [PATCH 041/157] refine examples --- pl_examples/example1.py | 84 ------------------- pl_examples/loop_examples/__init__.py | 0 pl_examples/loop_examples/example1.py | 57 +++++++++++++ pl_examples/loop_examples/example2.py | 40 +++++++++ pl_examples/loop_examples/example3.py | 48 +++++++++++ .../simple_loop.py} | 83 ++---------------- 6 files changed, 153 insertions(+), 159 deletions(-) delete mode 100644 pl_examples/example1.py create mode 100644 pl_examples/loop_examples/__init__.py create mode 100644 pl_examples/loop_examples/example1.py create mode 100644 pl_examples/loop_examples/example2.py create mode 100644 pl_examples/loop_examples/example3.py rename pl_examples/{example2.py => loop_examples/simple_loop.py} (64%) diff --git a/pl_examples/example1.py b/pl_examples/example1.py deleted file mode 100644 index 714e160380711..0000000000000 --- a/pl_examples/example1.py +++ /dev/null @@ -1,84 +0,0 @@ -import os - -import torch -from torch.utils.data import DataLoader, Dataset - -from pytorch_lightning import LightningModule, Trainer -from pytorch_lightning.loops import EvaluationLoop, FitLoop, TrainingBatchLoop, TrainingEpochLoop -from pytorch_lightning.trainer.progress import FitLoopProgress - - -class RandomDataset(Dataset): - - def __init__(self, size, length): - self.len = length - self.data = torch.randn(length, size) - - def __getitem__(self, index): - return self.data[index] - - def __len__(self): - return self.len - - -class BoringModel(LightningModule): - - def __init__(self): - super().__init__() - self.layer = torch.nn.Linear(32, 2) - - def forward(self, x): - return self.layer(x) - - def training_step(self, batch, batch_idx): - loss = self(batch).sum() - self.log("train_loss", loss) - return {"loss": loss} - - def validation_step(self, batch, batch_idx): - loss = self(batch).sum() - self.log("valid_loss", loss) - - def test_step(self, batch, batch_idx): - loss = self(batch).sum() - self.log("test_loss", loss) - - def configure_optimizers(self): - return torch.optim.SGD(self.layer.parameters(), lr=0.1) - - -def run(): - train_data = DataLoader(RandomDataset(32, 64), batch_size=2) - val_data = DataLoader(RandomDataset(32, 64), batch_size=2) - test_data = DataLoader(RandomDataset(32, 64), batch_size=2) - - model = BoringModel() - - trainer = Trainer( - default_root_dir=os.getcwd(), - limit_train_batches=1, - limit_val_batches=1, - num_sanity_val_steps=0, - max_epochs=1, - weights_summary=None, - ) - - # construct loops - fit_loop = FitLoop(max_epochs=2) - train_epoch_loop = TrainingEpochLoop(min_steps=0, max_steps=2) - train_batch_loop = TrainingBatchLoop() - val_loop = EvaluationLoop() - - # link loops - train_epoch_loop.connect(batch_loop=train_batch_loop, val_loop=val_loop) - fit_loop.connect(epoch_loop=train_epoch_loop) - - # connect fit loop to trainer - trainer.fit_loop = fit_loop - - trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data) - trainer.test(model, dataloaders=test_data) - - -if __name__ == '__main__': - run() diff --git a/pl_examples/loop_examples/__init__.py b/pl_examples/loop_examples/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/pl_examples/loop_examples/example1.py b/pl_examples/loop_examples/example1.py new file mode 100644 index 0000000000000..d86cdec74c0d1 --- /dev/null +++ b/pl_examples/loop_examples/example1.py @@ -0,0 +1,57 @@ +import os + +from torch.utils.data import DataLoader + +from pl_examples.bug_report_model import RandomDataset, BoringModel +from pytorch_lightning import Trainer +from pytorch_lightning.loops import EvaluationLoop, FitLoop, TrainingBatchLoop, TrainingEpochLoop + + +def run(): + """ + This example demonstrates how loops are linked together. + Here we form a simple tree structure of three basic loops that make up the FitLoop: + + - Trainer + - fit_loop: FitLoop + - epoch_loop: TrainingEpochLoop + - batch_loop: TrainingBatchLoop + - val_loop: EvaluationLoop + """ + train_data = DataLoader(RandomDataset(32, 64), batch_size=2) + val_data = DataLoader(RandomDataset(32, 64), batch_size=2) + test_data = DataLoader(RandomDataset(32, 64), batch_size=2) + + model = BoringModel() + + trainer = Trainer( + default_root_dir=os.getcwd(), + limit_train_batches=1, + limit_val_batches=1, + num_sanity_val_steps=0, + max_epochs=1, + weights_summary=None, + ) + + # construct loops + fit_loop = FitLoop(max_epochs=2) + train_epoch_loop = TrainingEpochLoop(min_steps=0, max_steps=2) + train_batch_loop = TrainingBatchLoop() + val_loop = EvaluationLoop() + + # connect loops together + train_epoch_loop.connect(batch_loop=train_batch_loop, val_loop=val_loop) + fit_loop.connect(epoch_loop=train_epoch_loop) + + # connect fit loop to trainer (main entry point for the call in trainer.fit()) + trainer.fit_loop = fit_loop + + # this will use the newly constructed loop! + trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data) + + # this will still use the default test loop + trainer.test(model, dataloaders=test_data) + + +if __name__ == '__main__': + run() diff --git a/pl_examples/loop_examples/example2.py b/pl_examples/loop_examples/example2.py new file mode 100644 index 0000000000000..4e38c1c830544 --- /dev/null +++ b/pl_examples/loop_examples/example2.py @@ -0,0 +1,40 @@ +import os +from torch.utils.data import DataLoader + +from pl_examples.bug_report_model import RandomDataset +from pl_examples.loop_examples.example1 import BoringModel +from pl_examples.loop_examples.simple_loop import SimpleLoop +from pytorch_lightning import Trainer + + +def run(): + """ + This example shows how to replace the FitLoop on the Trainer with a very simple, custom iteration-based + training loop. + """ + train_data = DataLoader(RandomDataset(32, 64), batch_size=2) + + model = BoringModel() + trainer = Trainer( + default_root_dir=os.getcwd(), + limit_train_batches=1, + limit_val_batches=1, + num_sanity_val_steps=0, + max_epochs=1, + weights_summary=None, + progress_bar_refresh_rate=1, + ) + + # instantiate the new loop + simple_loop = SimpleLoop(num_iterations=1000) + + # replace the fit loop + # the trainer reference will be set internally + trainer.fit_loop = simple_loop + + # fit using the new loop! + trainer.fit(model, train_dataloader=train_data) + + +if __name__ == '__main__': + run() diff --git a/pl_examples/loop_examples/example3.py b/pl_examples/loop_examples/example3.py new file mode 100644 index 0000000000000..e13490f414ebb --- /dev/null +++ b/pl_examples/loop_examples/example3.py @@ -0,0 +1,48 @@ +import os +from torch.utils.data import DataLoader + +from pl_examples.bug_report_model import RandomDataset, BoringModel +from pytorch_lightning import Trainer +from pytorch_lightning.loops import EvaluationLoop, TrainingBatchLoop + + +def run(): + """ + This example shows how to switch out an individual loop. + Here, we want to take the default FitLoop from Lightning but switch out + + 1. the batch_loop inside the training epoch loop + 2. the val_loop inside the training epoch loop + + """ + train_data = DataLoader(RandomDataset(32, 64), batch_size=2) + val_data = DataLoader(RandomDataset(32, 64), batch_size=2) + + model = BoringModel() + + trainer = Trainer( + default_root_dir=os.getcwd(), + limit_train_batches=1, + limit_val_batches=1, + num_sanity_val_steps=0, + max_epochs=1, + weights_summary=None, + ) + + # instantiate the new batch- and validation loop + new_batch_loop = TrainingBatchLoop() + new_val_loop = EvaluationLoop() + + # call connect on the existing, default fit_loop.epoch_loop + trainer.fit_loop.epoch_loop.connect(batch_loop=new_batch_loop, val_loop=new_val_loop) + + # the new batch loop is registered and the trainer got linked internally + assert trainer.fit_loop.epoch_loop.batch_loop == new_batch_loop + assert trainer.fit_loop.epoch_loop.batch_loop.trainer == trainer + + # this uses the new custom batch loop + trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data) + + +if __name__ == '__main__': + run() diff --git a/pl_examples/example2.py b/pl_examples/loop_examples/simple_loop.py similarity index 64% rename from pl_examples/example2.py rename to pl_examples/loop_examples/simple_loop.py index 498c11700cfe2..f02eca00d88b7 100644 --- a/pl_examples/example2.py +++ b/pl_examples/loop_examples/simple_loop.py @@ -1,19 +1,22 @@ -import os from collections import OrderedDict from typing import Any, Dict, Iterator, Optional, Tuple, Union import torch from torch import Tensor from torch.optim import Optimizer -from torch.utils.data import DataLoader, Dataset - -from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.loops import Loop from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection class SimpleLoop(Loop): - """This loop is for demonstration purposes only.""" + """ + This loop is for demonstration purposes only. + It implements a purely iteration-based loop with a bare miminum of functionality. + - 1 optimizer + - no logging + - no grad accumulation + - no epoch hooks calling + """ def __init__(self, num_iterations: int = float("inf")): super().__init__() @@ -110,73 +113,3 @@ def _process_training_step_output(training_step_output: Union[Dict, Tensor]) -> loss = training_step_output return loss, extra - - -class RandomDataset(Dataset): - - def __init__(self, size, length): - self.len = length - self.data = torch.randn(length, size) - - def __getitem__(self, index): - return self.data[index] - - def __len__(self): - return self.len - - -class BoringModel(LightningModule): - - def __init__(self): - super().__init__() - self.layer = torch.nn.Linear(32, 2) - - def forward(self, x): - return self.layer(x) - - def on_train_batch_start(self, batch, batch_idx, dataloader_idx): - self.print("batch start:", batch_idx) - - def training_step(self, batch, batch_idx): - self.print("training_step:", batch_idx) - loss = self(batch).sum() - self.log("train_loss", loss) - return {"loss": loss} - - def backward(self, loss, *args, **kwargs): - self.print("backward:", loss) - return super().backward(loss, *args, **kwargs) - - def optimizer_step(self, *args, **kwargs): - self.print("optimizer_step") - return super().optimizer_step(*args, **kwargs) - - def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): - self.print("batch end:", batch_idx) - - def configure_optimizers(self): - return torch.optim.SGD(self.layer.parameters(), lr=0.1) - - -def run(): - train_data = DataLoader(RandomDataset(32, 64), batch_size=2) - - model = BoringModel() - trainer = Trainer( - default_root_dir=os.getcwd(), - limit_train_batches=1, - limit_val_batches=1, - num_sanity_val_steps=0, - max_epochs=1, - weights_summary=None, - progress_bar_refresh_rate=1, - ) - - simple_loop = SimpleLoop(num_iterations=1000) - trainer.fit_loop = simple_loop - - trainer.fit(model, train_dataloader=train_data) - - -if __name__ == '__main__': - run() From ec1d96055e364ac6d7101951dd63b576538577a4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 8 Jul 2021 19:34:35 +0000 Subject: [PATCH 042/157] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pl_examples/loop_examples/example1.py | 2 +- pl_examples/loop_examples/example2.py | 1 + pl_examples/loop_examples/example3.py | 3 ++- pl_examples/loop_examples/simple_loop.py | 1 + 4 files changed, 5 insertions(+), 2 deletions(-) diff --git a/pl_examples/loop_examples/example1.py b/pl_examples/loop_examples/example1.py index d86cdec74c0d1..6da59586dad2a 100644 --- a/pl_examples/loop_examples/example1.py +++ b/pl_examples/loop_examples/example1.py @@ -2,7 +2,7 @@ from torch.utils.data import DataLoader -from pl_examples.bug_report_model import RandomDataset, BoringModel +from pl_examples.bug_report_model import BoringModel, RandomDataset from pytorch_lightning import Trainer from pytorch_lightning.loops import EvaluationLoop, FitLoop, TrainingBatchLoop, TrainingEpochLoop diff --git a/pl_examples/loop_examples/example2.py b/pl_examples/loop_examples/example2.py index 4e38c1c830544..4a347a1777909 100644 --- a/pl_examples/loop_examples/example2.py +++ b/pl_examples/loop_examples/example2.py @@ -1,4 +1,5 @@ import os + from torch.utils.data import DataLoader from pl_examples.bug_report_model import RandomDataset diff --git a/pl_examples/loop_examples/example3.py b/pl_examples/loop_examples/example3.py index e13490f414ebb..317293c362746 100644 --- a/pl_examples/loop_examples/example3.py +++ b/pl_examples/loop_examples/example3.py @@ -1,7 +1,8 @@ import os + from torch.utils.data import DataLoader -from pl_examples.bug_report_model import RandomDataset, BoringModel +from pl_examples.bug_report_model import BoringModel, RandomDataset from pytorch_lightning import Trainer from pytorch_lightning.loops import EvaluationLoop, TrainingBatchLoop diff --git a/pl_examples/loop_examples/simple_loop.py b/pl_examples/loop_examples/simple_loop.py index f02eca00d88b7..84ea0d6c3821c 100644 --- a/pl_examples/loop_examples/simple_loop.py +++ b/pl_examples/loop_examples/simple_loop.py @@ -4,6 +4,7 @@ import torch from torch import Tensor from torch.optim import Optimizer + from pytorch_lightning.loops import Loop from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection From 7f8000f0bc453761d9e4c6fa092853b4f696dd34 Mon Sep 17 00:00:00 2001 From: tchaton Date: Sat, 10 Jul 2021 11:59:30 +0200 Subject: [PATCH 043/157] resolve comments --- pytorch_lightning/loops/base.py | 187 +++++++++++++------------------- tests/loops/test_loops.py | 88 +++++---------- 2 files changed, 106 insertions(+), 169 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 3f98cf4214523..2405ec49ed704 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -17,11 +17,11 @@ from typing import Any, Dict, Optional from deprecate import void +from torch.nn.modules.module import _IncompatibleKeys import pytorch_lightning as pl from pytorch_lightning.trainer.progress import BaseProgress, ProgressDict from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.warnings import WarningCache warning_cache = WarningCache() @@ -51,89 +51,35 @@ class Loop(ABC): def __init__(self) -> None: self.iteration_count: int = 0 - self.trainer: Optional['pl.Trainer'] = None + self._trainer: Optional['pl.Trainer'] = None self.restarting = False - self._loops = OrderedDict() - self._progress = OrderedDict() - self._has_parent: bool = False - self.__parent_loop: Optional['Loop'] = None - - @property - def has_parent(self) -> Optional[bool]: - """Whether this loop has been attached to another loop""" - return self._has_parent - - @property - def has_children(self) -> bool: - """Whether this loop has any children""" - loops = self.__dict__.get('_loops') - return len(loops) > 0 - - @property - def is_leaf(self) -> bool: - """This loop is a leaf if it doesn't possess any loops.""" - return not self.has_children @property def loop_progress(self) -> Dict[str, Any]: """Return the progress for the current loop and children loop.""" progress = {} - for n, p in self.__dict__.get('_progress').items(): - progress[n] = p - - loops = self.__dict__.get('_loops') - - if loops is not None: - for name, loop in loops.items(): - progress[name] = ProgressDict(**loop.loop_progress) + for k, v in self.__dict__.items(): + if isinstance(v, BaseProgress): + progress[k] = v + elif isinstance(v, Loop): + progress[k] = ProgressDict(**v.loop_progress) return ProgressDict(**progress) - def __setattr__(self, name: str, value: Any) -> None: - if isinstance(value, pl.Trainer): - # when assigning a Trainer to a loop, it will assign to its children too. - object.__setattr__(self, name, value) - for loop in self._loops.values(): - loop.__setattr__(name, value) - elif isinstance(value, Loop): - if name == "_Loop__parent_loop": - object.__setattr__(self, name, value) - return - if getattr(self, "__children__loops__", None) is not None and name not in self.__children__loops__: - raise MisconfigurationException( - f"The current loop accept only {self.__children__loops__} as children attribute names. Found {name}" - ) - if value._has_parent: - raise MisconfigurationException(f"This provided loop {value} already has a parent. ") - self._loops[name] = value - value._has_parent = True - value.__parent_loop = self - elif isinstance(value, BaseProgress): - self._progress[name] = value - else: - object.__setattr__(self, name, value) - - def __getattr__(self, name) -> Any: - loops = self.__dict__.get('_loops') - - if loops is not None and name in loops: - return loops[name] - - progress = self.__dict__.get('_progress') - - if progress is not None and name in progress: - return progress[name] - - return object.__getattribute__(self, name) - - def __delattr__(self, name) -> None: - if name in self._loops: - self._loops[name]._has_parent = False - self._loops[name]._Loop__parent_loop = None - del self._loops[name] - elif name in self._progress: - del self._progress[name] - else: - object.__delattr__(self, name) + @property + def trainer(self) -> Optional['pl.Trainer']: + return self._trainer + + @trainer.setter + def trainer(self, trainer: 'pl.Trainer'): + """Connect the Trainer to itself and all sub-children loops""" + if not isinstance(trainer, pl.Trainer): + raise MisconfigurationException( + f"Loop {self.__class__.__name__} should be connected to a `Trainer`, found: {trainer}." + ) + self._trainer = trainer + for v in self.__dict__.values(): + if isinstance(v, Loop): + v.trainer = trainer @property @abstractmethod @@ -148,10 +94,6 @@ def skip(self) -> bool: def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None: """Connects Loop with all the necessary things like connectors and accelerators.""" # TODO(@justusschock): Make the trainer a weakref/proxy - if not isinstance(trainer, pl.Trainer): - raise MisconfigurationException( - f"Loop {self.__class__.__name__} should be connected to a `Trainer`, found: {trainer}." - ) self.trainer = trainer def on_skip(self) -> Optional[Any]: @@ -178,13 +120,7 @@ def run(self, *args: Any, **kwargs: Any) -> Optional[Any]: if self.skip: return self.on_skip() - if self.restarting: - if not is_overridden("restore", self, Loop): - warning_cache.warn(f"{self.__class__.__name__} Loop doesn't override the restore function.") - self.restore() - self.restarting = False - else: - self.reset() + self.reset() self.on_run_start(*args, **kwargs) @@ -200,9 +136,6 @@ def run(self, *args: Any, **kwargs: Any) -> Optional[Any]: output = self.on_run_end() return output - def restore(self, state: Optional[Dict] = None) -> None: - """Restore the internal state of the loop the beginning of run if restarting is ``True``.""" - @abstractmethod def reset(self) -> None: """Resets the internal state of the loop at the beginning of each call to :attr:`run`.""" @@ -234,33 +167,46 @@ def on_run_end(self) -> Any: def teardown(self) -> None: """Use to release memory etc.""" - def state_dict(self) -> Dict: - """Current Loop state""" + def on_save_checkpoint(self) -> Dict: + """ + Called when saving a model checkpoint, use to persist loop state. + + Returns: + The current loop state. + """ return {} - def load_state_dict(self, state_dict: Dict) -> None: - """Reload Loop state""" + def on_load_checkpoint(self, state_dict: Dict): + """Called when loading a model checkpoint, use to reload loop state.""" - def get_state_dict(self, destination: Optional[OrderedDict] = None, prefix: Optional[str] = '') -> OrderedDict: + def state_dict(self, destination: Optional[OrderedDict] = None, prefix: Optional[str] = '') -> Dict: if destination is None: destination = OrderedDict() - destination[prefix + "state_dict"] = self.state_dict() + destination[prefix + "state_dict"] = self.on_save_checkpoint() - for name, progress in self._progress.items(): - destination[prefix + name] = progress.state_dict() + for k, v in self.__dict__.items(): + if isinstance(v, BaseProgress): + destination[prefix + k] = v.state_dict() + elif isinstance(v, Loop): + v.state_dict(destination, prefix + k + '.') - for name, loop in self._loops.items(): - loop.get_state_dict(destination, prefix + name + '.') return destination - def _load_from_state_dict(self, state_dict, prefix, strict, missing_keys, unexpected_keys, error_msgs): - self.load_state_dict(state_dict[prefix + "state_dict"]) + def _load_from_state_dict( + self, state_dict, prefix, strict, restart_progress, missing_keys, unexpected_keys, error_msgs + ): + print(state_dict, prefix) + for k, v in self.__dict__.items(): + if isinstance(v, BaseProgress): + v.load_state_dict(state_dict[prefix + k]) - for name, progress in self._progress.items(): - progress.load_state_dict(state_dict[prefix + name]) + self.on_load_checkpoint(state_dict[prefix + "state_dict"]) - def _load_state_dict(self, state_dict: Dict, strict: bool = True): + def load_state_dict(self, state_dict: Dict, restart_progress: bool = True, strict: bool = True): + """ + This function is highly inspired from ``PyTorch nn.Module``. + """ missing_keys = [] unexpected_keys = [] @@ -269,11 +215,32 @@ def _load_state_dict(self, state_dict: Dict, strict: bool = True): state_dict = state_dict.copy() def load(loop, prefix=''): - loop._load_from_state_dict(state_dict, prefix, True, missing_keys, unexpected_keys, error_msgs) + if loop.restarting: + return + loop._load_from_state_dict( + state_dict, prefix, True, restart_progress, missing_keys, unexpected_keys, error_msgs + ) loop.restarting = True - for name, loop_children in loop._loops.items(): - if loop_children is not None: - load(loop_children, prefix + name + '.') + for k, v in self.__dict__.items(): + if isinstance(v, Loop): + load(v, prefix + k + '.') load(self) - load = None # break load->load reference cycle + + if strict: + if len(unexpected_keys) > 0: + error_msgs.insert( + 0, 'Unexpected key(s) in state_dict: {}. '.format( + ', '.join('"{}"'.format(k) for k in unexpected_keys) + ) + ) + if len(missing_keys) > 0: + error_msgs.insert( + 0, 'Missing key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in missing_keys)) + ) + + if len(error_msgs) > 0: + raise RuntimeError( + 'Error(s) in loading state_dict for {}:\n\t{}'.format(self.__class__.__name__, "\n\t".join(error_msgs)) + ) + return _IncompatibleKeys(missing_keys, unexpected_keys) diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 5b1b12ebdf054..5ce182c59a476 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -16,12 +16,9 @@ from dataclasses import dataclass from typing import Any, Dict, Iterator -import pytest - from pytorch_lightning.loops.base import Loop from pytorch_lightning.trainer.progress import BaseProgress, ProgressDict from pytorch_lightning.trainer.trainer import Trainer -from pytorch_lightning.utilities.exceptions import MisconfigurationException def test_loop_restore(): @@ -39,19 +36,20 @@ def __init__(self, dataset: Iterator): def skip(self) -> bool: return False - def restore(self) -> None: - self.iter_dataset = iter(self.dataset) - for _ in range(self.iteration_count): - next(self.iter_dataset) - self.iteration_count += 1 - @property def done(self) -> bool: return self.iteration_count > len(self.dataset) def reset(self) -> None: self.iter_dataset = iter(self.dataset) - self.outputs = [] + + if self.restarting: + for _ in range(self.iteration_count): + next(self.iter_dataset) + self.iteration_count += 1 + self.restarting = False + else: + self.outputs = [] def advance(self) -> None: value = next(self.iter_dataset) @@ -104,17 +102,16 @@ def load_state_dict(self, state_dict): class Simple(Loop): - __children__loops__ = ("loop_child", "something") - def __init__(self, a): super().__init__() self.a = a self.progress = SimpleProgress() def advance(self, *args: Any, **kwargs: Any) -> None: - for loop in self._loops.values(): - loop.run() - self.progress.increment += 1 + loop = getattr(self, "loop_child", None) + if not loop: + return + loop.run() self.progress.increment += 1 @property @@ -126,49 +123,23 @@ def done(self) -> bool: return self.iteration_count > 0 def reset(self) -> None: - pass + self.restarting = False - def restore(self) -> None: - pass - - def state_dict(self) -> Dict: + def on_save_checkpoint(self) -> Dict: return {"a": self.a} - def load_state_dict(self, state_dict: Dict) -> None: + def on_load_checkpoint(self, state_dict: Dict) -> None: self.a = state_dict["a"] grand_loop_parent = Simple(0) loop_parent = Simple(1) loop_child = Simple(2) - assert not loop_child.has_parent loop_parent.loop_child = loop_child - assert loop_child._Loop__parent_loop == loop_parent - - assert loop_child.has_parent - - with pytest.raises(MisconfigurationException, match="already has a parent"): - loop_parent.loop_child = loop_child - assert not loop_parent.skip - with pytest.raises(MisconfigurationException, match="already has a parent"): - loop_parent.something = loop_child - - with pytest.raises(MisconfigurationException, match="Loop hasn't been attached to any Trainer."): - grand_loop_parent.run() - - with pytest.raises(MisconfigurationException, match="already has a parent"): - grand_loop_parent.loop_child = loop_child - - assert loop_child.has_parent - assert loop_parent.has_children - - state_dict = loop_parent.get_state_dict() - - with pytest.raises(MisconfigurationException, match="The current loop accept only"): - loop_parent.wrong_name = loop_child + state_dict = loop_parent.state_dict() loop_progress: ProgressDict = loop_parent.loop_progress assert loop_progress["progress"] == loop_parent.progress @@ -197,42 +168,41 @@ def load_state_dict(self, state_dict: Dict) -> None: loop_parent.progress state_dict["loop_child.state_dict"]["a"] = 3 - loop_parent._load_state_dict(state_dict) + + loop_parent.load_state_dict(state_dict) assert loop_parent.restarting loop_parent.run() loop_parent_copy = deepcopy(loop_parent) - assert loop_parent_copy.get_state_dict() == loop_parent.get_state_dict() + assert loop_parent_copy.state_dict() == loop_parent.state_dict() - assert loop_parent_copy.state_dict() == {'a': 1} - assert loop_parent_copy.loop_child.state_dict() == {'a': 3} + assert loop_parent_copy.on_save_checkpoint() == {'a': 1} + assert loop_parent_copy.loop_child.on_save_checkpoint() == {'a': 3} assert not loop_parent.restarting - state_dict = loop_parent.get_state_dict() + state_dict = loop_parent.state_dict() assert state_dict == OrderedDict([('state_dict', { 'a': 1 }), ('progress', { - 'increment': 2 + 'increment': 1 }), ('loop_child.state_dict', { 'a': 3 }), ('loop_child.progress', { - 'increment': 1 + 'increment': 0 })]) loop_parent = Simple(1) loop_child = Simple(2) loop_parent.loop_child = loop_child - loop_parent._load_state_dict(state_dict) - assert loop_parent.progress.increment == 2 - assert loop_parent.loop_child.progress.increment == 1 + loop_parent.load_state_dict(state_dict) + assert loop_parent.progress.increment == 1 + assert loop_parent.loop_child.progress.increment == 0 del loop_parent.loop_child - assert not loop_child.has_parent - assert loop_child._Loop__parent_loop is None - state_dict = loop_parent.get_state_dict() - assert state_dict == OrderedDict([('state_dict', {'a': 1}), ('progress', {'increment': 2})]) + state_dict = loop_parent.state_dict() + assert state_dict == OrderedDict([('state_dict', {'a': 1}), ('progress', {'increment': 1})]) grand_loop_parent = Simple(0) loop_parent = Simple(1) From 4153b8106a46d26261bebc7558e5d3b7cb1cfc47 Mon Sep 17 00:00:00 2001 From: tchaton Date: Sat, 10 Jul 2021 12:07:07 +0200 Subject: [PATCH 044/157] update --- pytorch_lightning/loops/base.py | 7 +++---- pytorch_lightning/trainer/progress.py | 11 +---------- tests/loops/test_loops.py | 8 ++------ 3 files changed, 6 insertions(+), 20 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 2405ec49ed704..629e443b788a7 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -20,7 +20,7 @@ from torch.nn.modules.module import _IncompatibleKeys import pytorch_lightning as pl -from pytorch_lightning.trainer.progress import BaseProgress, ProgressDict +from pytorch_lightning.trainer.progress import BaseProgress from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.warnings import WarningCache @@ -62,8 +62,8 @@ def loop_progress(self) -> Dict[str, Any]: if isinstance(v, BaseProgress): progress[k] = v elif isinstance(v, Loop): - progress[k] = ProgressDict(**v.loop_progress) - return ProgressDict(**progress) + progress[k] = v.loop_progress + return progress @property def trainer(self) -> Optional['pl.Trainer']: @@ -196,7 +196,6 @@ def state_dict(self, destination: Optional[OrderedDict] = None, prefix: Optional def _load_from_state_dict( self, state_dict, prefix, strict, restart_progress, missing_keys, unexpected_keys, error_msgs ): - print(state_dict, prefix) for k, v in self.__dict__.items(): if isinstance(v, BaseProgress): v.load_state_dict(state_dict[prefix + k]) diff --git a/pytorch_lightning/trainer/progress.py b/pytorch_lightning/trainer/progress.py index 54b85273d9c0a..3acae2485cea0 100644 --- a/pytorch_lightning/trainer/progress.py +++ b/pytorch_lightning/trainer/progress.py @@ -12,16 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import asdict, dataclass, field -from typing import Dict, Optional - - -class ProgressDict(Dict): - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - for k, v in kwargs.items(): - setattr(self, k, v) +from typing import Optional @dataclass diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 5ce182c59a476..03418d5e70430 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -17,7 +17,7 @@ from typing import Any, Dict, Iterator from pytorch_lightning.loops.base import Loop -from pytorch_lightning.trainer.progress import BaseProgress, ProgressDict +from pytorch_lightning.trainer.progress import BaseProgress from pytorch_lightning.trainer.trainer import Trainer @@ -141,16 +141,12 @@ def on_load_checkpoint(self, state_dict: Dict) -> None: state_dict = loop_parent.state_dict() - loop_progress: ProgressDict = loop_parent.loop_progress + loop_progress = loop_parent.loop_progress assert loop_progress["progress"] == loop_parent.progress assert loop_progress["loop_child"]["progress"] == loop_child.progress - assert loop_progress.progress == loop_parent.progress - assert loop_progress.loop_child.progress == loop_child.progress - loop_progress = loop_child.loop_progress assert loop_progress["progress"] == loop_child.progress - assert loop_progress.progress == loop_child.progress loop_parent.trainer = Trainer() assert loop_child.trainer == loop_parent.trainer From d6280e0677aa3e975ed6a8ae97caf2742a92a08f Mon Sep 17 00:00:00 2001 From: tchaton Date: Sat, 10 Jul 2021 12:08:16 +0200 Subject: [PATCH 045/157] update --- pytorch_lightning/loops/base.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 629e443b788a7..709556356701c 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -22,9 +22,6 @@ import pytorch_lightning as pl from pytorch_lightning.trainer.progress import BaseProgress from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.warnings import WarningCache - -warning_cache = WarningCache() class Loop(ABC): From c499c241335a249314cb43b3cb9e79c42ac6a94b Mon Sep 17 00:00:00 2001 From: tchaton Date: Sat, 10 Jul 2021 12:10:01 +0200 Subject: [PATCH 046/157] update changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 002e01098ab16..42dbdb2f8d277 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -342,6 +342,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed deprecated `optimizer` argument in `LightningModule.manual_backward()`; Toggling optimizers in manual optimization should be done using `LightningModule.{un}toggle_optimizer()` ([#8287](https://github.com/PyTorchLightning/pytorch-lightning/pull/8287)) +- Removed `Loop restore` function to give more control for loop restart ([#8334](https://github.com/PyTorchLightning/pytorch-lightning/pull/8334)) + + ### Fixed - Fixed `lr_scheduler` checkpointed state by calling `update_lr_schedulers` before saving checkpoints ([#7877](https://github.com/PyTorchLightning/pytorch-lightning/pull/7877)) @@ -401,6 +404,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed missing call to `LightningModule.untoggle_optimizer` in training loop when running gradient accumulation with multiple optimizers ([#8284](https://github.com/PyTorchLightning/pytorch-lightning/pull/8284)) + ## [1.3.8] - 2021-07-01 ### Fixed From 3cb6df2a4cc87386dccc4855de1d546d0c067d6d Mon Sep 17 00:00:00 2001 From: tchaton Date: Sat, 10 Jul 2021 12:45:09 +0200 Subject: [PATCH 047/157] update --- pytorch_lightning/loops/base.py | 5 +- tests/loops/test_loop_state_dict.py | 206 +++++++++++++++++++++++++++- 2 files changed, 202 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 709556356701c..731192858a308 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -13,7 +13,6 @@ # limitations under the License. from abc import ABC, abstractmethod -from collections import OrderedDict from typing import Any, Dict, Optional from deprecate import void @@ -176,9 +175,9 @@ def on_save_checkpoint(self) -> Dict: def on_load_checkpoint(self, state_dict: Dict): """Called when loading a model checkpoint, use to reload loop state.""" - def state_dict(self, destination: Optional[OrderedDict] = None, prefix: Optional[str] = '') -> Dict: + def state_dict(self, destination: Optional[Dict] = None, prefix: Optional[str] = '') -> Dict: if destination is None: - destination = OrderedDict() + destination = {} destination[prefix + "state_dict"] = self.on_save_checkpoint() diff --git a/tests/loops/test_loop_state_dict.py b/tests/loops/test_loop_state_dict.py index 1930dc46566fd..eed23a89a8b36 100644 --- a/tests/loops/test_loop_state_dict.py +++ b/tests/loops/test_loop_state_dict.py @@ -40,15 +40,209 @@ def test_loops_state_dict_structure(): "test_loop": trainer.test_loop.state_dict(), "predict_loop": trainer.predict_loop.state_dict(), } + # todo (tchaton) Update this once new progress as been added. + # yapf: disable expected = { "fit_loop": { - 'epoch_loop': { - 'batch_loop': {}, - 'val_loop': {}, + "epoch_loop": { + "batch_loop": { + "state_dict": {}, + "progress": { + "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "current": { + "ready": 0, + "started": 0, + "processed": 0, + "completed": 0, + }, + }, + "optim_progress": { + "optimizer": { + "step": { + "total": { + "ready": 0, + "started": 0, + "processed": None, + "completed": 0, + }, + "current": { + "ready": 0, + "started": 0, + "processed": None, + "completed": 0, + }, + }, + "zero_grad": { + "total": { + "ready": 0, + "started": 0, + "processed": None, + "completed": 0, + }, + "current": { + "ready": 0, + "started": 0, + "processed": None, + "completed": 0, + }, + }, + }, + "scheduler": { + "total": { + "ready": 0, + "started": None, + "processed": None, + "completed": 0, + }, + "current": { + "ready": 0, + "started": None, + "processed": None, + "completed": 0, + }, + }, + }, + }, + "val_loop": { + "state_dict": {}, + "progress": { + "epoch": { + "total": { + "ready": 0, + "started": 0, + "processed": 0, + "completed": 0, + }, + "current": { + "ready": 0, + "started": 0, + "processed": 0, + "completed": 0, + }, + "batch": { + "total": { + "ready": 0, + "started": 0, + "processed": 0, + "completed": 0, + }, + "current": { + "ready": 0, + "started": 0, + "processed": 0, + "completed": 0, + }, + }, + } + }, + "epoch_loop.state_dict": {}, + "epoch_loop.progress": { + "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "current": { + "ready": 0, + "started": 0, + "processed": 0, + "completed": 0, + }, + "batch": { + "total": { + "ready": 0, + "started": 0, + "processed": 0, + "completed": 0, + }, + "current": { + "ready": 0, + "started": 0, + "processed": 0, + "completed": 0, + }, + }, + }, + }, } }, - "validate_loop": {}, - "test_loop": {}, - "predict_loop": {}, + "validate_loop": { + "state_dict": {}, + "progress": { + "epoch": { + "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "batch": { + "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "current": { + "ready": 0, + "started": 0, + "processed": 0, + "completed": 0, + }, + }, + } + }, + "epoch_loop.state_dict": {}, + "epoch_loop.progress": { + "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "batch": { + "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + }, + }, + }, + "test_loop": { + "state_dict": {}, + "progress": { + "epoch": { + "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "batch": { + "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "current": { + "ready": 0, + "started": 0, + "processed": 0, + "completed": 0, + }, + }, + } + }, + "epoch_loop.state_dict": {}, + "epoch_loop.progress": { + "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "batch": { + "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + }, + }, + }, + "predict_loop": { + "state_dict": {}, + "progress": { + "epoch": { + "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "batch": { + "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "current": { + "ready": 0, + "started": 0, + "processed": 0, + "completed": 0, + }, + }, + } + }, + "epoch_loop.state_dict": {}, + "epoch_loop.progress": { + "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "batch": { + "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + }, + }, + }, } + # yapf: enable assert state_dict == expected From e8c12e95d64447919a17cb4e246b4ce35a7ab7eb Mon Sep 17 00:00:00 2001 From: tchaton Date: Sat, 10 Jul 2021 13:19:05 +0200 Subject: [PATCH 048/157] update --- pytorch_lightning/loops/base.py | 60 +++++++-------------------- pytorch_lightning/trainer/progress.py | 13 ++++++ 2 files changed, 29 insertions(+), 44 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 731192858a308..0f718ff2ce1be 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -16,10 +16,10 @@ from typing import Any, Dict, Optional from deprecate import void -from torch.nn.modules.module import _IncompatibleKeys import pytorch_lightning as pl -from pytorch_lightning.trainer.progress import BaseProgress +from pytorch_lightning.trainer.progress import BaseProgress, Tracker +from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -189,53 +189,25 @@ def state_dict(self, destination: Optional[Dict] = None, prefix: Optional[str] = return destination - def _load_from_state_dict( - self, state_dict, prefix, strict, restart_progress, missing_keys, unexpected_keys, error_msgs - ): + def _load_from_state_dict(self, state_dict, prefix, restart_progress): for k, v in self.__dict__.items(): if isinstance(v, BaseProgress): v.load_state_dict(state_dict[prefix + k]) + if restart_progress: - self.on_load_checkpoint(state_dict[prefix + "state_dict"]) + def restart(v: Tracker): + v.reset_on_restart() - def load_state_dict(self, state_dict: Dict, restart_progress: bool = True, strict: bool = True): - """ - This function is highly inspired from ``PyTorch nn.Module``. - """ + apply_to_collection(v, Tracker, restart) - missing_keys = [] - unexpected_keys = [] - error_msgs = [] + self.on_load_checkpoint(state_dict[prefix + "state_dict"]) + self.restarting = True - state_dict = state_dict.copy() + def __load(self, state_dict, restart_progress, prefix=''): + self._load_from_state_dict(state_dict, prefix, restart_progress) + for k, v in self.__dict__.items(): + if isinstance(v, Loop): + v.__load(state_dict.copy(), restart_progress, prefix + k + '.') - def load(loop, prefix=''): - if loop.restarting: - return - loop._load_from_state_dict( - state_dict, prefix, True, restart_progress, missing_keys, unexpected_keys, error_msgs - ) - loop.restarting = True - for k, v in self.__dict__.items(): - if isinstance(v, Loop): - load(v, prefix + k + '.') - - load(self) - - if strict: - if len(unexpected_keys) > 0: - error_msgs.insert( - 0, 'Unexpected key(s) in state_dict: {}. '.format( - ', '.join('"{}"'.format(k) for k in unexpected_keys) - ) - ) - if len(missing_keys) > 0: - error_msgs.insert( - 0, 'Missing key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in missing_keys)) - ) - - if len(error_msgs) > 0: - raise RuntimeError( - 'Error(s) in loading state_dict for {}:\n\t{}'.format(self.__class__.__name__, "\n\t".join(error_msgs)) - ) - return _IncompatibleKeys(missing_keys, unexpected_keys) + def load_state_dict(self, state_dict: Dict, restart_progress: bool = True): + self.__load(state_dict.copy(), restart_progress) diff --git a/pytorch_lightning/trainer/progress.py b/pytorch_lightning/trainer/progress.py index 3acae2485cea0..1098957033855 100644 --- a/pytorch_lightning/trainer/progress.py +++ b/pytorch_lightning/trainer/progress.py @@ -70,6 +70,19 @@ def __repr__(self): args = [f"{k}={v}" for k, v in self.__dict__.items() if v is not None] return f"{self.__class__.__name__}({', '.join(args)})" + def reset_on_restart(self): + """Reset the progress on restart""" + value = self.completed if self.processed is None else self.processed + + if self.ready is not None: + self.ready = value + if self.started is not None: + self.started = value + if self.processed is not None: + self.processed = value + if self.completed is not None: + self.completed = value + @dataclass class Progress(BaseProgress): From df4b1ba3b5b8be6ef8efc528c61cda9524439f82 Mon Sep 17 00:00:00 2001 From: tchaton Date: Sat, 10 Jul 2021 13:19:54 +0200 Subject: [PATCH 049/157] remove space --- CHANGELOG.md | 1 - 1 file changed, 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 42dbdb2f8d277..14db839d0d6e6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -404,7 +404,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed missing call to `LightningModule.untoggle_optimizer` in training loop when running gradient accumulation with multiple optimizers ([#8284](https://github.com/PyTorchLightning/pytorch-lightning/pull/8284)) - ## [1.3.8] - 2021-07-01 ### Fixed From ee8d9b80368eacd48d1f10ec99df635efdb9ea44 Mon Sep 17 00:00:00 2001 From: tchaton Date: Sat, 10 Jul 2021 13:20:57 +0200 Subject: [PATCH 050/157] update --- pytorch_lightning/loops/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 0f718ff2ce1be..98248a5e631de 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -67,7 +67,7 @@ def trainer(self) -> Optional['pl.Trainer']: @trainer.setter def trainer(self, trainer: 'pl.Trainer'): - """Connect the Trainer to itself and all sub-children loops""" + """Connect the Trainer to itself and all its children loops""" if not isinstance(trainer, pl.Trainer): raise MisconfigurationException( f"Loop {self.__class__.__name__} should be connected to a `Trainer`, found: {trainer}." From 65540a88624106bb070b6353913668edc3331ad2 Mon Sep 17 00:00:00 2001 From: tchaton Date: Sat, 10 Jul 2021 14:14:42 +0200 Subject: [PATCH 051/157] add progress tracking to loops --- pytorch_lightning/core/optimizer.py | 1 + pytorch_lightning/loops/base.py | 1 - .../loops/batch/training_batch_loop.py | 58 +++-- .../loops/dataloader/evaluation_loop.py | 33 ++- .../loops/dataloader/prediction_loop.py | 2 - .../loops/epoch/evaluation_epoch_loop.py | 25 +- .../loops/epoch/training_epoch_loop.py | 43 ++-- pytorch_lightning/loops/fit_loop.py | 19 +- .../connectors/checkpoint_connector.py | 26 ++ .../trainer/connectors/optimizer_connector.py | 7 + pytorch_lightning/trainer/progress.py | 77 +++--- pytorch_lightning/trainer/trainer.py | 9 +- pytorch_lightning/utilities/imports.py | 5 + tests/trainer/test_progress.py | 241 +++++++++++++----- 14 files changed, 356 insertions(+), 191 deletions(-) diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index 25e4519eb39fc..33b44a35d31f1 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -207,6 +207,7 @@ def closure_dis(): profiler_name = f"optimizer_step_and_closure_{self._optimizer_idx}" self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs) + self._trainer.fit_loop.epoch_loop.batch_loop.optim_progress.optimizer.step.increment_processed() self._total_optimizer_step_calls += 1 def __repr__(self): diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 98248a5e631de..0b5df30003ba8 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -186,7 +186,6 @@ def state_dict(self, destination: Optional[Dict] = None, prefix: Optional[str] = destination[prefix + k] = v.state_dict() elif isinstance(v, Loop): v.state_dict(destination, prefix + k + '.') - return destination def _load_from_state_dict(self, state_dict, prefix, restart_progress): diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 41ad9280ffaf7..89ea775977817 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -28,7 +28,7 @@ from pytorch_lightning.loops.base import Loop from pytorch_lightning.plugins import ParallelPlugin from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection -from pytorch_lightning.trainer.progress import BatchProgress, OptimizationProgress +from pytorch_lightning.trainer.progress import OptimizationProgress, Progress from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import AMPType, AttributeDict, DeviceType, grad_norm from pytorch_lightning.utilities.apply_func import apply_to_collection @@ -50,29 +50,17 @@ def __init__(self) -> None: self.running_loss: TensorRunningAccum = TensorRunningAccum(window_length=20) self.batch_idx: int = 0 self.split_idx: Optional[int] = None - self.progress = BatchProgress() + self.progress = Progress() self.optim_progress = OptimizationProgress() - self._warning_cache: WarningCache = WarningCache() self._hiddens: Optional[Tensor] = None self._optimizer_freq_cumsum: Optional[int] = None self._remaining_splits: Optional[List[Any]] = None self._skip_backward: bool = False - def connect( - self, - trainer: 'pl.Trainer', - *args: Any, - progress: Optional[BatchProgress] = None, - optim_progress: Optional[OptimizationProgress] = None, - **kwargs: Any - ) -> None: + def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None: """Connects the loop with necessary arguments like the trainer""" super().connect(trainer, *args, **kwargs) - if progress is not None: - self.progress = progress - if optim_progress is not None: - self.optim_progress = optim_progress @property def done(self) -> bool: @@ -98,6 +86,8 @@ def run(self, batch: Any, batch_idx: int, dataloader_idx: int) -> AttributeDict: self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...") return AttributeDict(signal=0, training_step_output=[[]]) + self.progress.increment_ready() + # hook self.trainer.logger_connector.on_batch_start() response = self.trainer.call_hook("on_batch_start") @@ -120,6 +110,8 @@ def reset(self) -> None: self.batch_idx = 0 self.batch_outputs = [[] for _ in range(len(self.trainer.optimizers))] + self.optim_progress.optimizer.reset_on_epoch() + def on_run_start(self, batch: Any, batch_idx: int, dataloader_idx: int): """Splits the data into tbptt splits @@ -131,6 +123,10 @@ def on_run_start(self, batch: Any, batch_idx: int, dataloader_idx: int): void(batch_idx, dataloader_idx) self._remaining_splits = list(enumerate(self._tbptt_split_batch(batch))) + def on_advance_start(self, *args: Any, **kwargs: Any) -> None: + super().on_advance_start(*args, **kwargs) + self.progress.increment_started() + def advance(self, batch, batch_idx, dataloader_idx): """Runs the train step together with optimization (if necessary) on the current batch split @@ -148,7 +144,18 @@ def advance(self, batch, batch_idx, dataloader_idx): self.trainer.logger_connector.on_train_split_start(batch_idx, split_idx, split_batch) if self.trainer.lightning_module.automatic_optimization: - for opt_idx, optimizer in self.get_active_optimizers(batch_idx): + active_optimizers = self.get_active_optimizers(batch_idx) + for opt_idx, optimizer in active_optimizers: + + # handle optimization restart + if self.restarting: + if len(active_optimizers) > 1 and opt_idx < self.progress.current.completed: + continue + self.restarting = False + + # track optimizer_idx + self.optim_progress.optimizer_idx = opt_idx + result = self._run_optimization(batch_idx, split_batch, opt_idx, optimizer) if result: self.batch_outputs[opt_idx].append(result.training_step_output) @@ -158,6 +165,12 @@ def advance(self, batch, batch_idx, dataloader_idx): if result: self.batch_outputs[0].append(result.training_step_output) + self.progress.increment_processed() + + def on_advance_end(self) -> None: + super().on_advance_end() + self.progress.increment_completed() + def teardown(self) -> None: # release memory self._remaining_splits = None @@ -240,6 +253,11 @@ def _training_step_and_backward_closure( result = self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens) if result is not None: return_result.update(result) + + # this should be done only if result.loss exists and ``optimizer step`` is being run + if not self.should_accumulate(): + self.optim_progress.optimizer.step.increment_started() + return return_result.loss def _make_closure(self, *closure_args: Any, **closure_kwargs: Any) -> Callable: @@ -409,6 +427,8 @@ def _optimizer_step( # wraps into LightningOptimizer only for running step optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, opt_idx) + self.optim_progress.optimizer.step.increment_ready() + # model hook model_ref.optimizer_step( self.trainer.current_epoch, @@ -421,13 +441,17 @@ def _optimizer_step( using_lbfgs=is_lbfgs, ) + self.optim_progress.optimizer.step.increment_completed() + def _on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None: """Calls the ``on_before_zero_grad`` hook. Args: optimizer: the current optimizer """ + self.optim_progress.optimizer.zero_grad.increment_started() self.trainer.call_hook('on_before_zero_grad', optimizer) + self.optim_progress.optimizer.zero_grad.increment_ready() def _optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer, opt_idx: int) -> None: """Zeroes out all gradients of parameters optimized by the current optimizer. @@ -439,6 +463,8 @@ def _optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer, """ self.trainer.accelerator.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx) + self.optim_progress.optimizer.zero_grad.increment_completed() + def _track_and_norm_grad(self, optimizer: torch.optim.Optimizer) -> Dict[str, Tensor]: """Tracks gradient norms and clips the gradients of all parameters optimized by the current optimizer. diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 2f6e14b93b767..5dc0270f58774 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -21,7 +21,7 @@ from pytorch_lightning.loops.dataloader import DataLoaderLoop from pytorch_lightning.loops.epoch import EvaluationEpochLoop from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection -from pytorch_lightning.trainer.progress import EpochLoopProgress +from pytorch_lightning.trainer.progress import DataLoaderProgress from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import EPOCH_OUTPUT @@ -33,10 +33,8 @@ class EvaluationLoop(DataLoaderLoop): def __init__(self): super().__init__() self.outputs = [] - self.progress = EpochLoopProgress() - + self.progress = DataLoaderProgress() self.epoch_loop = EvaluationEpochLoop() - self._results = ResultCollection(training=False) self._max_batches: Optional[Union[int, Sequence[int]]] = None self._has_run: bool = False @@ -66,14 +64,10 @@ def predictions(self): """Returns the predictions from all dataloaders""" return self.epoch_loop.predictions - def connect( - self, trainer: "pl.Trainer", *args: Any, progress: Optional[EpochLoopProgress] = None, **kwargs: Any - ) -> None: + def connect(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: """Connects the loop with necessary arguments like the trainer""" super().connect(trainer, *args, **kwargs) - if progress is not None: - self.progress = progress - self.epoch_loop.connect(trainer, progress=self.progress.epoch) + self.epoch_loop.connect(trainer) @property def done(self) -> bool: @@ -96,18 +90,31 @@ def reset(self) -> None: if isinstance(self._max_batches, int): self._max_batches = [self._max_batches] * len(self.dataloaders) + if self.restarting: + self.iteration_count = self.progress.dataloader_idx + self.restarting = False + else: + self.iteration_count = 0 + # reset batch / epoch progress tracking + self.progress.current.reset() + def on_skip(self) -> List: return [] def on_run_start(self, *args: Any, **kwargs: Any) -> None: """Runs the ``on_evaluation_model_eval``, ``on_evaluation_start`` and ``on_evaluation_epoch_start`` hooks""" void(*args, **kwargs) + + self.progress.increment_started() + # hook self.on_evaluation_model_eval() self.trainer.lightning_module.zero_grad() self.on_evaluation_start() self.on_evaluation_epoch_start() + self.progress.increment_ready() + def advance(self, *args: Any, **kwargs: Any) -> None: """Performs evaluation on one single dataloader""" void(*args, **kwargs) @@ -115,6 +122,8 @@ def advance(self, *args: Any, **kwargs: Any) -> None: dataloader_iter = enumerate(dataloader) dl_max_batches = self._max_batches[self.current_dataloader_idx] + self.progress.dataloader_idx = self.iteration_count + dl_outputs = self.epoch_loop.run( dataloader_iter, self.current_dataloader_idx, @@ -141,6 +150,8 @@ def on_run_end(self) -> Any: if len(outputs) > 0 and self.num_dataloaders == 1: outputs = outputs[0] + self.progress.increment_processed() + # lightning module method self.evaluation_epoch_end(outputs) @@ -159,6 +170,8 @@ def on_run_end(self) -> Any: # enable train mode again self.on_evaluation_model_train() + self.progress.increment_completed() + return eval_loop_results def get_max_batches(self) -> List[Union[int, float]]: diff --git a/pytorch_lightning/loops/dataloader/prediction_loop.py b/pytorch_lightning/loops/dataloader/prediction_loop.py index 55647e5d7f2a3..1bdd38ed950b0 100644 --- a/pytorch_lightning/loops/dataloader/prediction_loop.py +++ b/pytorch_lightning/loops/dataloader/prediction_loop.py @@ -20,9 +20,7 @@ def __init__(self): self.predictions: Optional[List[List[Any]]] = None self.epoch_batch_indices: Optional[List[List[int]]] = None self.progress = EpochLoopProgress() - self.epoch_loop = PredictionEpochLoop() - self._results = None # for `trainer._results` access self._return_predictions: bool = False diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index c01b20a5f84e2..c56b4a7f097d1 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -21,7 +21,7 @@ import pytorch_lightning as pl from pytorch_lightning.loops.base import Loop from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection -from pytorch_lightning.trainer.progress import EpochProgress +from pytorch_lightning.trainer.progress import Progress from pytorch_lightning.trainer.supporters import PredictionCollection from pytorch_lightning.utilities.memory import recursive_detach from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -41,15 +41,11 @@ def __init__(self) -> None: self.dataloader_idx: Optional[int] = None self.num_dataloaders: Optional[int] = None self.outputs: List[STEP_OUTPUT] = [] - self.progress = EpochProgress() + self.progress = Progress() - def connect( - self, trainer: "pl.Trainer", *args: Any, progress: Optional[EpochProgress] = None, **kwargs: Any - ) -> None: + def connect(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: """Connects the loop with necessary arguments like the trainer""" super().connect(trainer, *args, **kwargs) - if progress is not None: - self.progress = progress @property def done(self) -> bool: @@ -65,6 +61,13 @@ def reset(self) -> None: self.num_dataloaders = None self.outputs = [] + if self.restarting: + self.iteration_count = self.progress.current.completed + self.restarting = False + else: + self.iteration_count = 0 + self.progress.current.reset() + def on_run_start( self, dataloader_iter: Iterator, @@ -114,9 +117,13 @@ def advance( with self.trainer.profiler.profile("evaluation_batch_to_device"): batch = self.trainer.accelerator.batch_to_device(batch, dataloader_idx=dataloader_idx) + self.progress.increment_started() + # hook self.on_evaluation_batch_start(batch, batch_idx, dataloader_idx) + self.progress.increment_ready() + # lightning module methods with self.trainer.profiler.profile("evaluation_step_and_end"): output = self.evaluation_step(batch, batch_idx, dataloader_idx) @@ -131,6 +138,10 @@ def advance( # track epoch level outputs self.outputs = self._track_output_for_epoch_end(self.outputs, output) + self.progress.increment_processed() + + self.progress.increment_completed() + def on_run_end(self) -> List[STEP_OUTPUT]: """Returns the outputs of the whole run""" outputs = self.outputs diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index bc378c6bed0fb..af4e4fc52d63f 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -69,19 +69,11 @@ def done(self) -> bool: max_steps_reached = self.max_steps is not None and self.global_step >= self.max_steps return max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(self.is_last_batch) - def connect( - self, - trainer: 'pl.Trainer', - *args: Any, - progress: Optional[TrainingEpochProgress] = None, - **kwargs: Any - ) -> None: + def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None: """Connects the loop with necessary arguments like the trainer""" super().connect(trainer, *args, **kwargs) - if progress is not None: - self.progress = progress - self.batch_loop.connect(trainer, progress=self.progress.batch, optim_progress=self.progress.optim) - self.val_loop.connect(trainer, progress=self.progress.val) + self.batch_loop.connect(trainer) + self.val_loop.connect(trainer) def reset(self) -> None: """Resets the internal state of the loop for a new run""" @@ -93,12 +85,25 @@ def reset(self) -> None: # track epoch output self._epoch_output = [[] for _ in range(self.batch_loop.num_active_optimizers(self.total_batch_idx))] + if self.restarting: + self.iteration_count = self.batch_loop.current_batch_completed + self.batches_seen = self.batch_loop.current_batch_completed + # restarting is finished. + self.restarting = False + else: + # todo (tchaton) the batch_loop should be responsible for that. + self.batch_loop.progress.current.reset() + def on_run_start(self, *args: Any, **kwargs: Any) -> None: + self.progress.increment_ready() + # hook self.trainer.logger_connector.on_epoch_start() self.trainer.call_hook("on_epoch_start") self.trainer.call_hook("on_train_epoch_start") + self.progress.increment_started() + def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None: """Runs a single training batch. @@ -158,7 +163,10 @@ def on_advance_end(self): # ----------------------------------------- # VALIDATE IF NEEDED + CHECKPOINT CALLBACK # ----------------------------------------- - should_check_val = self._should_check_val_fx(self.iteration_count, self.is_last_batch) + self.progress.should_check_val = should_check_val = self._should_check_val_fx( + self.iteration_count, self.is_last_batch + ) + if should_check_val: self.trainer.validating = True self._run_validation() @@ -216,11 +224,15 @@ def on_run_end(self) -> List[List[STEP_OUTPUT]]: 'HINT: remove the return statement in training_epoch_end' ) + self.progress.increment_processed() + # call train epoch end hooks self._on_train_epoch_end_hook(processed_outputs) self.trainer.call_hook('on_epoch_end') self.trainer.logger_connector.on_epoch_end() + self.progress.increment_completed() + epoch_output = self._epoch_output # free memory self._epoch_output = None @@ -430,10 +442,3 @@ def _save_loggers_on_train_batch_end(self) -> None: should_flush_logs = self.trainer.logger_connector.should_flush_logs if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None: self.trainer.logger.save() - - def state_dict(self) -> Dict: - return {"batch_loop": self.batch_loop.state_dict(), "val_loop": self.val_loop.state_dict()} - - def load_state_dict(self, state_dict: Dict) -> None: - self.batch_loop.load_state_dict(state_dict["batch_loop"]) - self.val_loop.load_state_dict(state_dict["val_loop"]) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index a8eb44923a241..6963f4b3f2c4a 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -14,13 +14,12 @@ import logging from contextlib import suppress -from typing import Any, Dict, Optional +from typing import Any, Optional import pytorch_lightning as pl from pytorch_lightning.loops import Loop from pytorch_lightning.loops.epoch import TrainingEpochLoop from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection -from pytorch_lightning.trainer.progress import FitLoopProgress from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import rank_zero_info @@ -51,8 +50,6 @@ def __init__( super().__init__() self.max_epochs = 1000 if (max_epochs is None and max_steps is None) else max_epochs self.min_epochs = 1 if (min_epochs is None and min_steps is None) else min_epochs - self.progress = FitLoopProgress() - self.epoch_loop = TrainingEpochLoop(min_steps, max_steps) @property @@ -169,14 +166,10 @@ def skip(self) -> bool: """Whether we should skip the training and immediately return from the call to :meth:`run`.""" return self.done or self.trainer.num_training_batches == 0 - def connect( - self, trainer: 'pl.Trainer', *args: Any, progress: Optional[FitLoopProgress] = None, **kwargs: Any - ) -> None: + def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None: """Connects the loop with necessary arguments like the trainer""" super().connect(trainer, *args, **kwargs) - if progress is not None: - self.progress = progress - self.epoch_loop.connect(trainer, progress=self.progress.epoch) + self.epoch_loop.connect(trainer) def reset(self) -> None: """Resets the internal state of this loop""" @@ -289,11 +282,5 @@ def _check_checkpoint_callback(self, should_update: bool, is_last: bool = False) for cb in callbacks: cb.on_validation_end(self.trainer, model) - def state_dict(self) -> Dict: - return {"epoch_loop": self.epoch_loop.state_dict()} - - def load_state_dict(self, state_dict: Dict) -> None: - self.epoch_loop.load_state_dict(state_dict["epoch_loop"]) - def teardown(self) -> None: self.epoch_loop.teardown() diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index ab74c3bccfc8d..df1328d668305 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -23,6 +23,7 @@ from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_deprecation, rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import fault_tolerant_enabled from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS if _OMEGACONF_AVAILABLE: @@ -165,6 +166,8 @@ def restore_training_state(self) -> None: self.restore_optimizers_and_schedulers() + self.restore_loops() + def restore_callbacks(self) -> None: """ Restores all callbacks from the pre-loaded checkpoint. """ if not self._loaded_checkpoint: @@ -249,6 +252,18 @@ def restore_lr_schedulers(self) -> None: for scheduler, lrs_state in zip(self.trainer.lr_schedulers, lr_schedulers): scheduler['scheduler'].load_state_dict(lrs_state) + def restore_loops(self) -> None: + """ Calls hooks on the loops to give it a chance to restore its state from the checkpoint. """ + if not self._loaded_checkpoint: + return + + state_dict = self._loaded_checkpoint.get("loops", None) + if state_dict: + self.trainer.fit_loop.load_state_dict(state_dict["fit_loop"]) + self.trainer.validate_loop.load_state_dict(state_dict["validate_loop"]) + self.trainer.test_loop.load_state_dict(state_dict["test_loop"]) + self.trainer.predict_loop.load_state_dict(state_dict["predict_loop"]) + # ---------------------------------- # PRIVATE OPS # ---------------------------------- @@ -332,6 +347,9 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: 'state_dict': self.trainer.accelerator.lightning_module_state_dict(), } + if fault_tolerant_enabled(): + checkpoint.update({"loops": self.get_loops_state_dict()}) + if not weights_only: # dump callbacks checkpoint['callbacks'] = self.trainer.on_save_checkpoint(checkpoint) @@ -370,6 +388,14 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: return checkpoint + def get_loops_state_dict(self): + return { + "fit_loop": self.trainer.fit_loop.state_dict(), + "validate_loop": self.trainer.validate_loop.state_dict(), + "test_loop": self.trainer.test_loop.state_dict(), + "predict_loop": self.trainer.predict_loop.state_dict(), + } + def hpc_load(self, checkpoint_path: str) -> None: """ Attempts to restore the full training and model state from a HPC checkpoint file. diff --git a/pytorch_lightning/trainer/connectors/optimizer_connector.py b/pytorch_lightning/trainer/connectors/optimizer_connector.py index 06ae55a1ca672..a71356710b5a7 100644 --- a/pytorch_lightning/trainer/connectors/optimizer_connector.py +++ b/pytorch_lightning/trainer/connectors/optimizer_connector.py @@ -15,6 +15,7 @@ from weakref import proxy import pytorch_lightning as pl +from pytorch_lightning.trainer.progress import OptimizationProgress from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -48,6 +49,8 @@ def update_learning_rates( if opt_indices is None: opt_indices = [] + progress: OptimizationProgress = self.trainer.fit_loop.epoch_loop.batch_loop.optim_progress + for scheduler_idx, lr_scheduler in enumerate(self.trainer.lr_schedulers): if isinstance(lr_scheduler['opt_idx'], int) and lr_scheduler['opt_idx'] not in opt_indices: continue @@ -83,11 +86,15 @@ def update_learning_rates( # update LR old_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr'] + progress.scheduler.increment_ready() + if lr_scheduler['reduce_on_plateau']: lr_scheduler['scheduler'].step(monitor_val) else: lr_scheduler['scheduler'].step() + progress.scheduler.increment_completed() + new_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr'] if self.trainer.dev_debugger.enabled: diff --git a/pytorch_lightning/trainer/progress.py b/pytorch_lightning/trainer/progress.py index 1098957033855..db2321365bfa6 100644 --- a/pytorch_lightning/trainer/progress.py +++ b/pytorch_lightning/trainer/progress.py @@ -35,13 +35,11 @@ def from_state_dict(cls, state_dict: dict) -> "BaseProgress": class Tracker(BaseProgress): """ Track an event's progress. - Args: ready: Intended to track the number of events ready to start. started: Intended to be incremented after the event is started (e.g. after ``on_*_start`` runs). processed: Intended to be incremented after the event is processed. completed: Intended to be incremented after the event completes (e.g. after ``on_*_end`` runs). - Attributes set to ``None`` are treated as unused and are restricted. """ @@ -88,7 +86,6 @@ def reset_on_restart(self): class Progress(BaseProgress): """ Track aggregated and current progress. - Args: total: Intended to track the total progress of an event current: Intended to track the current progress of an event @@ -130,14 +127,39 @@ def load_state_dict(self, state_dict: dict) -> None: self.current.load_state_dict(state_dict["current"]) +@dataclass class BatchProgress(Progress): """ Tracks the batch progress + Args: + total: Tracks the total epoch progress + current: Tracks the current epoch progress + """ + +@dataclass +class TrainingEpochProgress(Progress): + """ + Tracks the batch progress Args: total: Tracks the total epoch progress current: Tracks the current epoch progress """ + should_check_val: bool = False + + def load_state_dict(self, state_dict: dict) -> None: + super().load_state_dict(state_dict) + self.should_check_val = state_dict["should_check_val"] + + +@dataclass +class DataLoaderProgress(Progress): + + dataloader_idx: int = 0 + + def load_state_dict(self, state_dict: dict) -> None: + super().load_state_dict(state_dict) + self.dataloader_idx = state_dict["dataloader_idx"] @dataclass @@ -145,13 +167,12 @@ class EpochProgress(Progress): """ Tracks the epoch progress These counters are local to a trainer rank. By default, they are not globally synced across all ranks. - Args: total: Tracks the total epoch progress current: Tracks the current epoch progress batch: Tracks batch progress. """ - + dataloader_idx: int = 0 batch: BatchProgress = field(default_factory=BatchProgress) def reset_on_epoch(self) -> None: @@ -160,13 +181,13 @@ def reset_on_epoch(self) -> None: def load_state_dict(self, state_dict: dict) -> None: super().load_state_dict(state_dict) self.batch.load_state_dict(state_dict["batch"]) + self.dataloader_idx = state_dict["dataloader_idx"] @dataclass class OptimizerProgress(BaseProgress): """ Track optimizer progress. - Args: step: Tracks ``optimizer.step`` calls. zero_grad: Tracks ``optimizer.zero_grad`` calls. @@ -188,13 +209,13 @@ def load_state_dict(self, state_dict: dict) -> None: class OptimizationProgress(BaseProgress): """ Track optimization progress. - Args: optimizer: Tracks optimizer progress. scheduler: Tracks scheduler progress. """ # TODO: support for multiple optimizers + optimizer_idx: int = 0 optimizer: OptimizerProgress = field(default_factory=OptimizerProgress) scheduler: Progress = field(default_factory=lambda: Progress.from_defaults(started=None, processed=None)) @@ -213,6 +234,7 @@ def reset_on_epoch(self) -> None: def load_state_dict(self, state_dict: dict) -> None: self.optimizer.load_state_dict(state_dict["optimizer"]) self.scheduler.load_state_dict(state_dict["scheduler"]) + self.optimizer_idx = state_dict["optimizer_idx"] @dataclass @@ -220,11 +242,9 @@ class EpochLoopProgress(BaseProgress): """ Tracks epoch loop progress. These counters are local to a trainer rank. By default, they are not globally synced across all ranks. - Args: epoch: Tracks epochs progress. """ - epoch: EpochProgress = field(default_factory=EpochProgress) def increment_epoch_completed(self) -> None: @@ -237,42 +257,3 @@ def reset_on_epoch(self) -> None: def load_state_dict(self, state_dict: dict) -> None: self.epoch.load_state_dict(state_dict["epoch"]) - - -@dataclass -class TrainingEpochProgress(EpochProgress): - """ - Extends ``EpochProgress`` with training specific attributes - - Args: - total: Tracks the total epoch progress. - current: Tracks the current epoch progress. - batch: Tracks batch progress. - optim: Tracks optimization progress. - val: Tracks val_loop progress. - """ - - optim: OptimizationProgress = field(default_factory=OptimizationProgress) - val: EpochLoopProgress = field(default_factory=EpochLoopProgress) - - def load_state_dict(self, state_dict: dict) -> None: - super().load_state_dict(state_dict) - self.optim.load_state_dict(state_dict["optim"]) - self.val.load_state_dict(state_dict["val"]) - - -@dataclass -class FitLoopProgress(EpochLoopProgress): - """ - Extends ``EpochLoopProgress`` with fit specific attributes - - Args: - epoch: Tracks epochs progress. - """ - - epoch: TrainingEpochProgress = field(default_factory=TrainingEpochProgress) - - def reset_on_epoch(self) -> None: - # do not reset `epoch.current` as it should track the number of epochs this `fit` call - self.epoch.reset_on_epoch() - self.epoch.optim.reset_on_epoch() diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7475cd9c81326..32b61992166a0 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -57,7 +57,6 @@ from pytorch_lightning.trainer.logging import TrainerLoggingMixin from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin -from pytorch_lightning.trainer.progress import EpochLoopProgress, FitLoopProgress from pytorch_lightning.trainer.properties import TrainerProperties from pytorch_lightning.trainer.states import TrainerFn, TrainerState, TrainerStatus from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin @@ -359,10 +358,10 @@ def __init__( self.validate_loop = EvaluationLoop() self.test_loop = EvaluationLoop() self.predict_loop = PredictionLoop() - self.fit_loop.connect(self, progress=FitLoopProgress()) - self.validate_loop.connect(self, progress=EpochLoopProgress()) - self.test_loop.connect(self, progress=EpochLoopProgress()) - self.predict_loop.connect(self, progress=EpochLoopProgress()) + self.fit_loop.connect(self) + self.validate_loop.connect(self) + self.test_loop.connect(self) + self.predict_loop.connect(self) # training state if weights_summary is not None and weights_summary not in ModelSummary.MODES: diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 3125a2d38f15e..fdd5382ca751d 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -14,6 +14,7 @@ """General utilities""" import importlib import operator +import os import platform import sys from importlib.util import find_spec @@ -101,3 +102,7 @@ def _compare_version(package: str, op, version) -> bool: _IPU_AVAILABLE = poptorch.ipuHardwareIsAvailable() else: _IPU_AVAILABLE = False + + +def fault_tolerant_enabled(): + return os.getenv("PL_FAULT_TOLERANT_TRAINING", "0") == "1" diff --git a/tests/trainer/test_progress.py b/tests/trainer/test_progress.py index a3bbd5a36a2c1..ec203ae7cac76 100644 --- a/tests/trainer/test_progress.py +++ b/tests/trainer/test_progress.py @@ -11,19 +11,29 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os from copy import deepcopy +from unittest import mock import pytest +import torch +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.trainer.progress import ( BatchProgress, EpochLoopProgress, EpochProgress, - FitLoopProgress, + OptimizationProgress, OptimizerProgress, Progress, Tracker, ) +from tests.helpers import BoringModel + + +class CustomException(BaseException): + pass def test_progress_geattr_setattr(): @@ -135,74 +145,9 @@ def test_optimizer_progress_default_factory(): assert p2.step.total.completed == 0 -def test_fit_loop_progress_serialization(): - fit_loop = FitLoopProgress() - _ = deepcopy(fit_loop) - fit_loop.epoch.increment_completed() # check `TrainingEpochProgress.load_state_dict` calls `super` - - state_dict = fit_loop.state_dict() - # yapf: disable - assert state_dict == { - 'epoch': { - # number of epochs across `fit` calls - 'total': {'completed': 1, 'processed': 0, 'ready': 0, 'started': 0}, - # number of epochs this `fit` call - 'current': {'completed': 1, 'processed': 0, 'ready': 0, 'started': 0}, - 'batch': { - # number of batches across `fit` calls - 'total': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0}, - # number of batches this epoch - 'current': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0}, - }, - # `fit` optimization progress - 'optim': { - # optimizers progress - 'optimizer': { - 'step': { - # `optimizer.step` calls across `fit` calls - 'total': {'completed': 0, 'processed': None, 'ready': 0, 'started': 0}, - # `optimizer.step` calls this epoch - 'current': {'completed': 0, 'processed': None, 'ready': 0, 'started': 0}, - }, - 'zero_grad': { - # `optimizer.zero_grad` calls across `fit` calls - 'total': {'completed': 0, 'processed': None, 'ready': 0, 'started': 0}, - # `optimizer.zero_grad` calls this epoch - 'current': {'completed': 0, 'processed': None, 'ready': 0, 'started': 0}, - }, - }, - 'scheduler': { - # `scheduler.step` calls across `fit` calls - 'total': {'completed': 0, 'processed': None, 'ready': 0, 'started': None}, - # `scheduler.step` calls this epoch - 'current': {'completed': 0, 'processed': None, 'ready': 0, 'started': None}, - }, - }, - # `fit` validation progress - 'val': { - 'epoch': { - # number of `validation` calls across `fit` calls - 'total': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0}, - # number of `validation` calls this `fit` call - 'current': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0}, - 'batch': { - # number of batches across `fit` `validation` calls - 'total': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0}, - # number of batches this `fit` `validation` call - 'current': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0}, - }, - } - }, - } - } - # yapf: enable - - new_loop = FitLoopProgress.from_state_dict(state_dict) - assert fit_loop == new_loop - - def test_epoch_loop_progress_serialization(): loop = EpochLoopProgress() + loop.epoch.dataloader_idx = 1 _ = deepcopy(loop) state_dict = loop.state_dict() @@ -219,9 +164,171 @@ def test_epoch_loop_progress_serialization(): # number of batches this `validate` call 'current': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0}, }, + 'dataloader_idx': 1 } } # yapf: enable new_loop = EpochLoopProgress.from_state_dict(state_dict) assert loop == new_loop + + +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +@pytest.mark.parametrize("use_multiple_optimizers", [False, True]) +@pytest.mark.parametrize("accumulate_grad_batches", [1, 2]) +def test_progress_tracking(use_multiple_optimizers, accumulate_grad_batches, tmpdir): + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + if use_multiple_optimizers: + self.configure_optimizers = self.configure_optimizers_3 + self.should_fail = True + + def training_step(self, batch, batch_idx, optimizer_idx: int = None): + # breaking on global_step 4 + if self.should_fail and self.trainer.current_epoch == 1 and batch_idx == 1 and optimizer_idx == ( + 1 if use_multiple_optimizers else None + ): + raise CustomException + return super().training_step(batch, batch_idx) + + def configure_optimizers_3(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) + optimizer_1 = torch.optim.Adam(self.layer.parameters(), lr=0.1) + lr_scheduler_1 = torch.optim.lr_scheduler.StepLR(optimizer_1, step_size=1) + optimizer_2 = torch.optim.Adam(self.layer.parameters(), lr=0.1) + return [optimizer, optimizer_1, optimizer_2], \ + [lr_scheduler, {"scheduler": lr_scheduler_1, "interval": "step"}] + + model = TestModel() + model.training_epoch_end = None + + chk = ModelCheckpoint(dirpath=tmpdir, filename=str(use_multiple_optimizers), save_last=True) + chk.last_model_path = None + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + limit_train_batches=3, + limit_val_batches=0, + callbacks=chk, + accumulate_grad_batches=accumulate_grad_batches, + resume_from_checkpoint=None, + ) + + # simulate random failure in training_step + try: + trainer.fit(model) + except CustomException: + pass + + assert isinstance(trainer.fit_loop.epoch_loop.batch_loop.optim_progress, OptimizationProgress) + + pr = trainer.fit_loop.epoch_loop.progress + + assert pr.total == Tracker(ready=2, started=2, processed=1, completed=1) + assert pr.current == Tracker(ready=2, started=2, processed=1, completed=1) + + pr = trainer.fit_loop.epoch_loop.batch_loop.progress + + assert pr.total == Tracker(ready=5, started=5, processed=4, completed=4) + assert pr.current == Tracker(ready=2, started=2, processed=1, completed=1) + + num_optimizers = 3 if use_multiple_optimizers else 1 + + optim = trainer.fit_loop.epoch_loop.batch_loop.optim_progress + + # 4 optimizer steps because breaking on the second batch of the second epoch (3 + 1) + total = (4 * num_optimizers + (1 if use_multiple_optimizers else 0)) // accumulate_grad_batches + + # we raised expection on the first optimizer + current = (1 if use_multiple_optimizers else 0) + + if accumulate_grad_batches == 2 and use_multiple_optimizers: + total += 1 + + assert optim.optimizer.step.total == Tracker(ready=total + 1, started=total, processed=None, completed=total) + assert optim.optimizer.step.current == Tracker( + ready=current + 1, started=current, processed=None, completed=current + ) + + if accumulate_grad_batches == 2: + # that's weird ! todo (tchaton) investigate this + total = (9 if use_multiple_optimizers else 3) + current = 0 # same there. + + assert optim.optimizer.zero_grad.total == Tracker(ready=total, started=total, processed=None, completed=total) + assert optim.optimizer.zero_grad.current == Tracker( + ready=current, started=current, processed=None, completed=current + ) + + # for multiple optimizers: 4 batches + 1 on epoch + total = (5 if use_multiple_optimizers else 1) // accumulate_grad_batches + + if accumulate_grad_batches == 2: + total += 1 + + assert optim.scheduler.total == Tracker(ready=total, started=None, processed=None, completed=total) + # assert optim.scheduler.current == Tracker(ready=0, started=None, processed=None, completed=0) + + assert optim.optimizer_idx == (1 if use_multiple_optimizers else 0) + + checkpoint = torch.load(trainer.checkpoint_callback.last_model_path) + assert "loops" in checkpoint + + +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +def test_progress_tracking_validation_multiple_datasets(tmpdir): + + class ValidationModel(BoringModel): + + def __init__(self): + super().__init__() + + def validation_step(self, batch, batch_idx, dataloader_idx): + if self.trainer.fit_loop.epoch_loop.batch_idx == 3 and batch_idx == 1 and dataloader_idx == 1: + raise CustomException + return super().validation_step(batch, batch_idx) + + def val_dataloader(self): + return [super().val_dataloader(), super().val_dataloader(), super().val_dataloader()] + + model = ValidationModel() + model.validation_epoch_end = None + + chk = ModelCheckpoint(dirpath=tmpdir, save_last=True) + chk.last_model_path = None + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=5, + limit_val_batches=3, + callbacks=chk, + resume_from_checkpoint=None, + val_check_interval=2, + num_sanity_val_steps=0, + ) + + # simulate random failure in training_step + try: + trainer.fit(model) + except CustomException: + pass + + pr = trainer.fit_loop.epoch_loop.val_loop.progress + + assert pr.total == Tracker(ready=2, started=2, processed=1, completed=1) + assert pr.current == Tracker(ready=1, started=1, processed=0, completed=0) + assert pr.dataloader_idx == 1 + + assert trainer.fit_loop.epoch_loop.progress.should_check_val + + pr = trainer.fit_loop.epoch_loop.val_loop.epoch_loop.progress + + # 3 dataloaders with 3 samples for batch_idx == 1 + first dataloader on batch_idx == 1 + failure on batch_idx = 1 + current = 2 + total = 3 * 3 + 3 + current + assert pr.total == Tracker(ready=total, started=total, processed=total - 1, completed=total - 1) + assert pr.current == Tracker(ready=current, started=current, processed=current - 1, completed=current - 1) From 22fa5fb2903220ce469352e06ad626e3b69728a7 Mon Sep 17 00:00:00 2001 From: tchaton Date: Sat, 10 Jul 2021 14:46:52 +0200 Subject: [PATCH 052/157] validate json --- .../loops/batch/training_batch_loop.py | 2 +- pytorch_lightning/trainer/trainer.py | 7 + tests/trainer/test_progress.py | 124 +++++++++++++----- 3 files changed, 98 insertions(+), 35 deletions(-) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 89ea775977817..55244e7bfa1c0 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -110,7 +110,7 @@ def reset(self) -> None: self.batch_idx = 0 self.batch_outputs = [[] for _ in range(len(self.trainer.optimizers))] - self.optim_progress.optimizer.reset_on_epoch() + self.optim_progress.reset_on_epoch() def on_run_start(self, batch: Any, batch_idx: int, dataloader_idx: int): """Splits the data into tbptt splits diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 32b61992166a0..d4df26941f919 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -76,6 +76,7 @@ from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.distributed import distributed_available from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import fault_tolerant_enabled from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS @@ -996,6 +997,7 @@ def _run_train(self) -> None: self.training_type_plugin.reconciliate_processes(traceback.format_exc()) # give accelerators a chance to finish self.accelerator.on_train_end() + self.on_expection() # reset bookkeeping self.state.stage = None raise @@ -1235,3 +1237,8 @@ def _log_device_info(self) -> None: "IPU available but not used. Set the `ipus` flag in your trainer" " `Trainer(ipus=8)` or script `--ipus=8`." ) + + def on_expection(self): + if fault_tolerant_enabled(): + # save a checkpoint for fault tolerant training + self.fit_loop._check_checkpoint_callback(True) diff --git a/tests/trainer/test_progress.py b/tests/trainer/test_progress.py index ec203ae7cac76..187b10e9dd0df 100644 --- a/tests/trainer/test_progress.py +++ b/tests/trainer/test_progress.py @@ -24,7 +24,6 @@ BatchProgress, EpochLoopProgress, EpochProgress, - OptimizationProgress, OptimizerProgress, Progress, Tracker, @@ -224,59 +223,116 @@ def configure_optimizers_3(self): except CustomException: pass - assert isinstance(trainer.fit_loop.epoch_loop.batch_loop.optim_progress, OptimizationProgress) + ####################### + # VALIDATE CHECKPOINT # + ####################### - pr = trainer.fit_loop.epoch_loop.progress - - assert pr.total == Tracker(ready=2, started=2, processed=1, completed=1) - assert pr.current == Tracker(ready=2, started=2, processed=1, completed=1) - - pr = trainer.fit_loop.epoch_loop.batch_loop.progress - - assert pr.total == Tracker(ready=5, started=5, processed=4, completed=4) - assert pr.current == Tracker(ready=2, started=2, processed=1, completed=1) + checkpoint = torch.load(trainer.checkpoint_callback.last_model_path) num_optimizers = 3 if use_multiple_optimizers else 1 - optim = trainer.fit_loop.epoch_loop.batch_loop.optim_progress - # 4 optimizer steps because breaking on the second batch of the second epoch (3 + 1) - total = (4 * num_optimizers + (1 if use_multiple_optimizers else 0)) // accumulate_grad_batches + total_optimizer_step = (4 * num_optimizers + (1 if use_multiple_optimizers else 0)) // accumulate_grad_batches # we raised expection on the first optimizer - current = (1 if use_multiple_optimizers else 0) + current_optimize_step = (1 if use_multiple_optimizers else 0) if accumulate_grad_batches == 2 and use_multiple_optimizers: - total += 1 + total_optimizer_step += 1 - assert optim.optimizer.step.total == Tracker(ready=total + 1, started=total, processed=None, completed=total) - assert optim.optimizer.step.current == Tracker( - ready=current + 1, started=current, processed=None, completed=current - ) + total_optimizer_zero_grad = total_optimizer_step + current_optimizer_zero_grad = current_optimize_step if accumulate_grad_batches == 2: # that's weird ! todo (tchaton) investigate this - total = (9 if use_multiple_optimizers else 3) - current = 0 # same there. + total_optimizer_zero_grad = (9 if use_multiple_optimizers else 3) + current_optimizer_zero_grad = 0 # same there. - assert optim.optimizer.zero_grad.total == Tracker(ready=total, started=total, processed=None, completed=total) - assert optim.optimizer.zero_grad.current == Tracker( - ready=current, started=current, processed=None, completed=current - ) + total_scheduler_step = (5 if use_multiple_optimizers else 1) // accumulate_grad_batches - # for multiple optimizers: 4 batches + 1 on epoch - total = (5 if use_multiple_optimizers else 1) // accumulate_grad_batches + current_scheduler_step = 0 if accumulate_grad_batches == 2: - total += 1 + total_scheduler_step += 1 - assert optim.scheduler.total == Tracker(ready=total, started=None, processed=None, completed=total) - # assert optim.scheduler.current == Tracker(ready=0, started=None, processed=None, completed=0) + optimizer_idx = (1 if use_multiple_optimizers else 0) - assert optim.optimizer_idx == (1 if use_multiple_optimizers else 0) + # yapf: disable + expected = { + "state_dict": {}, + "epoch_loop.state_dict": {}, + "epoch_loop.progress": { + "total": {"ready": 2, "started": 2, "processed": 1, "completed": 1}, + "current": {"ready": 2, "started": 2, "processed": 1, "completed": 1}, + "should_check_val": False, + }, + "epoch_loop.batch_loop.state_dict": {}, + "epoch_loop.batch_loop.progress": { + "total": {"ready": 5, "started": 5, "processed": 4, "completed": 4}, + "current": {"ready": 2, "started": 2, "processed": 1, "completed": 1}, + }, + "epoch_loop.batch_loop.optim_progress": { + "optimizer_idx": optimizer_idx, + "optimizer": { + "step": { + "total": { + "ready": total_optimizer_step + 1, + "started": total_optimizer_step, + "processed": None, + "completed": total_optimizer_step + }, + "current": { + "ready": current_optimize_step + 1, + "started": current_optimize_step, + "processed": None, + "completed": current_optimize_step, + }, + }, + "zero_grad": { + "total": { + "ready": total_optimizer_zero_grad, + "started": total_optimizer_zero_grad, + "processed": None, + "completed": total_optimizer_zero_grad + }, + "current": { + "ready": current_optimizer_zero_grad, + "started": current_optimizer_zero_grad, + "processed": None, + "completed": current_optimizer_zero_grad, + }, + }, + }, + "scheduler": { + "total": { + "ready": total_scheduler_step, + "started": None, + "processed": None, + "completed": total_scheduler_step + }, + "current": { + "ready": current_scheduler_step, + "started": None, + "processed": None, + "completed": current_scheduler_step + }, + }, + }, + "epoch_loop.val_loop.state_dict": {}, + "epoch_loop.val_loop.progress": { + "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "dataloader_idx": 0, + }, + "epoch_loop.val_loop.epoch_loop.state_dict": {}, + "epoch_loop.val_loop.epoch_loop.progress": { + "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + }, + } + # yapf: enable - checkpoint = torch.load(trainer.checkpoint_callback.last_model_path) - assert "loops" in checkpoint + assert checkpoint["loops"]["fit_loop"] == expected @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) From 6d45fe26e3542a9523e792263b1e968e28f16fd1 Mon Sep 17 00:00:00 2001 From: tchaton Date: Sat, 10 Jul 2021 14:49:45 +0200 Subject: [PATCH 053/157] update --- tests/trainer/test_progress.py | 35 +++++++++++++++++++++++++++++----- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/tests/trainer/test_progress.py b/tests/trainer/test_progress.py index 187b10e9dd0df..043233532379a 100644 --- a/tests/trainer/test_progress.py +++ b/tests/trainer/test_progress.py @@ -205,12 +205,14 @@ def configure_optimizers_3(self): model = TestModel() model.training_epoch_end = None + limit_train_batches = 3 + chk = ModelCheckpoint(dirpath=tmpdir, filename=str(use_multiple_optimizers), save_last=True) chk.last_model_path = None trainer = Trainer( default_root_dir=tmpdir, max_epochs=2, - limit_train_batches=3, + limit_train_batches=limit_train_batches, limit_val_batches=0, callbacks=chk, accumulate_grad_batches=accumulate_grad_batches, @@ -229,6 +231,9 @@ def configure_optimizers_3(self): checkpoint = torch.load(trainer.checkpoint_callback.last_model_path) + num_epochs = 1 + num_batches = 4 + num_optimizers = 3 if use_multiple_optimizers else 1 # 4 optimizer steps because breaking on the second batch of the second epoch (3 + 1) @@ -262,14 +267,34 @@ def configure_optimizers_3(self): "state_dict": {}, "epoch_loop.state_dict": {}, "epoch_loop.progress": { - "total": {"ready": 2, "started": 2, "processed": 1, "completed": 1}, - "current": {"ready": 2, "started": 2, "processed": 1, "completed": 1}, + "total": { + "ready": num_epochs + 1, + "started": num_epochs + 1, + "processed": 1, + "completed": 1 + }, + "current": { + "ready": num_epochs + 1, + "started": num_epochs + 1, + "processed": 1, + "completed": 1 + }, "should_check_val": False, }, "epoch_loop.batch_loop.state_dict": {}, "epoch_loop.batch_loop.progress": { - "total": {"ready": 5, "started": 5, "processed": 4, "completed": 4}, - "current": {"ready": 2, "started": 2, "processed": 1, "completed": 1}, + "total": { + "ready": num_batches + 1, + "started": num_batches + 1, + "processed": num_batches, + "completed": num_batches + }, + "current": { + "ready": num_batches - limit_train_batches + 1, + "started": num_batches - limit_train_batches + 1, + "processed": num_batches - limit_train_batches, + "completed": num_batches - limit_train_batches + }, }, "epoch_loop.batch_loop.optim_progress": { "optimizer_idx": optimizer_idx, From 71d01d696b984bba161316b0d01d8549cfe52d7f Mon Sep 17 00:00:00 2001 From: tchaton Date: Sat, 10 Jul 2021 14:55:53 +0200 Subject: [PATCH 054/157] convert to dict for better readability --- tests/trainer/test_progress.py | 46 ++++++++++++++++++++++++++++------ 1 file changed, 38 insertions(+), 8 deletions(-) diff --git a/tests/trainer/test_progress.py b/tests/trainer/test_progress.py index 043233532379a..053373f354989 100644 --- a/tests/trainer/test_progress.py +++ b/tests/trainer/test_progress.py @@ -398,18 +398,48 @@ def val_dataloader(self): except CustomException: pass - pr = trainer.fit_loop.epoch_loop.val_loop.progress + ####################### + # VALIDATE CHECKPOINT # + ####################### - assert pr.total == Tracker(ready=2, started=2, processed=1, completed=1) - assert pr.current == Tracker(ready=1, started=1, processed=0, completed=0) - assert pr.dataloader_idx == 1 + checkpoint = torch.load(trainer.checkpoint_callback.last_model_path)["loops"]["fit_loop"] - assert trainer.fit_loop.epoch_loop.progress.should_check_val + checkpoint = torch.load(trainer.checkpoint_callback.last_model_path)["loops"]["fit_loop"] - pr = trainer.fit_loop.epoch_loop.val_loop.epoch_loop.progress + expected = { + "total": { + "ready": 2, + "started": 2, + "processed": 1, + "completed": 1 + }, + "current": { + "ready": 1, + "started": 1, + "processed": 0, + "completed": 0 + }, + "dataloader_idx": 1, + } + + assert checkpoint["epoch_loop.val_loop.progress"] == expected # 3 dataloaders with 3 samples for batch_idx == 1 + first dataloader on batch_idx == 1 + failure on batch_idx = 1 current = 2 total = 3 * 3 + 3 + current - assert pr.total == Tracker(ready=total, started=total, processed=total - 1, completed=total - 1) - assert pr.current == Tracker(ready=current, started=current, processed=current - 1, completed=current - 1) + expected = { + "total": { + "ready": total, + "started": total, + "processed": total - 1, + "completed": total - 1 + }, + "current": { + "ready": current, + "started": current, + "processed": current - 1, + "completed": current - 1 + }, + } + + assert checkpoint["epoch_loop.val_loop.epoch_loop.progress"] == expected From 1c6c5661e29f74619b16de8362661a2a9755fe47 Mon Sep 17 00:00:00 2001 From: tchaton Date: Sat, 10 Jul 2021 15:14:57 +0200 Subject: [PATCH 055/157] validate reload --- tests/trainer/test_progress.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/trainer/test_progress.py b/tests/trainer/test_progress.py index 053373f354989..6322ab5be33bb 100644 --- a/tests/trainer/test_progress.py +++ b/tests/trainer/test_progress.py @@ -359,6 +359,15 @@ def configure_optimizers_3(self): assert checkpoint["loops"]["fit_loop"] == expected + trainer = Trainer() + trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"], restart_progress=False) + assert trainer.fit_loop.state_dict() == checkpoint["loops"]["fit_loop"] + + trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"]) + state_dict = trainer.fit_loop.state_dict() + assert state_dict != checkpoint["loops"]["fit_loop"] + assert state_dict['epoch_loop.progress']["total"]["started"] == num_epochs + @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) def test_progress_tracking_validation_multiple_datasets(tmpdir): @@ -443,3 +452,10 @@ def val_dataloader(self): } assert checkpoint["epoch_loop.val_loop.epoch_loop.progress"] == expected + + trainer = Trainer() + trainer.fit_loop.load_state_dict(checkpoint, restart_progress=False) + assert trainer.fit_loop.state_dict() == checkpoint + + trainer.fit_loop.load_state_dict(checkpoint) + assert trainer.fit_loop.state_dict() != checkpoint From bc49cc72829cdd96c1c94af42df42948595a80ca Mon Sep 17 00:00:00 2001 From: tchaton Date: Sat, 10 Jul 2021 15:16:03 +0200 Subject: [PATCH 056/157] update --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 14db839d0d6e6..399795869209f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -140,6 +140,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `FastForwardSampler` and `CaptureIterableDataset` ([#8307](https://github.com/PyTorchLightning/pytorch-lightning/pull/8307)) +- Added `progress` tracking on loops ([#8362](https://github.com/PyTorchLightning/pytorch-lightning/pull/8362)) + + ### Changed From 0a0b5e35eff02ddedf5d9998a5a2665eec09d564 Mon Sep 17 00:00:00 2001 From: tchaton Date: Sat, 10 Jul 2021 19:27:16 +0200 Subject: [PATCH 057/157] update --- tests/loops/test_loop_progress_integration.py | 20 +- tests/loops/test_loop_state_dict.py | 201 ++++++------------ 2 files changed, 72 insertions(+), 149 deletions(-) diff --git a/tests/loops/test_loop_progress_integration.py b/tests/loops/test_loop_progress_integration.py index 986ea2543d6d8..82d56fb5dd872 100644 --- a/tests/loops/test_loop_progress_integration.py +++ b/tests/loops/test_loop_progress_integration.py @@ -3,20 +3,16 @@ def test_loop_progress_integration(): trainer = Trainer() - fit_loop = trainer.fit_loop - # check identities inside the fit loop - assert fit_loop.progress.epoch is fit_loop.epoch_loop.progress - assert fit_loop.epoch_loop.progress.batch is fit_loop.epoch_loop.batch_loop.progress - assert fit_loop.epoch_loop.progress.optim is fit_loop.epoch_loop.batch_loop.optim_progress - assert fit_loop.epoch_loop.progress.val is fit_loop.epoch_loop.val_loop.progress - assert fit_loop.epoch_loop.val_loop.progress.epoch is fit_loop.epoch_loop.val_loop.epoch_loop.progress - # check identities inside the evaluation and predict loops - assert trainer.validate_loop.progress.epoch is trainer.validate_loop.epoch_loop.progress - assert trainer.test_loop.progress.epoch is trainer.test_loop.epoch_loop.progress - assert trainer.predict_loop.progress.epoch is trainer.predict_loop.epoch_loop.progress # check no progresses are shared - assert trainer.fit_loop.progress is not trainer.validate_loop.progress assert trainer.validate_loop.progress is not trainer.test_loop.progress assert trainer.test_loop.progress is not trainer.predict_loop.progress # check the validation progresses are not shared assert trainer.fit_loop.epoch_loop.val_loop.progress is not trainer.validate_loop.progress + expected = trainer.fit_loop.loop_progress["epoch_loop"]["progress"] + assert expected == trainer.fit_loop.epoch_loop.progress + expected = trainer.fit_loop.loop_progress["epoch_loop"]["batch_loop"]["progress"] + assert expected == trainer.fit_loop.epoch_loop.batch_loop.progress + expected = trainer.fit_loop.loop_progress["epoch_loop"]["val_loop"]["progress"] + assert expected == trainer.fit_loop.epoch_loop.val_loop.progress + expected = trainer.fit_loop.loop_progress["epoch_loop"]["val_loop"]["epoch_loop"]["progress"] + assert expected == trainer.fit_loop.epoch_loop.val_loop.epoch_loop.progress diff --git a/tests/loops/test_loop_state_dict.py b/tests/loops/test_loop_state_dict.py index eed23a89a8b36..591f0c0f297b8 100644 --- a/tests/loops/test_loop_state_dict.py +++ b/tests/loops/test_loop_state_dict.py @@ -44,176 +44,101 @@ def test_loops_state_dict_structure(): # yapf: disable expected = { "fit_loop": { - "epoch_loop": { - "batch_loop": { - "state_dict": {}, - "progress": { - "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "state_dict": {}, + "epoch_loop.state_dict": {}, + "epoch_loop.progress": { + "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "should_check_val": False, + }, + "epoch_loop.batch_loop.state_dict": {}, + "epoch_loop.batch_loop.progress": { + "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + }, + "epoch_loop.batch_loop.optim_progress": { + "optimizer_idx": 0, + "optimizer": { + "step": { + "total": { + "ready": 0, + "started": 0, + "processed": None, + "completed": 0, + }, "current": { "ready": 0, "started": 0, - "processed": 0, + "processed": None, "completed": 0, }, }, - "optim_progress": { - "optimizer": { - "step": { - "total": { - "ready": 0, - "started": 0, - "processed": None, - "completed": 0, - }, - "current": { - "ready": 0, - "started": 0, - "processed": None, - "completed": 0, - }, - }, - "zero_grad": { - "total": { - "ready": 0, - "started": 0, - "processed": None, - "completed": 0, - }, - "current": { - "ready": 0, - "started": 0, - "processed": None, - "completed": 0, - }, - }, - }, - "scheduler": { - "total": { - "ready": 0, - "started": None, - "processed": None, - "completed": 0, - }, - "current": { - "ready": 0, - "started": None, - "processed": None, - "completed": 0, - }, + "zero_grad": { + "total": { + "ready": 0, + "started": 0, + "processed": None, + "completed": 0, }, - }, - }, - "val_loop": { - "state_dict": {}, - "progress": { - "epoch": { - "total": { - "ready": 0, - "started": 0, - "processed": 0, - "completed": 0, - }, - "current": { - "ready": 0, - "started": 0, - "processed": 0, - "completed": 0, - }, - "batch": { - "total": { - "ready": 0, - "started": 0, - "processed": 0, - "completed": 0, - }, - "current": { - "ready": 0, - "started": 0, - "processed": 0, - "completed": 0, - }, - }, - } - }, - "epoch_loop.state_dict": {}, - "epoch_loop.progress": { - "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, "current": { "ready": 0, "started": 0, - "processed": 0, + "processed": None, "completed": 0, }, - "batch": { - "total": { - "ready": 0, - "started": 0, - "processed": 0, - "completed": 0, - }, - "current": { - "ready": 0, - "started": 0, - "processed": 0, - "completed": 0, - }, - }, }, }, - } + "scheduler": { + "total": { + "ready": 0, + "started": None, + "processed": None, + "completed": 0, + }, + "current": { + "ready": 0, + "started": None, + "processed": None, + "completed": 0, + }, + }, + }, + "epoch_loop.val_loop.state_dict": {}, + "epoch_loop.val_loop.progress": { + "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "dataloader_idx": 0, + }, + "epoch_loop.val_loop.epoch_loop.state_dict": {}, + "epoch_loop.val_loop.epoch_loop.progress": { + "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + }, }, "validate_loop": { "state_dict": {}, "progress": { - "epoch": { - "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "batch": { - "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "current": { - "ready": 0, - "started": 0, - "processed": 0, - "completed": 0, - }, - }, - } + "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "dataloader_idx": 0, }, "epoch_loop.state_dict": {}, "epoch_loop.progress": { "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "batch": { - "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - }, }, }, "test_loop": { "state_dict": {}, "progress": { - "epoch": { - "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "batch": { - "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "current": { - "ready": 0, - "started": 0, - "processed": 0, - "completed": 0, - }, - }, - } + "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "dataloader_idx": 0, }, "epoch_loop.state_dict": {}, "epoch_loop.progress": { "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "batch": { - "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - }, }, }, "predict_loop": { @@ -222,6 +147,7 @@ def test_loops_state_dict_structure(): "epoch": { "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "dataloader_idx": 0, "batch": { "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, "current": { @@ -237,6 +163,7 @@ def test_loops_state_dict_structure(): "epoch_loop.progress": { "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "dataloader_idx": 0, "batch": { "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, From 45fb6576c234220e0b6ffa9babf1bf1c1157af40 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 12 Jul 2021 11:58:45 +0200 Subject: [PATCH 058/157] update on comments --- pytorch_lightning/core/optimizer.py | 1 - pytorch_lightning/loops/batch/training_batch_loop.py | 11 +++++++++-- pytorch_lightning/loops/epoch/training_epoch_loop.py | 1 + .../trainer/connectors/optimizer_connector.py | 9 +++------ 4 files changed, 13 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index 33b44a35d31f1..25e4519eb39fc 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -207,7 +207,6 @@ def closure_dis(): profiler_name = f"optimizer_step_and_closure_{self._optimizer_idx}" self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs) - self._trainer.fit_loop.epoch_loop.batch_loop.optim_progress.optimizer.step.increment_processed() self._total_optimizer_step_calls += 1 def __repr__(self): diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 55244e7bfa1c0..8b4ed1140144c 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -28,7 +28,7 @@ from pytorch_lightning.loops.base import Loop from pytorch_lightning.plugins import ParallelPlugin from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection -from pytorch_lightning.trainer.progress import OptimizationProgress, Progress +from pytorch_lightning.trainer.progress import BatchProgress, OptimizationProgress from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import AMPType, AttributeDict, DeviceType, grad_norm from pytorch_lightning.utilities.apply_func import apply_to_collection @@ -50,7 +50,7 @@ def __init__(self) -> None: self.running_loss: TensorRunningAccum = TensorRunningAccum(window_length=20) self.batch_idx: int = 0 self.split_idx: Optional[int] = None - self.progress = Progress() + self.progress = BatchProgress() self.optim_progress = OptimizationProgress() self._warning_cache: WarningCache = WarningCache() self._hiddens: Optional[Tensor] = None @@ -441,6 +441,7 @@ def _optimizer_step( using_lbfgs=is_lbfgs, ) + self.optim_progress.optimizer.step.increment_processed() self.optim_progress.optimizer.step.increment_completed() def _on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None: @@ -724,3 +725,9 @@ def _truncated_bptt_steps(self) -> int: if lightning_module.truncated_bptt_steps > 0: return lightning_module.truncated_bptt_steps return self.trainer.truncated_bptt_steps or 0 + + def increment_scheduler_ready(self): + self.optim_progress.scheduler.increment_ready() + + def increment_scheduler_completed(self): + self.optim_progress.scheduler.increment_completed() diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index af4e4fc52d63f..8f6cc13e64fd5 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -394,6 +394,7 @@ def update_lr_schedulers(self, interval: str, update_plateau_schedulers: bool) - """updates the lr schedulers based on the given interval""" if interval == "step" and self.batch_loop.should_accumulate(): return + self.trainer.optimizer_connector.update_learning_rates( interval=interval, update_plateau_schedulers=update_plateau_schedulers, diff --git a/pytorch_lightning/trainer/connectors/optimizer_connector.py b/pytorch_lightning/trainer/connectors/optimizer_connector.py index a71356710b5a7..16b751e7db4b9 100644 --- a/pytorch_lightning/trainer/connectors/optimizer_connector.py +++ b/pytorch_lightning/trainer/connectors/optimizer_connector.py @@ -15,7 +15,6 @@ from weakref import proxy import pytorch_lightning as pl -from pytorch_lightning.trainer.progress import OptimizationProgress from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -49,8 +48,6 @@ def update_learning_rates( if opt_indices is None: opt_indices = [] - progress: OptimizationProgress = self.trainer.fit_loop.epoch_loop.batch_loop.optim_progress - for scheduler_idx, lr_scheduler in enumerate(self.trainer.lr_schedulers): if isinstance(lr_scheduler['opt_idx'], int) and lr_scheduler['opt_idx'] not in opt_indices: continue @@ -86,17 +83,17 @@ def update_learning_rates( # update LR old_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr'] - progress.scheduler.increment_ready() + self.trainer.fit_loop.epoch_loop.batch_loop.increment_scheduler_ready() if lr_scheduler['reduce_on_plateau']: lr_scheduler['scheduler'].step(monitor_val) else: lr_scheduler['scheduler'].step() - progress.scheduler.increment_completed() - new_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr'] + self.trainer.fit_loop.epoch_loop.batch_loop.increment_scheduler_completed() + if self.trainer.dev_debugger.enabled: self.trainer.dev_debugger.track_lr_schedulers_update( self.trainer.fit_loop.batch_idx, From 65821c9f79b02b044c61a6cd623c9048b32aa826 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 12 Jul 2021 13:32:47 +0200 Subject: [PATCH 059/157] remove deadcode --- pytorch_lightning/loops/base.py | 30 ------------------------------ 1 file changed, 30 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 61d445fff760f..9997baac79cc5 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -54,33 +54,6 @@ def __init__(self) -> None: def trainer(self) -> Optional['pl.Trainer']: return self._trainer - @trainer.setter - def trainer(self, trainer: 'pl.Trainer'): - """Connect the Trainer to this loop and all children.""" - if not isinstance(trainer, pl.Trainer) and trainer is not None: - raise MisconfigurationException( - f"Loop {self.__class__.__name__} should be connected to a `Trainer`, found: {trainer}." - ) - self._trainer = trainer - for v in self.__dict__.values(): - if isinstance(v, Loop): - v.trainer = trainer - - @property - def loop_progress(self) -> Dict[str, Any]: - """Return the progress for the current loop and children loop.""" - progress = {} - for k, v in self.__dict__.items(): - if isinstance(v, BaseProgress): - progress[k] = v - elif isinstance(v, Loop): - progress[k] = v.loop_progress - return progress - - @property - def trainer(self) -> Optional['pl.Trainer']: - return self._trainer - @trainer.setter def trainer(self, trainer: 'pl.Trainer'): """Connect the Trainer to itself and all its children loops""" @@ -126,9 +99,6 @@ def run(self, *args: Any, **kwargs: Any) -> Optional[Any]: Returns: the output of :attr:`on_run_end` (often outputs collected from each step of the loop) """ - if self.trainer is None: - raise MisconfigurationException(f"The {self.__class__.__name__} Loop hasn't been attached to any Trainer.") - if self.skip: return self.on_skip() From d0492b519d1915473a2acc4bade4c6d5cf0b3c3c Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 12 Jul 2021 13:33:58 +0200 Subject: [PATCH 060/157] clean changelog --- CHANGELOG.md | 1 - 1 file changed, 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 47c5c5e8a866c..4fe41a10fa655 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -205,7 +205,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Refactored prediction loop interface; added new classes `PredictionLoop`, `PredictionEpochLoop` ([#7700](https://github.com/PyTorchLightning/pytorch-lightning/pull/7700), [#8077](https://github.com/PyTorchLightning/pytorch-lightning/pull/8077)) * Removed `pytorch_lightning/trainer/predict_loop.py` ([#8094](https://github.com/PyTorchLightning/pytorch-lightning/pull/8094)) * Moved result teardown to the loops ([#8245](https://github.com/PyTorchLightning/pytorch-lightning/pull/8245)) - * Improve `Loop` API to better handle children `state_dict` and `progress` ([#8334](https://github.com/PyTorchLightning/pytorch-lightning/pull/8334)) - Refactored logging * Renamed and moved `core/step_result.py` to `trainer/connectors/logger_connector/result.py` ([#7736](https://github.com/PyTorchLightning/pytorch-lightning/pull/7736)) From 462b35718032012a2988d82d0f9140c63d3165df Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 12 Jul 2021 13:34:44 +0200 Subject: [PATCH 061/157] clean changelog --- CHANGELOG.md | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4fe41a10fa655..ef98267575d0d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -205,6 +205,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Refactored prediction loop interface; added new classes `PredictionLoop`, `PredictionEpochLoop` ([#7700](https://github.com/PyTorchLightning/pytorch-lightning/pull/7700), [#8077](https://github.com/PyTorchLightning/pytorch-lightning/pull/8077)) * Removed `pytorch_lightning/trainer/predict_loop.py` ([#8094](https://github.com/PyTorchLightning/pytorch-lightning/pull/8094)) * Moved result teardown to the loops ([#8245](https://github.com/PyTorchLightning/pytorch-lightning/pull/8245)) + * Improve `Loop` API to better handle children `state_dict` and `progress` ([#8334](https://github.com/PyTorchLightning/pytorch-lightning/pull/8334)) - Refactored logging * Renamed and moved `core/step_result.py` to `trainer/connectors/logger_connector/result.py` ([#7736](https://github.com/PyTorchLightning/pytorch-lightning/pull/7736)) @@ -291,9 +292,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - `Trainer(resume_from_checkpoint=...)` now restores the model directly after `LightningModule.setup()`, which is before `LightningModule.configure_sharded_model()` ([#7652](https://github.com/PyTorchLightning/pytorch-lightning/pull/7652)) -- Improve `Loop` API to better handle children `state_dict` and `progress` ([#8334](https://github.com/PyTorchLightning/pytorch-lightning/pull/8334)) - - ### Deprecated From 8c0426b8e82882e7209b4fae6bed14efe7803e23 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 12 Jul 2021 13:35:31 +0200 Subject: [PATCH 062/157] update --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ef98267575d0d..0fb37fa52af79 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -375,6 +375,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed deprecated `optimizer` argument in `LightningModule.manual_backward()`; Toggling optimizers in manual optimization should be done using `LightningModule.{un}toggle_optimizer()` ([#8287](https://github.com/PyTorchLightning/pytorch-lightning/pull/8287)) + + ### Fixed - Fixed `lr_scheduler` checkpointed state by calling `update_lr_schedulers` before saving checkpoints ([#7877](https://github.com/PyTorchLightning/pytorch-lightning/pull/7877)) From b7c411325d968518285b81acffc944a9f2e97a1b Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 12 Jul 2021 14:05:50 +0200 Subject: [PATCH 063/157] update on comments --- tests/loops/test_loop_progress_integration.py | 9 ++-- tests/loops/test_loops.py | 49 +++++++++++-------- 2 files changed, 34 insertions(+), 24 deletions(-) diff --git a/tests/loops/test_loop_progress_integration.py b/tests/loops/test_loop_progress_integration.py index 82d56fb5dd872..465ec7ad15655 100644 --- a/tests/loops/test_loop_progress_integration.py +++ b/tests/loops/test_loop_progress_integration.py @@ -1,4 +1,5 @@ from pytorch_lightning import Trainer +from tests.loops.test_loops import _collect_loop_progress def test_loop_progress_integration(): @@ -8,11 +9,11 @@ def test_loop_progress_integration(): assert trainer.test_loop.progress is not trainer.predict_loop.progress # check the validation progresses are not shared assert trainer.fit_loop.epoch_loop.val_loop.progress is not trainer.validate_loop.progress - expected = trainer.fit_loop.loop_progress["epoch_loop"]["progress"] + expected = _collect_loop_progress(trainer.fit_loop)["epoch_loop"]["progress"] assert expected == trainer.fit_loop.epoch_loop.progress - expected = trainer.fit_loop.loop_progress["epoch_loop"]["batch_loop"]["progress"] + expected = _collect_loop_progress(trainer.fit_loop)["epoch_loop"]["batch_loop"]["progress"] assert expected == trainer.fit_loop.epoch_loop.batch_loop.progress - expected = trainer.fit_loop.loop_progress["epoch_loop"]["val_loop"]["progress"] + expected = _collect_loop_progress(trainer.fit_loop)["epoch_loop"]["val_loop"]["progress"] assert expected == trainer.fit_loop.epoch_loop.val_loop.progress - expected = trainer.fit_loop.loop_progress["epoch_loop"]["val_loop"]["epoch_loop"]["progress"] + expected = _collect_loop_progress(trainer.fit_loop)["epoch_loop"]["val_loop"]["epoch_loop"]["progress"] assert expected == trainer.fit_loop.epoch_loop.val_loop.epoch_loop.progress diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index aa1a0a74750a3..70e2ca7a62d3e 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import OrderedDict from copy import deepcopy from dataclasses import dataclass from typing import Any, Dict, Iterator @@ -162,15 +161,20 @@ def on_load_checkpoint(self, state_dict: Dict) -> None: loop_parent.trainer = Trainer() assert loop_child.trainer == loop_parent.trainer - assert state_dict == OrderedDict([('state_dict', { - 'a': 1 - }), ('progress', { - 'increment': 0 - }), ('loop_child.state_dict', { - 'a': 2 - }), ('loop_child.progress', { - 'increment': 0 - })]) + assert state_dict == { + 'state_dict': { + 'a': 1 + }, + 'progress': { + 'increment': 0 + }, + 'loop_child.state_dict': { + 'a': 2 + }, + 'loop_child.progress': { + 'increment': 0 + } + } loop_parent.progress @@ -190,15 +194,20 @@ def on_load_checkpoint(self, state_dict: Dict) -> None: assert not loop_parent.restarting state_dict = loop_parent.state_dict() - assert state_dict == OrderedDict([('state_dict', { - 'a': 1 - }), ('progress', { - 'increment': 1 - }), ('loop_child.state_dict', { - 'a': 3 - }), ('loop_child.progress', { - 'increment': 0 - })]) + assert state_dict == { + 'state_dict': { + 'a': 1 + }, + 'progress': { + 'increment': 1 + }, + 'loop_child.state_dict': { + 'a': 3 + }, + 'loop_child.progress': { + 'increment': 0 + } + } loop_parent = Simple(1) loop_child = Simple(2) @@ -209,7 +218,7 @@ def on_load_checkpoint(self, state_dict: Dict) -> None: del loop_parent.loop_child state_dict = loop_parent.state_dict() - assert state_dict == OrderedDict([('state_dict', {'a': 1}), ('progress', {'increment': 1})]) + assert state_dict == {'state_dict': {'a': 1}, 'progress': {'increment': 1}} grand_loop_parent = Simple(0) loop_parent = Simple(1) From 7e0456b23c75d3fac2b78cb6e7619a8ab717efac Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 12 Jul 2021 15:00:38 +0200 Subject: [PATCH 064/157] CHANGELOG --- CHANGELOG.md | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0fb37fa52af79..14c9adc46d25a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,7 +33,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Progress tracking * Added dataclasses for progress tracking ([#6603](https://github.com/PyTorchLightning/pytorch-lightning/pull/6603), [#7574](https://github.com/PyTorchLightning/pytorch-lightning/pull/7574), [#8140](https://github.com/PyTorchLightning/pytorch-lightning/pull/8140)) * Add `{,load_}state_dict` to the progress tracking dataclasses ([#8140](https://github.com/PyTorchLightning/pytorch-lightning/pull/8140)) - * Connect the progress tracking dataclasses to the loops ([#8244](https://github.com/PyTorchLightning/pytorch-lightning/pull/8244)) + * Connect the progress tracking dataclasses to the loops ([#8244](https://github.com/PyTorchLightning/pytorch-lightning/pull/8244), [#8362](https://github.com/PyTorchLightning/pytorch-lightning/pull/8362)) - Added support for passing a `LightningDataModule` positionally as the second argument to `trainer.{validate,test,predict}` ([#7431](https://github.com/PyTorchLightning/pytorch-lightning/pull/7431)) @@ -146,10 +146,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `restore` function and `restarting` attribute to base `Loop` ([#8247](https://github.com/PyTorchLightning/pytorch-lightning/pull/8247)) -- Added `FastForwardSampler` and `CaptureIterableDataset` ([#8307](https://github.com/PyTorchLightning/pytorch-lightning/pull/8307)) - - -- Added `progress` tracking on loops ([#8362](https://github.com/PyTorchLightning/pytorch-lightning/pull/8362)) +- Added `FastForwardSampler` and `CaptureIterableDataset` ([#8307](https://github.com/PyTorchLightning/pytorch-lightning/pull/8307))`` - Added support for `save_hyperparameters` in `LightningDataModule` ([#3792](https://github.com/PyTorchLightning/pytorch-lightning/pull/3792)) @@ -375,8 +372,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed deprecated `optimizer` argument in `LightningModule.manual_backward()`; Toggling optimizers in manual optimization should be done using `LightningModule.{un}toggle_optimizer()` ([#8287](https://github.com/PyTorchLightning/pytorch-lightning/pull/8287)) - - ### Fixed - Fixed `lr_scheduler` checkpointed state by calling `update_lr_schedulers` before saving checkpoints ([#7877](https://github.com/PyTorchLightning/pytorch-lightning/pull/7877)) From c2665328ee258702373d944c3d442a16b4eb97df Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 12 Jul 2021 15:01:58 +0200 Subject: [PATCH 065/157] CHANGELOG --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 14c9adc46d25a..2a9cd56df7478 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -146,7 +146,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `restore` function and `restarting` attribute to base `Loop` ([#8247](https://github.com/PyTorchLightning/pytorch-lightning/pull/8247)) -- Added `FastForwardSampler` and `CaptureIterableDataset` ([#8307](https://github.com/PyTorchLightning/pytorch-lightning/pull/8307))`` +- Added `FastForwardSampler` and `CaptureIterableDataset` ([#8307](https://github.com/PyTorchLightning/pytorch-lightning/pull/8307)) - Added support for `save_hyperparameters` in `LightningDataModule` ([#3792](https://github.com/PyTorchLightning/pytorch-lightning/pull/3792)) From 30ddd1030f9d0fc779a03b243927089bbdbb64ec Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Mon, 12 Jul 2021 15:35:07 +0200 Subject: [PATCH 066/157] Update pytorch_lightning/loops/base.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/loops/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 9997baac79cc5..a8173d523de3d 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -56,7 +56,7 @@ def trainer(self) -> Optional['pl.Trainer']: @trainer.setter def trainer(self, trainer: 'pl.Trainer'): - """Connect the Trainer to itself and all its children loops""" + """Connects this loop's trainer and it's children""" if not isinstance(trainer, pl.Trainer): raise MisconfigurationException( f"Loop {self.__class__.__name__} should be connected to a `Trainer`, found: {trainer}." From ffc6ca71f6938a9594592a2897d4331a01303c90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 12 Jul 2021 16:21:32 +0200 Subject: [PATCH 067/157] whitespace suggestions --- pytorch_lightning/loops/batch/training_batch_loop.py | 1 + pytorch_lightning/loops/dataloader/evaluation_loop.py | 1 + pytorch_lightning/loops/dataloader/prediction_loop.py | 1 + 3 files changed, 3 insertions(+) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index e76ebf704cf38..d27b2a34987b5 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -52,6 +52,7 @@ def __init__(self) -> None: self.split_idx: Optional[int] = None self.progress = BatchProgress() self.optim_progress = OptimizationProgress() + self._warning_cache: WarningCache = WarningCache() self._hiddens: Optional[Tensor] = None self._optimizer_freq_cumsum: Optional[int] = None diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 5dc0270f58774..ba554bf9c1a29 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -35,6 +35,7 @@ def __init__(self): self.outputs = [] self.progress = DataLoaderProgress() self.epoch_loop = EvaluationEpochLoop() + self._results = ResultCollection(training=False) self._max_batches: Optional[Union[int, Sequence[int]]] = None self._has_run: bool = False diff --git a/pytorch_lightning/loops/dataloader/prediction_loop.py b/pytorch_lightning/loops/dataloader/prediction_loop.py index 1bdd38ed950b0..6a58a2c78f4b1 100644 --- a/pytorch_lightning/loops/dataloader/prediction_loop.py +++ b/pytorch_lightning/loops/dataloader/prediction_loop.py @@ -21,6 +21,7 @@ def __init__(self): self.epoch_batch_indices: Optional[List[List[int]]] = None self.progress = EpochLoopProgress() self.epoch_loop = PredictionEpochLoop() + self._results = None # for `trainer._results` access self._return_predictions: bool = False From 9ac0b61967619eb81c493ca7b5479c2c2601bfa4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 12 Jul 2021 14:22:45 +0000 Subject: [PATCH 068/157] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/loops/dataloader/prediction_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/dataloader/prediction_loop.py b/pytorch_lightning/loops/dataloader/prediction_loop.py index 6a58a2c78f4b1..345a6296578f5 100644 --- a/pytorch_lightning/loops/dataloader/prediction_loop.py +++ b/pytorch_lightning/loops/dataloader/prediction_loop.py @@ -21,7 +21,7 @@ def __init__(self): self.epoch_batch_indices: Optional[List[List[int]]] = None self.progress = EpochLoopProgress() self.epoch_loop = PredictionEpochLoop() - + self._results = None # for `trainer._results` access self._return_predictions: bool = False From 8ddb020530277bfc2462cbeed0d6c8e821d15b98 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 12 Jul 2021 16:23:07 +0200 Subject: [PATCH 069/157] make fault_tolerant_enabled protected --- pytorch_lightning/loops/dataloader/prediction_loop.py | 2 +- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 4 ++-- pytorch_lightning/trainer/trainer.py | 4 ++-- pytorch_lightning/utilities/imports.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/loops/dataloader/prediction_loop.py b/pytorch_lightning/loops/dataloader/prediction_loop.py index 6a58a2c78f4b1..345a6296578f5 100644 --- a/pytorch_lightning/loops/dataloader/prediction_loop.py +++ b/pytorch_lightning/loops/dataloader/prediction_loop.py @@ -21,7 +21,7 @@ def __init__(self): self.epoch_batch_indices: Optional[List[List[int]]] = None self.progress = EpochLoopProgress() self.epoch_loop = PredictionEpochLoop() - + self._results = None # for `trainer._results` access self._return_predictions: bool = False diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index df1328d668305..40b59f8c93f54 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -23,7 +23,7 @@ from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_deprecation, rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import fault_tolerant_enabled +from pytorch_lightning.utilities.imports import _fault_tolerant_enabled from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS if _OMEGACONF_AVAILABLE: @@ -347,7 +347,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: 'state_dict': self.trainer.accelerator.lightning_module_state_dict(), } - if fault_tolerant_enabled(): + if _fault_tolerant_enabled(): checkpoint.update({"loops": self.get_loops_state_dict()}) if not weights_only: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index a9e316bfc5a2f..c7e5224593744 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -76,7 +76,7 @@ from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.distributed import distributed_available from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import fault_tolerant_enabled +from pytorch_lightning.utilities.imports import _fault_tolerant_enabled from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS @@ -1252,6 +1252,6 @@ def _log_device_info(self) -> None: ) def on_expection(self): - if fault_tolerant_enabled(): + if _fault_tolerant_enabled(): # save a checkpoint for fault tolerant training self.fit_loop._check_checkpoint_callback(True) diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index fdd5382ca751d..347bcd1ecf544 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -104,5 +104,5 @@ def _compare_version(package: str, op, version) -> bool: _IPU_AVAILABLE = False -def fault_tolerant_enabled(): +def _fault_tolerant_enabled(): return os.getenv("PL_FAULT_TOLERANT_TRAINING", "0") == "1" From 50b6f49c1801955c71c31fcd9ea081e7dadd7efc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 12 Jul 2021 16:24:10 +0200 Subject: [PATCH 070/157] whitespace fixes around Args --- pytorch_lightning/trainer/progress.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pytorch_lightning/trainer/progress.py b/pytorch_lightning/trainer/progress.py index db2321365bfa6..e5746fa05b283 100644 --- a/pytorch_lightning/trainer/progress.py +++ b/pytorch_lightning/trainer/progress.py @@ -35,6 +35,7 @@ def from_state_dict(cls, state_dict: dict) -> "BaseProgress": class Tracker(BaseProgress): """ Track an event's progress. + Args: ready: Intended to track the number of events ready to start. started: Intended to be incremented after the event is started (e.g. after ``on_*_start`` runs). @@ -86,6 +87,7 @@ def reset_on_restart(self): class Progress(BaseProgress): """ Track aggregated and current progress. + Args: total: Intended to track the total progress of an event current: Intended to track the current progress of an event @@ -131,6 +133,7 @@ def load_state_dict(self, state_dict: dict) -> None: class BatchProgress(Progress): """ Tracks the batch progress + Args: total: Tracks the total epoch progress current: Tracks the current epoch progress @@ -141,6 +144,7 @@ class BatchProgress(Progress): class TrainingEpochProgress(Progress): """ Tracks the batch progress + Args: total: Tracks the total epoch progress current: Tracks the current epoch progress @@ -167,6 +171,7 @@ class EpochProgress(Progress): """ Tracks the epoch progress These counters are local to a trainer rank. By default, they are not globally synced across all ranks. + Args: total: Tracks the total epoch progress current: Tracks the current epoch progress @@ -188,6 +193,7 @@ def load_state_dict(self, state_dict: dict) -> None: class OptimizerProgress(BaseProgress): """ Track optimizer progress. + Args: step: Tracks ``optimizer.step`` calls. zero_grad: Tracks ``optimizer.zero_grad`` calls. @@ -209,6 +215,7 @@ def load_state_dict(self, state_dict: dict) -> None: class OptimizationProgress(BaseProgress): """ Track optimization progress. + Args: optimizer: Tracks optimizer progress. scheduler: Tracks scheduler progress. @@ -242,6 +249,7 @@ class EpochLoopProgress(BaseProgress): """ Tracks epoch loop progress. These counters are local to a trainer rank. By default, they are not globally synced across all ranks. + Args: epoch: Tracks epochs progress. """ From 8e9682ed8e3abda025179aa58dc3f5073f3d0c2a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 12 Jul 2021 14:25:40 +0000 Subject: [PATCH 071/157] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/trainer/progress.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/progress.py b/pytorch_lightning/trainer/progress.py index e5746fa05b283..e37f14960220d 100644 --- a/pytorch_lightning/trainer/progress.py +++ b/pytorch_lightning/trainer/progress.py @@ -35,7 +35,7 @@ def from_state_dict(cls, state_dict: dict) -> "BaseProgress": class Tracker(BaseProgress): """ Track an event's progress. - + Args: ready: Intended to track the number of events ready to start. started: Intended to be incremented after the event is started (e.g. after ``on_*_start`` runs). From 0838d7a7fa957f1f20cbd5afe9128400a94101b0 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 12 Jul 2021 20:21:39 +0200 Subject: [PATCH 072/157] update --- tests/loops/test_loop_progress_integration.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/tests/loops/test_loop_progress_integration.py b/tests/loops/test_loop_progress_integration.py index 465ec7ad15655..4395cb5cdcf3b 100644 --- a/tests/loops/test_loop_progress_integration.py +++ b/tests/loops/test_loop_progress_integration.py @@ -9,11 +9,8 @@ def test_loop_progress_integration(): assert trainer.test_loop.progress is not trainer.predict_loop.progress # check the validation progresses are not shared assert trainer.fit_loop.epoch_loop.val_loop.progress is not trainer.validate_loop.progress - expected = _collect_loop_progress(trainer.fit_loop)["epoch_loop"]["progress"] - assert expected == trainer.fit_loop.epoch_loop.progress - expected = _collect_loop_progress(trainer.fit_loop)["epoch_loop"]["batch_loop"]["progress"] - assert expected == trainer.fit_loop.epoch_loop.batch_loop.progress - expected = _collect_loop_progress(trainer.fit_loop)["epoch_loop"]["val_loop"]["progress"] - assert expected == trainer.fit_loop.epoch_loop.val_loop.progress - expected = _collect_loop_progress(trainer.fit_loop)["epoch_loop"]["val_loop"]["epoch_loop"]["progress"] - assert expected == trainer.fit_loop.epoch_loop.val_loop.epoch_loop.progress + generated = _collect_loop_progress(trainer.fit_loop)["epoch_loop"] + assert generated["progress"] is trainer.fit_loop.epoch_loop.progress + assert generated["batch_loop"]["progress"] is trainer.fit_loop.epoch_loop.batch_loop.progress + assert generated["val_loop"]["progress"] is trainer.fit_loop.epoch_loop.val_loop.progress + assert generated["val_loop"]["epoch_loop"]["progress"] is trainer.fit_loop.epoch_loop.val_loop.epoch_loop.progress From 107e1437dd6ae90e161ac365d38bae06f0e52a01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 13 Jul 2021 00:38:50 +0200 Subject: [PATCH 073/157] typo it's -> its --- pytorch_lightning/loops/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index a8173d523de3d..9209dcb993284 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -56,7 +56,7 @@ def trainer(self) -> Optional['pl.Trainer']: @trainer.setter def trainer(self, trainer: 'pl.Trainer'): - """Connects this loop's trainer and it's children""" + """Connects this loop's trainer and its children""" if not isinstance(trainer, pl.Trainer): raise MisconfigurationException( f"Loop {self.__class__.__name__} should be connected to a `Trainer`, found: {trainer}." From e49cd508ec3dcb4b09733d0bc2b80e3f7f545146 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 13 Jul 2021 00:39:13 +0200 Subject: [PATCH 074/157] fix copy-paste typo in progress docstring --- pytorch_lightning/trainer/progress.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/progress.py b/pytorch_lightning/trainer/progress.py index e37f14960220d..66921ccc94a4d 100644 --- a/pytorch_lightning/trainer/progress.py +++ b/pytorch_lightning/trainer/progress.py @@ -143,7 +143,7 @@ class BatchProgress(Progress): @dataclass class TrainingEpochProgress(Progress): """ - Tracks the batch progress + Tracks the epoch progress Args: total: Tracks the total epoch progress From 2e0423a046a8f55cafacd886e1ec834f17d9b1cf Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 13 Jul 2021 14:11:23 +0200 Subject: [PATCH 075/157] Delete classes --- .../loops/batch/training_batch_loop.py | 8 +-- .../loops/dataloader/prediction_loop.py | 12 ++--- .../loops/epoch/prediction_epoch_loop.py | 13 +---- pytorch_lightning/trainer/progress.py | 54 ++----------------- 4 files changed, 13 insertions(+), 74 deletions(-) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index d27b2a34987b5..334a80241fe59 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -28,7 +28,7 @@ from pytorch_lightning.loops.base import Loop from pytorch_lightning.plugins import ParallelPlugin from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection -from pytorch_lightning.trainer.progress import BatchProgress, OptimizationProgress +from pytorch_lightning.trainer.progress import OptimizationProgress, Progress from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import AMPType, AttributeDict, DeviceType, grad_norm from pytorch_lightning.utilities.apply_func import apply_to_collection @@ -50,7 +50,7 @@ def __init__(self) -> None: self.running_loss: TensorRunningAccum = TensorRunningAccum(window_length=20) self.batch_idx: int = 0 self.split_idx: Optional[int] = None - self.progress = BatchProgress() + self.progress = Progress() self.optim_progress = OptimizationProgress() self._warning_cache: WarningCache = WarningCache() @@ -437,9 +437,9 @@ def _on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None: Args: optimizer: the current optimizer """ - self.optim_progress.optimizer.zero_grad.increment_started() - self.trainer.call_hook('on_before_zero_grad', optimizer) self.optim_progress.optimizer.zero_grad.increment_ready() + self.trainer.call_hook('on_before_zero_grad', optimizer) + self.optim_progress.optimizer.zero_grad.increment_started() def _optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer, opt_idx: int) -> None: """Zeroes out all gradients of parameters optimized by the current optimizer. diff --git a/pytorch_lightning/loops/dataloader/prediction_loop.py b/pytorch_lightning/loops/dataloader/prediction_loop.py index 345a6296578f5..51eccdf202051 100644 --- a/pytorch_lightning/loops/dataloader/prediction_loop.py +++ b/pytorch_lightning/loops/dataloader/prediction_loop.py @@ -7,7 +7,7 @@ from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop from pytorch_lightning.loops.epoch.prediction_epoch_loop import PredictionEpochLoop from pytorch_lightning.plugins import DDPSpawnPlugin -from pytorch_lightning.trainer.progress import EpochLoopProgress +from pytorch_lightning.trainer.progress import DataLoaderProgress from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.types import _PREDICT_OUTPUT @@ -19,7 +19,7 @@ def __init__(self): super().__init__() self.predictions: Optional[List[List[Any]]] = None self.epoch_batch_indices: Optional[List[List[int]]] = None - self.progress = EpochLoopProgress() + self.progress = DataLoaderProgress() self.epoch_loop = PredictionEpochLoop() self._results = None # for `trainer._results` access @@ -75,14 +75,10 @@ def done(self) -> bool: def skip(self) -> bool: return sum(self.max_batches) == 0 - def connect( - self, trainer: "pl.Trainer", *args: Any, progress: Optional[EpochLoopProgress] = None, **kwargs: Any - ) -> None: + def connect(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: """Connects the loop with necessary arguments like the trainer""" super().connect(trainer, *args, **kwargs) - if progress is not None: - self.progress = progress - self.epoch_loop.connect(trainer, progress=self.progress.epoch) + self.epoch_loop.connect(trainer) def reset(self) -> None: """Resets the internal state of the loop for a new run""" diff --git a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py index ea03be5ef0096..f94e106a8c444 100644 --- a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py @@ -3,10 +3,9 @@ from deprecate import void -import pytorch_lightning as pl from pytorch_lightning.loops.base import Loop from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper -from pytorch_lightning.trainer.progress import EpochProgress +from pytorch_lightning.trainer.progress import Progress from pytorch_lightning.utilities.warnings import WarningCache @@ -18,21 +17,13 @@ def __init__(self) -> None: self.return_predictions: bool = False self.predictions: List[Any] = [] self.current_batch_indices: List[int] = [] - self.progress = EpochProgress() + self.progress = Progress() self._dl_max_batches: Optional[int] = None self._num_dataloaders: Optional[int] = None self._warning_cache = WarningCache() self._all_batch_indices: List[int] = [] - def connect( - self, trainer: "pl.Trainer", *args: Any, progress: Optional[EpochProgress] = None, **kwargs: Any - ) -> None: - """Connects the loop with necessary arguments like the trainer""" - super().connect(trainer, *args, **kwargs) - if progress is not None: - self.progress = progress - @property def done(self) -> bool: """Ends prediction when the iteration count exceeds the total number of available batches""" diff --git a/pytorch_lightning/trainer/progress.py b/pytorch_lightning/trainer/progress.py index 66921ccc94a4d..78215031a76e8 100644 --- a/pytorch_lightning/trainer/progress.py +++ b/pytorch_lightning/trainer/progress.py @@ -41,6 +41,7 @@ class Tracker(BaseProgress): started: Intended to be incremented after the event is started (e.g. after ``on_*_start`` runs). processed: Intended to be incremented after the event is processed. completed: Intended to be incremented after the event completes (e.g. after ``on_*_end`` runs). + Attributes set to ``None`` are treated as unused and are restricted. """ @@ -129,17 +130,6 @@ def load_state_dict(self, state_dict: dict) -> None: self.current.load_state_dict(state_dict["current"]) -@dataclass -class BatchProgress(Progress): - """ - Tracks the batch progress - - Args: - total: Tracks the total epoch progress - current: Tracks the current epoch progress - """ - - @dataclass class TrainingEpochProgress(Progress): """ @@ -158,34 +148,19 @@ def load_state_dict(self, state_dict: dict) -> None: @dataclass class DataLoaderProgress(Progress): - - dataloader_idx: int = 0 - - def load_state_dict(self, state_dict: dict) -> None: - super().load_state_dict(state_dict) - self.dataloader_idx = state_dict["dataloader_idx"] - - -@dataclass -class EpochProgress(Progress): """ - Tracks the epoch progress + Tracks the data-loader progress These counters are local to a trainer rank. By default, they are not globally synced across all ranks. Args: total: Tracks the total epoch progress current: Tracks the current epoch progress - batch: Tracks batch progress. + dataloader_idx: The index of the current dataloader. """ dataloader_idx: int = 0 - batch: BatchProgress = field(default_factory=BatchProgress) - - def reset_on_epoch(self) -> None: - self.batch.current.reset() def load_state_dict(self, state_dict: dict) -> None: super().load_state_dict(state_dict) - self.batch.load_state_dict(state_dict["batch"]) self.dataloader_idx = state_dict["dataloader_idx"] @@ -242,26 +217,3 @@ def load_state_dict(self, state_dict: dict) -> None: self.optimizer.load_state_dict(state_dict["optimizer"]) self.scheduler.load_state_dict(state_dict["scheduler"]) self.optimizer_idx = state_dict["optimizer_idx"] - - -@dataclass -class EpochLoopProgress(BaseProgress): - """ - Tracks epoch loop progress. - These counters are local to a trainer rank. By default, they are not globally synced across all ranks. - - Args: - epoch: Tracks epochs progress. - """ - epoch: EpochProgress = field(default_factory=EpochProgress) - - def increment_epoch_completed(self) -> None: - self.epoch.increment_completed() - self.reset_on_epoch() - - def reset_on_epoch(self) -> None: - self.epoch.reset_on_epoch() - self.epoch.current.reset() - - def load_state_dict(self, state_dict: dict) -> None: - self.epoch.load_state_dict(state_dict["epoch"]) From 7caca875f18e02d8fd0725942561d027ae987687 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 13 Jul 2021 14:18:35 +0200 Subject: [PATCH 076/157] Minor change --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 40b59f8c93f54..c4114c554962b 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -257,7 +257,7 @@ def restore_loops(self) -> None: if not self._loaded_checkpoint: return - state_dict = self._loaded_checkpoint.get("loops", None) + state_dict = self._loaded_checkpoint.get("loops") if state_dict: self.trainer.fit_loop.load_state_dict(state_dict["fit_loop"]) self.trainer.validate_loop.load_state_dict(state_dict["validate_loop"]) @@ -346,9 +346,8 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: 'pytorch-lightning_version': pl.__version__, 'state_dict': self.trainer.accelerator.lightning_module_state_dict(), } - if _fault_tolerant_enabled(): - checkpoint.update({"loops": self.get_loops_state_dict()}) + checkpoint["loops"] = self.get_loops_state_dict() if not weights_only: # dump callbacks From 2800eaec82a28373a6c9966fce3384868c7a8d11 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 13 Jul 2021 14:19:31 +0200 Subject: [PATCH 077/157] docs --- pytorch_lightning/trainer/progress.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/progress.py b/pytorch_lightning/trainer/progress.py index 78215031a76e8..5f153f45002c0 100644 --- a/pytorch_lightning/trainer/progress.py +++ b/pytorch_lightning/trainer/progress.py @@ -194,12 +194,13 @@ class OptimizationProgress(BaseProgress): Args: optimizer: Tracks optimizer progress. scheduler: Tracks scheduler progress. + optimizer_idx: The index of the current optimizer. """ # TODO: support for multiple optimizers - optimizer_idx: int = 0 optimizer: OptimizerProgress = field(default_factory=OptimizerProgress) scheduler: Progress = field(default_factory=lambda: Progress.from_defaults(started=None, processed=None)) + optimizer_idx: int = 0 @property def optimizer_steps(self) -> int: From feec34fba58007cc4beea6463c21ad389853fa13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 13 Jul 2021 14:30:38 +0200 Subject: [PATCH 078/157] protected get_loops_state --- .../trainer/connectors/checkpoint_connector.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index c4114c554962b..21bd347958237 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -347,7 +347,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: 'state_dict': self.trainer.accelerator.lightning_module_state_dict(), } if _fault_tolerant_enabled(): - checkpoint["loops"] = self.get_loops_state_dict() + checkpoint["loops"] = self._get_loops_state_dict() if not weights_only: # dump callbacks @@ -387,14 +387,6 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: return checkpoint - def get_loops_state_dict(self): - return { - "fit_loop": self.trainer.fit_loop.state_dict(), - "validate_loop": self.trainer.validate_loop.state_dict(), - "test_loop": self.trainer.test_loop.state_dict(), - "predict_loop": self.trainer.predict_loop.state_dict(), - } - def hpc_load(self, checkpoint_path: str) -> None: """ Attempts to restore the full training and model state from a HPC checkpoint file. @@ -453,3 +445,11 @@ def save_checkpoint(self, filepath, weights_only: bool = False) -> None: """ _checkpoint = self.dump_checkpoint(weights_only) self.trainer.accelerator.save_checkpoint(_checkpoint, filepath) + + def _get_loops_state_dict(self): + return { + "fit_loop": self.trainer.fit_loop.state_dict(), + "validate_loop": self.trainer.validate_loop.state_dict(), + "test_loop": self.trainer.test_loop.state_dict(), + "predict_loop": self.trainer.predict_loop.state_dict(), + } \ No newline at end of file From ccdd09d700124b242dbfb9410115ac16b3d24d6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 13 Jul 2021 14:34:35 +0200 Subject: [PATCH 079/157] merge restore_loops with restore_progress --- .../connectors/checkpoint_connector.py | 33 ++++++++----------- 1 file changed, 13 insertions(+), 20 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 21bd347958237..7c09804add72b 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -161,13 +161,11 @@ def restore_training_state(self) -> None: # restore precision plugin (scaler etc.) self.trainer.precision_plugin.on_load_checkpoint(self._loaded_checkpoint) - # restore progress (loops etc.) - self.restore_progress() + # restore loops and their progress + self.restore_loops() self.restore_optimizers_and_schedulers() - self.restore_loops() - def restore_callbacks(self) -> None: """ Restores all callbacks from the pre-loaded checkpoint. """ if not self._loaded_checkpoint: @@ -182,10 +180,10 @@ def restore_callbacks(self) -> None: ) self.trainer.on_load_checkpoint(self._loaded_checkpoint) - def restore_progress(self) -> None: + def restore_loops(self) -> None: """ - Restores the training progress from the pre-loaded checkpoint. This currently includes only the global step - and current epoch. + Restores the loop progress from the pre-loaded checkpoint. + Calls hooks on the loops to give it a chance to restore its state from the checkpoint. """ if not self._loaded_checkpoint: return @@ -212,6 +210,13 @@ def restore_progress(self) -> None: " consider using an end of epoch checkpoint." ) + state_dict = self._loaded_checkpoint.get("loops") + if state_dict: + self.trainer.fit_loop.load_state_dict(state_dict["fit_loop"]) + self.trainer.validate_loop.load_state_dict(state_dict["validate_loop"]) + self.trainer.test_loop.load_state_dict(state_dict["test_loop"]) + self.trainer.predict_loop.load_state_dict(state_dict["predict_loop"]) + def restore_optimizers_and_schedulers(self) -> None: """ Restores the optimizers and learning rate scheduler states from the pre-loaded checkpoint. """ if not self._loaded_checkpoint: @@ -252,18 +257,6 @@ def restore_lr_schedulers(self) -> None: for scheduler, lrs_state in zip(self.trainer.lr_schedulers, lr_schedulers): scheduler['scheduler'].load_state_dict(lrs_state) - def restore_loops(self) -> None: - """ Calls hooks on the loops to give it a chance to restore its state from the checkpoint. """ - if not self._loaded_checkpoint: - return - - state_dict = self._loaded_checkpoint.get("loops") - if state_dict: - self.trainer.fit_loop.load_state_dict(state_dict["fit_loop"]) - self.trainer.validate_loop.load_state_dict(state_dict["validate_loop"]) - self.trainer.test_loop.load_state_dict(state_dict["test_loop"]) - self.trainer.predict_loop.load_state_dict(state_dict["predict_loop"]) - # ---------------------------------- # PRIVATE OPS # ---------------------------------- @@ -452,4 +445,4 @@ def _get_loops_state_dict(self): "validate_loop": self.trainer.validate_loop.state_dict(), "test_loop": self.trainer.test_loop.state_dict(), "predict_loop": self.trainer.predict_loop.state_dict(), - } \ No newline at end of file + } From 01768cb23a28daf00c684e80f8f337765e8bd723 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 13 Jul 2021 14:41:25 +0200 Subject: [PATCH 080/157] Fix tests after removals --- tests/trainer/test_progress.py | 72 ++++++---------------------------- 1 file changed, 11 insertions(+), 61 deletions(-) diff --git a/tests/trainer/test_progress.py b/tests/trainer/test_progress.py index 6322ab5be33bb..32d97be167b52 100644 --- a/tests/trainer/test_progress.py +++ b/tests/trainer/test_progress.py @@ -21,12 +21,12 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.trainer.progress import ( - BatchProgress, - EpochLoopProgress, - EpochProgress, + DataLoaderProgress, + OptimizationProgress, OptimizerProgress, Progress, Tracker, + TrainingEpochProgress, ) from tests.helpers import BoringModel @@ -79,20 +79,9 @@ def test_base_progress_from_defaults(): assert actual == expected -def test_epoch_loop_progress_increment_epoch(): - p = EpochLoopProgress() - p.increment_epoch_completed() - p.increment_epoch_completed() - assert p.epoch.total == Tracker(completed=2) - assert p.epoch.current == Tracker() - assert p.epoch.batch.current == Tracker() - - def test_epoch_loop_progress_increment_sequence(): """Test sequences for incrementing batches reads and epochs.""" - batch = BatchProgress(total=Tracker(started=None)) - epoch = EpochProgress(batch=batch) - loop = EpochLoopProgress(epoch=epoch) + batch = Progress(total=Tracker(started=None)) batch.increment_ready() assert batch.total == Tracker(ready=1, started=None) @@ -110,26 +99,6 @@ def test_epoch_loop_progress_increment_sequence(): assert batch.total == Tracker(ready=1, started=None, processed=1, completed=1) assert batch.current == Tracker(ready=1, processed=1, completed=1) - assert epoch.total == Tracker() - assert epoch.current == Tracker() - loop.increment_epoch_completed() - assert batch.total == Tracker(ready=1, started=None, processed=1, completed=1) - assert batch.current == Tracker() - assert epoch.total == Tracker(completed=1) - assert epoch.current == Tracker() - - batch.increment_ready() - assert batch.total == Tracker(ready=2, started=None, processed=1, completed=1) - assert batch.current == Tracker(ready=1) - assert epoch.total == Tracker(completed=1) - assert epoch.current == Tracker() - - loop.reset_on_epoch() - assert batch.total == Tracker(ready=2, started=None, processed=1, completed=1) - assert batch.current == Tracker() - assert epoch.total == Tracker(completed=1) - assert epoch.current == Tracker() - def test_optimizer_progress_default_factory(): """ @@ -144,32 +113,13 @@ def test_optimizer_progress_default_factory(): assert p2.step.total.completed == 0 -def test_epoch_loop_progress_serialization(): - loop = EpochLoopProgress() - loop.epoch.dataloader_idx = 1 - _ = deepcopy(loop) - state_dict = loop.state_dict() - - # yapf: disable - assert state_dict == { - 'epoch': { - # number of times `validate` has been called - 'total': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0}, - # either 0 or 1 as `max_epochs` does not apply to the `validate` loop - 'current': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0}, - 'batch': { - # number of batches across `validate` calls - 'total': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0}, - # number of batches this `validate` call - 'current': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0}, - }, - 'dataloader_idx': 1 - } - } - # yapf: enable - - new_loop = EpochLoopProgress.from_state_dict(state_dict) - assert loop == new_loop +def test_deepcopy(): + _ = deepcopy(Tracker()) + _ = deepcopy(Progress()) + _ = deepcopy(TrainingEpochProgress()) + _ = deepcopy(DataLoaderProgress()) + _ = deepcopy(OptimizerProgress()) + _ = deepcopy(OptimizationProgress()) @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) From 71e05d3553b480fcfb7813d4ed8e3effe34b7d1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 14 Jul 2021 10:04:18 +0200 Subject: [PATCH 081/157] explicit save with trainer.save_checkpoint() --- pytorch_lightning/trainer/trainer.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index c7e5224593744..bb65ff47817e7 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -13,6 +13,7 @@ # limitations under the License. """Trainer to automate the training.""" import logging +import os import traceback import warnings from datetime import timedelta @@ -1010,7 +1011,7 @@ def _run_train(self) -> None: self.training_type_plugin.reconciliate_processes(traceback.format_exc()) # give accelerators a chance to finish self.accelerator.on_train_end() - self.on_expection() + self._on_expection() # reset bookkeeping self.state.stage = None raise @@ -1251,7 +1252,10 @@ def _log_device_info(self) -> None: " `Trainer(ipus=8)` or script `--ipus=8`." ) - def on_expection(self): - if _fault_tolerant_enabled(): - # save a checkpoint for fault tolerant training - self.fit_loop._check_checkpoint_callback(True) + def _on_expection(self): + if not self.is_global_zero or not _fault_tolerant_enabled(): + return + + # save a checkpoint for fault tolerant training + file_path = os.path.join(self.default_root_dir, ".pl_auto_save.ckpt") + self.save_checkpoint(file_path) From 6ca7b9cd04dd5ac7e4f9002965a6f05193e8b96a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 14 Jul 2021 10:43:24 +0200 Subject: [PATCH 082/157] update setter for trainer and connect all loops --- pl_examples/loop_examples/example3.py | 10 +++++++--- .../loops/batch/training_batch_loop.py | 18 +++--------------- .../loops/dataloader/evaluation_loop.py | 16 ---------------- .../loops/dataloader/prediction_loop.py | 18 +----------------- .../loops/epoch/evaluation_epoch_loop.py | 11 +++-------- .../loops/epoch/prediction_epoch_loop.py | 11 +++-------- .../loops/epoch/training_epoch_loop.py | 13 ++++++------- pytorch_lightning/loops/fit_loop.py | 7 ------- pytorch_lightning/trainer/trainer.py | 11 +++++++++++ 9 files changed, 34 insertions(+), 81 deletions(-) diff --git a/pl_examples/loop_examples/example3.py b/pl_examples/loop_examples/example3.py index 317293c362746..a11f7aa84a2d5 100644 --- a/pl_examples/loop_examples/example3.py +++ b/pl_examples/loop_examples/example3.py @@ -37,13 +37,17 @@ def run(): # call connect on the existing, default fit_loop.epoch_loop trainer.fit_loop.epoch_loop.connect(batch_loop=new_batch_loop, val_loop=new_val_loop) - # the new batch loop is registered and the trainer got linked internally - assert trainer.fit_loop.epoch_loop.batch_loop == new_batch_loop - assert trainer.fit_loop.epoch_loop.batch_loop.trainer == trainer + # the new batch loop is registered + assert trainer.fit_loop.epoch_loop.batch_loop is new_batch_loop + + # the trainer is not yet registered, will be done by the trainer internally + assert trainer.fit_loop.epoch_loop.batch_loop.trainer is None # this uses the new custom batch loop trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data) + assert trainer.fit_loop.epoch_loop.batch_loop.trainer is trainer + if __name__ == '__main__': run() diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index a56e123aba923..690263f3aa992 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -59,21 +59,6 @@ def __init__(self) -> None: self._remaining_splits: Optional[List[Any]] = None self._skip_backward: bool = False - # def connect( - # self, - # trainer: 'pl.Trainer', - # *args: Any, - # progress: Optional[BatchProgress] = None, - # optim_progress: Optional[OptimizationProgress] = None, - # **kwargs: Any - # ) -> None: - # """Called by the Trainer. Connects a Loop with all the necessary components like progress, etc.""" - # super().connect(trainer, *args, **kwargs) - # if progress is not None: - # self.progress = progress - # if optim_progress is not None: - # self.optim_progress = optim_progress - @property def done(self) -> bool: """Returns if all batch splits have been processed already""" @@ -86,6 +71,9 @@ def optimizer_freq_cumsum(self) -> int: self._optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies) return self._optimizer_freq_cumsum + def connect(self, **kwargs: "Loop") -> None: + raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.") + def run(self, batch: Any, batch_idx: int, dataloader_idx: int) -> AttributeDict: """Runs all the data splits and the ``on_batch_start`` and ``on_train_batch_start`` hooks diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 933890c26d6c6..ac582861ac821 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -40,11 +40,6 @@ def __init__(self): self._max_batches: Optional[Union[int, Sequence[int]]] = None self._has_run: bool = False - @DataLoaderLoop.trainer.setter - def trainer(self, trainer): - self._trainer = trainer - self.epoch_loop.trainer = trainer - @property def num_dataloaders(self) -> int: """Returns the total number of dataloaders""" @@ -73,17 +68,6 @@ def predictions(self): def connect(self, epoch_loop: EvaluationEpochLoop): """Connect the evaluation epoch loop with this loop.""" self.epoch_loop = epoch_loop - if self.trainer is not None: - self.epoch_loop.trainer = self.trainer - - # def connect( - # self, trainer: "pl.Trainer", *args: Any, progress: Optional[EpochLoopProgress] = None, **kwargs: Any - # ) -> None: - # """Called by the Trainer. Connects a Loop with all the necessary components like progress, etc.""" - # super().connect(trainer, *args, **kwargs) - # if progress is not None: - # self.progress = progress - # self.epoch_loop.connect(trainer, progress=self.progress.epoch) @property def done(self) -> bool: diff --git a/pytorch_lightning/loops/dataloader/prediction_loop.py b/pytorch_lightning/loops/dataloader/prediction_loop.py index 4d19d62f536a9..61794885f0a86 100644 --- a/pytorch_lightning/loops/dataloader/prediction_loop.py +++ b/pytorch_lightning/loops/dataloader/prediction_loop.py @@ -3,7 +3,6 @@ from deprecate.utils import void from torch.utils.data import DataLoader -import pytorch_lightning as pl from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop from pytorch_lightning.loops.epoch.prediction_epoch_loop import PredictionEpochLoop from pytorch_lightning.plugins import DDPSpawnPlugin @@ -25,11 +24,6 @@ def __init__(self): self._results = None # for `trainer._results` access self._return_predictions: bool = False - @DataLoaderLoop.trainer.setter - def trainer(self, trainer): - self._trainer = trainer - self.epoch_loop.trainer = trainer - @property def return_predictions(self) -> bool: """Whether to return the predictions or not""" @@ -81,18 +75,8 @@ def skip(self) -> bool: return sum(self.max_batches) == 0 def connect(self, epoch_loop: PredictionEpochLoop): + """Connect the prediction epoch loop with this loop.""" self.epoch_loop = epoch_loop - if self.trainer is not None: - self.epoch_loop.trainer = self.trainer - - # def connect( - # self, trainer: "pl.Trainer", *args: Any, progress: Optional[EpochLoopProgress] = None, **kwargs: Any - # ) -> None: - # """Called by the Trainer. Connects a Loop with all the necessary components like progress, etc.""" - # super().connect(trainer, *args, **kwargs) - # if progress is not None: - # self.progress = progress - # self.epoch_loop.connect(trainer, progress=self.progress.epoch) def reset(self) -> None: """Resets the internal state of the loop for a new run""" diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index 5e723a1ba3b81..9d742571f8212 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -43,19 +43,14 @@ def __init__(self) -> None: self.outputs: List[STEP_OUTPUT] = [] self.progress = EpochProgress() - # def connect( - # self, trainer: "pl.Trainer", *args: Any, progress: Optional[EpochProgress] = None, **kwargs: Any - # ) -> None: - # """Called by the Trainer. Connects a Loop with all the necessary components like progress, etc.""" - # super().connect(trainer, *args, **kwargs) - # if progress is not None: - # self.progress = progress - @property def done(self) -> bool: """Returns ``True`` if the current iteration count reaches the number of dataloader batches.""" return self.iteration_count >= self.dl_max_batches + def connect(self, **kwargs: "Loop") -> None: + raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.") + def reset(self) -> None: """Resets the loop's internal state.""" self.iteration_count = 0 diff --git a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py index 43ccbc74209c8..bc66f8570bc80 100644 --- a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py @@ -25,14 +25,6 @@ def __init__(self) -> None: self._warning_cache = WarningCache() self._all_batch_indices: List[int] = [] - # def connect( - # self, trainer: "pl.Trainer", *args: Any, progress: Optional[EpochProgress] = None, **kwargs: Any - # ) -> None: - # """Called by the Trainer. Connects a Loop with all the necessary components like progress, etc.""" - # super().connect(trainer, *args, **kwargs) - # if progress is not None: - # self.progress = progress - @property def done(self) -> bool: """Ends prediction when the iteration count exceeds the total number of available batches""" @@ -44,6 +36,9 @@ def should_store_predictions(self) -> bool: any_pred = any(cb.interval.on_epoch for cb in self.trainer.prediction_writer_callbacks) return self.return_predictions or any_pred + def connect(self, **kwargs: "Loop") -> None: + raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.") + def reset(self) -> None: """Resets the loops internal state""" self.iteration_count = 0 diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index c987080647df7..773622c097451 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -75,13 +75,12 @@ def done(self) -> bool: max_steps_reached = self.max_steps is not None and self.global_step >= self.max_steps return max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(self.is_last_batch) - def connect(self, batch_loop: TrainingBatchLoop, val_loop: "loops.EvaluationLoop") -> None: - """Called by the Trainer. Connects a Loop with all the necessary components like progress, etc.""" - self.batch_loop = batch_loop - self.val_loop = val_loop - if self.trainer is not None: - self.batch_loop.trainer = self.trainer - self.val_loop.trainer = self.trainer + def connect(self, batch_loop: TrainingBatchLoop = None, val_loop: "loops.EvaluationLoop" = None) -> None: + """Optionally connect a custom batch or validation loop to this training epoch loop.""" + if batch_loop is not None: + self.batch_loop = batch_loop + if val_loop is not None: + self.val_loop = val_loop def reset(self) -> None: """Resets the internal state of the loop for a new run""" diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 2af5cc73986d5..3f99da9c68a08 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -47,11 +47,6 @@ def __init__(self, min_epochs: Optional[int] = None, max_epochs: Optional[int] = self.epoch_loop = None self.progress: Optional[FitLoopProgress] = None - @Loop.trainer.setter - def trainer(self, trainer): - self._trainer = trainer - self.epoch_loop.trainer = trainer - @property def current_epoch(self) -> int: """Return the current epoch""" @@ -169,8 +164,6 @@ def skip(self) -> bool: def connect(self, epoch_loop: TrainingEpochLoop): """Connects a training epoch loop to this fit loop.""" self.epoch_loop = epoch_loop - if self.trainer is not None: - self.epoch_loop.trainer = self.trainer def reset(self) -> None: """Resets the internal state of this loop""" diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ad68bcd71149c..cdf49a3ef482d 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -370,15 +370,19 @@ def __init__( # default .fit() loop self.fit_loop = fit_loop + self.fit_loop.trainer = self # default .validate() loop self.validate_loop = EvaluationLoop() + self.fit_loop.trainer = self # default .test() loop self.test_loop = EvaluationLoop() + self.fit_loop.trainer = self # default .predict() loop self.predict_loop = PredictionLoop() + self.fit_loop.trainer = self # training state if weights_summary is not None and weights_summary not in ModelSummary.MODES: @@ -1056,6 +1060,8 @@ def _run_train(self) -> None: self.reset_train_val_dataloaders(model) try: + # reset trainer on this loop and all child loops in case user connected a custom loop + self.fit_loop.trainer = self self.fit_loop.run() except KeyboardInterrupt: rank_zero_warn('Detected KeyboardInterrupt, attempting graceful shutdown...') @@ -1085,6 +1091,9 @@ def _run_evaluate(self) -> _EVALUATE_OUTPUT: # reload dataloaders self._evaluation_loop.reload_evaluation_dataloaders() + # reset trainer on this loop and all child loops in case user connected a custom loop + self._evaluation_loop.trainer = self + with self.profiler.profile(f"run_{self.state.stage}_evaluation"), torch.no_grad(): eval_loop_results = self._evaluation_loop.run() @@ -1099,6 +1108,8 @@ def _run_evaluate(self) -> _EVALUATE_OUTPUT: def _run_predict(self) -> Optional[_PREDICT_OUTPUT]: self.reset_predict_dataloader(self.lightning_module) + # reset trainer on this loop and all child loops in case user connected a custom loop + self.predict_loop.trainer = self with torch.no_grad(): return self.predict_loop.run() From 704543b101c13bfd3144b4fec70d30c1554304a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 14 Jul 2021 10:46:36 +0200 Subject: [PATCH 083/157] add missing types --- pytorch_lightning/loops/epoch/training_epoch_loop.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 773622c097451..6cd5d7d29f98c 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -75,7 +75,11 @@ def done(self) -> bool: max_steps_reached = self.max_steps is not None and self.global_step >= self.max_steps return max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(self.is_last_batch) - def connect(self, batch_loop: TrainingBatchLoop = None, val_loop: "loops.EvaluationLoop" = None) -> None: + def connect( + self, + batch_loop: Optional[TrainingBatchLoop] = None, + val_loop: Optional["loops.EvaluationLoop"] = None, + ) -> None: """Optionally connect a custom batch or validation loop to this training epoch loop.""" if batch_loop is not None: self.batch_loop = batch_loop From f4200614b86a4cd4378cf940df6c3d82048067c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 14 Jul 2021 10:52:15 +0200 Subject: [PATCH 084/157] update docs for fit loop --- pytorch_lightning/loops/fit_loop.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 3f99da9c68a08..439b98f11f477 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -28,18 +28,13 @@ class FitLoop(Loop): - """This Loop iterates over the epochs to run the training + """ + This Loop iterates over the epochs to run the training. Args: min_epochs: The minimum number of epochs max_epochs: The maximum number of epochs - - .. note:: - If neither the minimum epochs nor steps are specified the minimum number of epochs is set to 1 - and if neither the maximum steps nor epochs are specified, the maximum epochs are set to 1000. """ - - # FIXME: update the note above def __init__(self, min_epochs: Optional[int] = None, max_epochs: Optional[int] = None): super().__init__() self.max_epochs = max_epochs From 6562e2f7f687eedba65c632dce39ee73e288a6b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 14 Jul 2021 10:52:35 +0200 Subject: [PATCH 085/157] update docs for training epoch loop --- pytorch_lightning/loops/epoch/training_epoch_loop.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 6cd5d7d29f98c..dfdc5739459c0 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -29,7 +29,13 @@ class TrainingEpochLoop(loops.Loop): - """ Runs over all batches in a dataloader (one epoch). """ + """ + Runs over all batches in a dataloader (one epoch). + + Args: + min_steps: The minimum number of steps (batches) to process + max_steps: The maximum number of steps (batches) to process + """ def __init__(self, min_steps: int, max_steps: int): super().__init__() @@ -80,7 +86,7 @@ def connect( batch_loop: Optional[TrainingBatchLoop] = None, val_loop: Optional["loops.EvaluationLoop"] = None, ) -> None: - """Optionally connect a custom batch or validation loop to this training epoch loop.""" + """Optionally connect a custom batch- or validation loop to this training epoch loop.""" if batch_loop is not None: self.batch_loop = batch_loop if val_loop is not None: From 64b8b2064f024196cc1a0888883ffa9f33d77fd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 14 Jul 2021 10:52:59 +0200 Subject: [PATCH 086/157] update type hints for training epoch loop --- pytorch_lightning/loops/epoch/training_epoch_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index dfdc5739459c0..39c5b620299d9 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -53,8 +53,8 @@ def __init__(self, min_steps: int, max_steps: int): self.is_last_batch: Optional[bool] = None self.progress = TrainingEpochProgress() - self.batch_loop = None - self.val_loop = None + self.batch_loop = Optional[TrainingBatchLoop] + self.val_loop = Optional["loops.EvaluationLoop"] self._results = ResultCollection(training=True) self._dataloader_idx: Optional[int] = None From 2d1a7fc2ba547485d19ca8c87545f5f1bcafe994 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 14 Jul 2021 10:53:09 +0200 Subject: [PATCH 087/157] remove redundant setter --- pytorch_lightning/loops/epoch/training_epoch_loop.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 39c5b620299d9..67d4ac654e71a 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -61,12 +61,6 @@ def __init__(self, min_steps: int, max_steps: int): self._warning_cache: WarningCache = WarningCache() self._epoch_output: Optional[List[List[STEP_OUTPUT]]] = None - @loops.Loop.trainer.setter - def trainer(self, trainer): - self._trainer = trainer - self.batch_loop.trainer = trainer - self.val_loop.trainer = trainer - @property def batch_idx(self) -> int: """Returns the current batch index (within this epoch)""" From 68d90064794766de9b0b0659c1b878b94417f566 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 14 Jul 2021 08:54:21 +0000 Subject: [PATCH 088/157] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/loops/fit_loop.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 439b98f11f477..450de6c9c69b5 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -35,6 +35,7 @@ class FitLoop(Loop): min_epochs: The minimum number of epochs max_epochs: The maximum number of epochs """ + def __init__(self, min_epochs: Optional[int] = None, max_epochs: Optional[int] = None): super().__init__() self.max_epochs = max_epochs From 015aec758a3224b0b420e575a8d88cb7deeb581d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 14 Jul 2021 11:55:00 +0200 Subject: [PATCH 089/157] remove redundant setter --- pytorch_lightning/trainer/trainer.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index cdf49a3ef482d..128aedcc200a2 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -370,19 +370,15 @@ def __init__( # default .fit() loop self.fit_loop = fit_loop - self.fit_loop.trainer = self # default .validate() loop self.validate_loop = EvaluationLoop() - self.fit_loop.trainer = self # default .test() loop self.test_loop = EvaluationLoop() - self.fit_loop.trainer = self # default .predict() loop self.predict_loop = PredictionLoop() - self.fit_loop.trainer = self # training state if weights_summary is not None and weights_summary not in ModelSummary.MODES: From 3d13b645954ca4eefcbdf5a5304a5514a29910fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 14 Jul 2021 12:48:38 +0200 Subject: [PATCH 090/157] handle optimization restart based on optimizer_idx --- pytorch_lightning/loops/batch/training_batch_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 334a80241fe59..30bf3a3dc482c 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -150,7 +150,7 @@ def advance(self, batch, batch_idx, dataloader_idx): # handle optimization restart if self.restarting: - if len(active_optimizers) > 1 and opt_idx < self.progress.current.completed: + if opt_idx < self.optim_progress.optimizer_idx: continue self.restarting = False From 78d13e2c09dab959cd0da9045984ac4d146a0464 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 14 Jul 2021 13:54:51 +0200 Subject: [PATCH 091/157] update increments --- .../loops/batch/training_batch_loop.py | 17 +++---------- .../loops/epoch/training_epoch_loop.py | 25 +++++++++++-------- pytorch_lightning/loops/fit_loop.py | 6 +++++ 3 files changed, 24 insertions(+), 24 deletions(-) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 30bf3a3dc482c..912c4f8fbe548 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -50,7 +50,8 @@ def __init__(self) -> None: self.running_loss: TensorRunningAccum = TensorRunningAccum(window_length=20) self.batch_idx: int = 0 self.split_idx: Optional[int] = None - self.progress = Progress() + # TODO: add progress updates for batch splits + self.split_progress = Progress() self.optim_progress = OptimizationProgress() self._warning_cache: WarningCache = WarningCache() @@ -87,8 +88,6 @@ def run(self, batch: Any, batch_idx: int, dataloader_idx: int) -> AttributeDict: self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...") return AttributeDict(signal=0, training_step_output=[[]]) - self.progress.increment_ready() - # hook self.trainer.logger_connector.on_batch_start() response = self.trainer.call_hook("on_batch_start") @@ -100,6 +99,8 @@ def run(self, batch: Any, batch_idx: int, dataloader_idx: int) -> AttributeDict: if response == -1: return AttributeDict(signal=-1) + self.trainer.fit_loop.epoch_loop.batch_progress.increment_started() + super().run(batch, batch_idx, dataloader_idx) output = AttributeDict(signal=0, training_step_output=self.batch_outputs) self.batch_outputs = None # free memory @@ -124,10 +125,6 @@ def on_run_start(self, batch: Any, batch_idx: int, dataloader_idx: int): void(batch_idx, dataloader_idx) self._remaining_splits = list(enumerate(self._tbptt_split_batch(batch))) - def on_advance_start(self, *args: Any, **kwargs: Any) -> None: - super().on_advance_start(*args, **kwargs) - self.progress.increment_started() - def advance(self, batch, batch_idx, dataloader_idx): """Runs the train step together with optimization (if necessary) on the current batch split @@ -166,12 +163,6 @@ def advance(self, batch, batch_idx, dataloader_idx): if result: self.batch_outputs[0].append(result.training_step_output) - self.progress.increment_processed() - - def on_advance_end(self) -> None: - super().on_advance_end() - self.progress.increment_completed() - def teardown(self) -> None: # release memory self._remaining_splits = None diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 8f6cc13e64fd5..0ff4613162c4f 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -20,7 +20,7 @@ from pytorch_lightning import loops # import as loops to avoid circular imports from pytorch_lightning.loops.batch import TrainingBatchLoop from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection -from pytorch_lightning.trainer.progress import TrainingEpochProgress +from pytorch_lightning.trainer.progress import TrainingEpochProgress, Progress from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature @@ -45,7 +45,7 @@ def __init__(self, min_steps: int, max_steps: int): # the number of batches seen this run, updates immediately after batch_loop.run() self.batches_seen: int = 0 self.is_last_batch: Optional[bool] = None - self.progress = TrainingEpochProgress() + self.batch_progress = Progress() self.batch_loop = TrainingBatchLoop() self.val_loop = loops.EvaluationLoop() @@ -92,17 +92,14 @@ def reset(self) -> None: self.restarting = False else: # todo (tchaton) the batch_loop should be responsible for that. - self.batch_loop.progress.current.reset() + self.batch_loop.split_progress.current.reset() def on_run_start(self, *args: Any, **kwargs: Any) -> None: - self.progress.increment_ready() - # hook self.trainer.logger_connector.on_epoch_start() self.trainer.call_hook("on_epoch_start") self.trainer.call_hook("on_train_epoch_start") - - self.progress.increment_started() + self.trainer.fit_loop.epoch_progress.increment_started() def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None: """Runs a single training batch. @@ -122,10 +119,16 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None: with self.trainer.profiler.profile("training_batch_to_device"): batch = self.trainer.accelerator.batch_to_device(batch, dataloader_idx=self._dataloader_idx) + self.batch_progress.increment_ready() + with self.trainer.profiler.profile("run_training_batch"): batch_output = self.batch_loop.run(batch, self.iteration_count, self._dataloader_idx) + + # TODO: remove with progress tracking self.batches_seen += 1 + self.batch_progress.increment_processed() + # when returning -1 from train_step, we end epoch early if batch_output.signal == -1: raise StopIteration @@ -146,6 +149,8 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None: self.trainer.call_hook('on_batch_end') self.trainer.logger_connector.on_batch_end() + self.batch_progress.increment_completed() + # figure out what to track for epoch end self._track_epoch_end_reduce_metrics(self._epoch_output, batch_end_outputs) @@ -163,7 +168,7 @@ def on_advance_end(self): # ----------------------------------------- # VALIDATE IF NEEDED + CHECKPOINT CALLBACK # ----------------------------------------- - self.progress.should_check_val = should_check_val = self._should_check_val_fx( + should_check_val = self._should_check_val_fx( self.iteration_count, self.is_last_batch ) @@ -224,15 +229,13 @@ def on_run_end(self) -> List[List[STEP_OUTPUT]]: 'HINT: remove the return statement in training_epoch_end' ) - self.progress.increment_processed() + self.trainer.fit_loop.epoch_progress.increment_processed() # call train epoch end hooks self._on_train_epoch_end_hook(processed_outputs) self.trainer.call_hook('on_epoch_end') self.trainer.logger_connector.on_epoch_end() - self.progress.increment_completed() - epoch_output = self._epoch_output # free memory self._epoch_output = None diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 6963f4b3f2c4a..7af7b52cbea05 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -20,6 +20,7 @@ from pytorch_lightning.loops import Loop from pytorch_lightning.loops.epoch import TrainingEpochLoop from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection +from pytorch_lightning.trainer.progress import Progress from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import rank_zero_info @@ -51,6 +52,7 @@ def __init__( self.max_epochs = 1000 if (max_epochs is None and max_steps is None) else max_epochs self.min_epochs = 1 if (min_epochs is None and min_steps is None) else min_epochs self.epoch_loop = TrainingEpochLoop(min_steps, max_steps) + self.epoch_progress = Progress() @property def current_epoch(self) -> int: @@ -200,6 +202,8 @@ def on_advance_start(self) -> None: window_length=self.trainer.accumulate_grad_batches ) + self.epoch_progress.increment_ready() + def advance(self) -> None: """Runs one whole epoch.""" train_dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader) @@ -234,6 +238,8 @@ def on_advance_end(self) -> None: self._check_checkpoint_callback(True) self.global_step += 1 + self.epoch_progress.increment_completed() + def on_run_end(self) -> None: """Calls the ``on_train_end`` hook""" # NOTE: the iteration_count/current_epoch is already incremented From 1048259c81394227d5ac0bb22f3aa38762f8c0ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 14 Jul 2021 14:07:30 +0200 Subject: [PATCH 092/157] update val batch progress and remove iteration count --- .../loops/epoch/evaluation_epoch_loop.py | 21 ++++++++----------- 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index c56b4a7f097d1..9591672a7d3c7 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -41,7 +41,7 @@ def __init__(self) -> None: self.dataloader_idx: Optional[int] = None self.num_dataloaders: Optional[int] = None self.outputs: List[STEP_OUTPUT] = [] - self.progress = Progress() + self.batch_progress = Progress() def connect(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: """Connects the loop with necessary arguments like the trainer""" @@ -50,11 +50,10 @@ def connect(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: @property def done(self) -> bool: """Returns ``True`` if the current iteration count reaches the number of dataloader batches.""" - return self.iteration_count >= self.dl_max_batches + return self.batch_progress.current.completed >= self.dl_max_batches def reset(self) -> None: """Resets the loop's internal state.""" - self.iteration_count = 0 self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size) self.dl_max_batches = None self.dataloader_idx = None @@ -62,11 +61,9 @@ def reset(self) -> None: self.outputs = [] if self.restarting: - self.iteration_count = self.progress.current.completed self.restarting = False else: - self.iteration_count = 0 - self.progress.current.reset() + self.batch_progress.current.reset() def on_run_start( self, @@ -117,31 +114,31 @@ def advance( with self.trainer.profiler.profile("evaluation_batch_to_device"): batch = self.trainer.accelerator.batch_to_device(batch, dataloader_idx=dataloader_idx) - self.progress.increment_started() + self.batch_progress.increment_ready() # hook self.on_evaluation_batch_start(batch, batch_idx, dataloader_idx) - self.progress.increment_ready() + self.batch_progress.increment_started() # lightning module methods with self.trainer.profiler.profile("evaluation_step_and_end"): output = self.evaluation_step(batch, batch_idx, dataloader_idx) output = self.evaluation_step_end(output) + self.batch_progress.increment_processed() + # hook + store predictions self.on_evaluation_batch_end(output, batch, batch_idx, dataloader_idx) + self.batch_progress.increment_completed() + # log batch metrics self.trainer.logger_connector.update_eval_step_metrics() # track epoch level outputs self.outputs = self._track_output_for_epoch_end(self.outputs, output) - self.progress.increment_processed() - - self.progress.increment_completed() - def on_run_end(self) -> List[STEP_OUTPUT]: """Returns the outputs of the whole run""" outputs = self.outputs From 668a4cfaac1c524ce54fa54af7d148802ec14f31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 14 Jul 2021 14:37:05 +0200 Subject: [PATCH 093/157] update progress tracking for dataloader loops --- .../loops/dataloader/dataloader_loop.py | 23 ++++++++++++++---- .../loops/dataloader/evaluation_loop.py | 24 ++----------------- .../loops/dataloader/prediction_loop.py | 7 ------ pytorch_lightning/trainer/progress.py | 12 ++++------ 4 files changed, 25 insertions(+), 41 deletions(-) diff --git a/pytorch_lightning/loops/dataloader/dataloader_loop.py b/pytorch_lightning/loops/dataloader/dataloader_loop.py index ce255b73d0bba..1b5bf6a2402fe 100644 --- a/pytorch_lightning/loops/dataloader/dataloader_loop.py +++ b/pytorch_lightning/loops/dataloader/dataloader_loop.py @@ -13,16 +13,21 @@ # limitations under the License. from abc import abstractmethod -from typing import Sequence +from typing import Sequence, Any from torch.utils.data import DataLoader from pytorch_lightning.loops.base import Loop +from pytorch_lightning.trainer.progress import DataLoaderProgress class DataLoaderLoop(Loop): """Base class to loop over all dataloaders""" + def __init__(self): + super().__init__() + self.dataloader_progress = DataLoaderProgress() + @property @abstractmethod def dataloaders(self) -> Sequence[DataLoader]: @@ -31,7 +36,7 @@ def dataloaders(self) -> Sequence[DataLoader]: @property def current_dataloader_idx(self) -> int: """Returns the index of the current dataloader""" - return self.iteration_count + return self.dataloader_progress.current.ready - 1 @property def current_dataloader(self) -> DataLoader: @@ -46,8 +51,18 @@ def num_dataloaders(self) -> int: @property def done(self) -> bool: """Returns whether all dataloaders have been processed""" - return self.current_dataloader_idx >= self.num_dataloaders + return self.dataloader_progress.current.completed >= self.num_dataloaders def reset(self) -> None: """Resets the internal state""" - self.iteration_count = 0 + if self.restarting: + self.restarting = False + else: + # reset batch / epoch progress tracking + self.dataloader_progress.current.reset() + + def on_advance_start(self, *args: Any, **kwargs: Any) -> None: + self.dataloader_progress.increment_ready() + + def on_advance_end(self) -> None: + self.dataloader_progress.increment_completed() diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index ba554bf9c1a29..eab89eaf415b8 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -21,7 +21,6 @@ from pytorch_lightning.loops.dataloader import DataLoaderLoop from pytorch_lightning.loops.epoch import EvaluationEpochLoop from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection -from pytorch_lightning.trainer.progress import DataLoaderProgress from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import EPOCH_OUTPUT @@ -33,7 +32,6 @@ class EvaluationLoop(DataLoaderLoop): def __init__(self): super().__init__() self.outputs = [] - self.progress = DataLoaderProgress() self.epoch_loop = EvaluationEpochLoop() self._results = ResultCollection(training=False) @@ -73,7 +71,7 @@ def connect(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: @property def done(self) -> bool: """Returns whether all dataloaders are processed or evaluation should be skipped altogether""" - return (self.current_dataloader_idx >= len(self.dataloaders)) or self.skip + return super().done or self.skip @property def skip(self) -> bool: @@ -83,7 +81,6 @@ def skip(self) -> bool: def reset(self) -> None: """Resets the internal state of the loop""" - self.iteration_count = 0 self._max_batches = self.get_max_batches() # bookkeeping self.outputs = [] @@ -91,13 +88,7 @@ def reset(self) -> None: if isinstance(self._max_batches, int): self._max_batches = [self._max_batches] * len(self.dataloaders) - if self.restarting: - self.iteration_count = self.progress.dataloader_idx - self.restarting = False - else: - self.iteration_count = 0 - # reset batch / epoch progress tracking - self.progress.current.reset() + super().reset() def on_skip(self) -> List: return [] @@ -105,17 +96,12 @@ def on_skip(self) -> List: def on_run_start(self, *args: Any, **kwargs: Any) -> None: """Runs the ``on_evaluation_model_eval``, ``on_evaluation_start`` and ``on_evaluation_epoch_start`` hooks""" void(*args, **kwargs) - - self.progress.increment_started() - # hook self.on_evaluation_model_eval() self.trainer.lightning_module.zero_grad() self.on_evaluation_start() self.on_evaluation_epoch_start() - self.progress.increment_ready() - def advance(self, *args: Any, **kwargs: Any) -> None: """Performs evaluation on one single dataloader""" void(*args, **kwargs) @@ -123,8 +109,6 @@ def advance(self, *args: Any, **kwargs: Any) -> None: dataloader_iter = enumerate(dataloader) dl_max_batches = self._max_batches[self.current_dataloader_idx] - self.progress.dataloader_idx = self.iteration_count - dl_outputs = self.epoch_loop.run( dataloader_iter, self.current_dataloader_idx, @@ -151,8 +135,6 @@ def on_run_end(self) -> Any: if len(outputs) > 0 and self.num_dataloaders == 1: outputs = outputs[0] - self.progress.increment_processed() - # lightning module method self.evaluation_epoch_end(outputs) @@ -171,8 +153,6 @@ def on_run_end(self) -> Any: # enable train mode again self.on_evaluation_model_train() - self.progress.increment_completed() - return eval_loop_results def get_max_batches(self) -> List[Union[int, float]]: diff --git a/pytorch_lightning/loops/dataloader/prediction_loop.py b/pytorch_lightning/loops/dataloader/prediction_loop.py index 51eccdf202051..e1de8669ddf68 100644 --- a/pytorch_lightning/loops/dataloader/prediction_loop.py +++ b/pytorch_lightning/loops/dataloader/prediction_loop.py @@ -7,7 +7,6 @@ from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop from pytorch_lightning.loops.epoch.prediction_epoch_loop import PredictionEpochLoop from pytorch_lightning.plugins import DDPSpawnPlugin -from pytorch_lightning.trainer.progress import DataLoaderProgress from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.types import _PREDICT_OUTPUT @@ -19,7 +18,6 @@ def __init__(self): super().__init__() self.predictions: Optional[List[List[Any]]] = None self.epoch_batch_indices: Optional[List[List[int]]] = None - self.progress = DataLoaderProgress() self.epoch_loop = PredictionEpochLoop() self._results = None # for `trainer._results` access @@ -66,11 +64,6 @@ def dataloaders(self) -> Sequence[DataLoader]: """Returns all prediction dataloaders""" return self.trainer.predict_dataloaders - @property - def done(self) -> bool: - """Whether prediction is finished: Max batches run or all dataloaders processed""" - return self.current_dataloader_idx >= len(self.dataloaders) - @property def skip(self) -> bool: return sum(self.max_batches) == 0 diff --git a/pytorch_lightning/trainer/progress.py b/pytorch_lightning/trainer/progress.py index 5f153f45002c0..b34b448d2a0e9 100644 --- a/pytorch_lightning/trainer/progress.py +++ b/pytorch_lightning/trainer/progress.py @@ -153,15 +153,11 @@ class DataLoaderProgress(Progress): These counters are local to a trainer rank. By default, they are not globally synced across all ranks. Args: - total: Tracks the total epoch progress - current: Tracks the current epoch progress - dataloader_idx: The index of the current dataloader. + total: Tracks the total dataloader progress + current: Tracks the current dataloader progress """ - dataloader_idx: int = 0 - - def load_state_dict(self, state_dict: dict) -> None: - super().load_state_dict(state_dict) - self.dataloader_idx = state_dict["dataloader_idx"] + total: Tracker = field(default_factory=lambda: Tracker(started=None, processed=None)) + current: Tracker = field(default_factory=lambda: Tracker(started=None, processed=None)) @dataclass From ad8b342b593bf0de93b0766582df8d144311efaa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 14 Jul 2021 14:42:29 +0200 Subject: [PATCH 094/157] remove self.dataloader_idx from eval_epoch_loop --- pytorch_lightning/loops/epoch/evaluation_epoch_loop.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index 9591672a7d3c7..757df0fb29ef6 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -38,7 +38,6 @@ def __init__(self) -> None: self.predictions: Optional[PredictionCollection] = None self.dataloader: Optional[Iterator] = None self.dl_max_batches: Optional[int] = None - self.dataloader_idx: Optional[int] = None self.num_dataloaders: Optional[int] = None self.outputs: List[STEP_OUTPUT] = [] self.batch_progress = Progress() @@ -56,7 +55,6 @@ def reset(self) -> None: """Resets the loop's internal state.""" self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size) self.dl_max_batches = None - self.dataloader_idx = None self.num_dataloaders = None self.outputs = [] @@ -80,10 +78,8 @@ def on_run_start( dl_max_batches: maximum number of batches the dataloader can produce num_dataloaders: the total number of dataloaders """ - void(dataloader_iter) - + void(dataloader_iter, dataloader_idx) self.dl_max_batches = dl_max_batches - self.dataloader_idx = dataloader_idx self.num_dataloaders = num_dataloaders def advance( From 512ee0d51c8cf3098fb88c7b9c13610ba862edbb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 14 Jul 2021 14:54:27 +0200 Subject: [PATCH 095/157] add batch progress to predict loop --- .../loops/epoch/prediction_epoch_loop.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py index f94e106a8c444..da1aa0e42f210 100644 --- a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py @@ -17,7 +17,7 @@ def __init__(self) -> None: self.return_predictions: bool = False self.predictions: List[Any] = [] self.current_batch_indices: List[int] = [] - self.progress = Progress() + self.batch_progress = Progress() self._dl_max_batches: Optional[int] = None self._num_dataloaders: Optional[int] = None @@ -27,7 +27,7 @@ def __init__(self) -> None: @property def done(self) -> bool: """Ends prediction when the iteration count exceeds the total number of available batches""" - return self.iteration_count >= self._dl_max_batches + return self.batch_progress.current.completed >= self._dl_max_batches @property def should_store_predictions(self) -> bool: @@ -37,9 +37,9 @@ def should_store_predictions(self) -> bool: def reset(self) -> None: """Resets the loops internal state""" - self.iteration_count = 0 self._all_batch_indices: List[int] = [] self.predictions: List[Any] = [] + self.batch_progress.current.reset() def on_run_start( self, @@ -89,6 +89,8 @@ def advance( with self.trainer.profiler.profile("predict_batch_to_device"): batch = self.trainer.accelerator.batch_to_device(batch, dataloader_idx=dataloader_idx) + self.batch_progress.increment_ready() + with self.trainer.profiler.profile("predict_step"): self._predict_step(batch, batch_idx, dataloader_idx) @@ -120,14 +122,20 @@ def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None self.trainer.call_hook("on_predict_batch_start", batch, batch_idx, dataloader_idx) + self.batch_progress.increment_started() + model_ref._current_fx_name = "predict_step" predictions = self.trainer.accelerator.predict_step(step_kwargs) + self.batch_progress.increment_processed() + if predictions is None: self._warning_cache.warn("predict returned None if it was on purpose, ignore this warning...") self.trainer.call_hook("on_predict_batch_end", predictions, batch, batch_idx, dataloader_idx) + self.batch_progress.increment_completed() + if self.should_store_predictions: self.predictions.append(predictions) From 2633d515673201a4f2f805c3d1f0c82f92b6e19a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 14 Jul 2021 12:56:18 +0000 Subject: [PATCH 096/157] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/loops/dataloader/dataloader_loop.py | 2 +- pytorch_lightning/loops/epoch/training_epoch_loop.py | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/loops/dataloader/dataloader_loop.py b/pytorch_lightning/loops/dataloader/dataloader_loop.py index 1b5bf6a2402fe..d8bdd67b41c17 100644 --- a/pytorch_lightning/loops/dataloader/dataloader_loop.py +++ b/pytorch_lightning/loops/dataloader/dataloader_loop.py @@ -13,7 +13,7 @@ # limitations under the License. from abc import abstractmethod -from typing import Sequence, Any +from typing import Any, Sequence from torch.utils.data import DataLoader diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 0ff4613162c4f..393dbd02ec824 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -20,7 +20,7 @@ from pytorch_lightning import loops # import as loops to avoid circular imports from pytorch_lightning.loops.batch import TrainingBatchLoop from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection -from pytorch_lightning.trainer.progress import TrainingEpochProgress, Progress +from pytorch_lightning.trainer.progress import Progress, TrainingEpochProgress from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature @@ -168,9 +168,7 @@ def on_advance_end(self): # ----------------------------------------- # VALIDATE IF NEEDED + CHECKPOINT CALLBACK # ----------------------------------------- - should_check_val = self._should_check_val_fx( - self.iteration_count, self.is_last_batch - ) + should_check_val = self._should_check_val_fx(self.iteration_count, self.is_last_batch) if should_check_val: self.trainer.validating = True From 4bbc7acf1062c25efb81d36b20d5958669e79c2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 14 Jul 2021 15:34:38 +0200 Subject: [PATCH 097/157] incorporate progress tracking for current_epoch --- pytorch_lightning/loops/base.py | 1 + .../loops/epoch/training_epoch_loop.py | 3 +-- pytorch_lightning/loops/fit_loop.py | 21 +++++++++---------- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 9209dcb993284..f67447fc19a03 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -46,6 +46,7 @@ class Loop(ABC): """ def __init__(self) -> None: + # TODO: replace by progress tracking self.iteration_count: int = 0 self.restarting = False self._trainer: Optional['pl.Trainer'] = None diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 393dbd02ec824..ee0e399a6c98b 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -38,11 +38,10 @@ def __init__(self, min_steps: int, max_steps: int): self.global_step: int = 0 # the total batch index across all epochs self.total_batch_idx: int = 0 - # the current batch index in the loop that runs over the dataloader(s) - self.iteration_count: int = 0 # the current split index when the batch gets split into chunks in truncated backprop through time self.split_idx: Optional[int] = None # the number of batches seen this run, updates immediately after batch_loop.run() + # TODO: replace by progress tracking self.batches_seen: int = 0 self.is_last_batch: Optional[bool] = None self.batch_progress = Progress() diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 7af7b52cbea05..21087d73a4662 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -57,12 +57,12 @@ def __init__( @property def current_epoch(self) -> int: """Return the current epoch""" - return self.iteration_count + return self.epoch_progress.current.completed @current_epoch.setter def current_epoch(self, value: int) -> None: """Setter for the current epoch""" - self.iteration_count = value + self.epoch_progress.current.completed = value @property def global_step(self) -> int: @@ -82,7 +82,7 @@ def total_batch_idx(self) -> int: @property def batch_idx(self) -> int: """Returns the number of batches already run within this epoch""" - return self.epoch_loop.iteration_count + return self.epoch_loop.batch_progress.current.ready - 1 @property def split_idx(self) -> int: @@ -227,16 +227,15 @@ def advance(self) -> None: def on_advance_end(self) -> None: """Updates the LR schedulers and does some internal bookkeeping""" - if self.epoch_loop.batches_seen == 0: - return - self.epoch_loop.update_lr_schedulers('epoch', update_plateau_schedulers=True) + if self.epoch_loop.batches_seen != 0: + self.epoch_loop.update_lr_schedulers('epoch', update_plateau_schedulers=True) - did_train_only = not self.trainer.enable_validation or self.epoch_loop.val_loop.skip - if did_train_only: - self.global_step -= 1 - self._check_checkpoint_callback(True) - self.global_step += 1 + did_train_only = not self.trainer.enable_validation or self.epoch_loop.val_loop.skip + if did_train_only: + self.global_step -= 1 + self._check_checkpoint_callback(True) + self.global_step += 1 self.epoch_progress.increment_completed() From 01f87145ce30bc63c06ffc9927d40cf56a4bdb70 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 14 Jul 2021 17:20:40 +0200 Subject: [PATCH 098/157] Fix test --- tests/loops/test_loop_progress_integration.py | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/tests/loops/test_loop_progress_integration.py b/tests/loops/test_loop_progress_integration.py index 4395cb5cdcf3b..32eac6d037a87 100644 --- a/tests/loops/test_loop_progress_integration.py +++ b/tests/loops/test_loop_progress_integration.py @@ -5,12 +5,16 @@ def test_loop_progress_integration(): trainer = Trainer() # check no progresses are shared - assert trainer.validate_loop.progress is not trainer.test_loop.progress - assert trainer.test_loop.progress is not trainer.predict_loop.progress + assert trainer.fit_loop.epoch_progress is not trainer.validate_loop.dataloader_progress + assert trainer.validate_loop.dataloader_progress is not trainer.test_loop.dataloader_progress + assert trainer.test_loop.dataloader_progress is not trainer.predict_loop.dataloader_progress # check the validation progresses are not shared - assert trainer.fit_loop.epoch_loop.val_loop.progress is not trainer.validate_loop.progress - generated = _collect_loop_progress(trainer.fit_loop)["epoch_loop"] - assert generated["progress"] is trainer.fit_loop.epoch_loop.progress - assert generated["batch_loop"]["progress"] is trainer.fit_loop.epoch_loop.batch_loop.progress - assert generated["val_loop"]["progress"] is trainer.fit_loop.epoch_loop.val_loop.progress - assert generated["val_loop"]["epoch_loop"]["progress"] is trainer.fit_loop.epoch_loop.val_loop.epoch_loop.progress + assert trainer.fit_loop.epoch_loop.val_loop.dataloader_progress is not trainer.validate_loop.dataloader_progress + # check recursive collection of progresses + progresses = _collect_loop_progress(trainer.fit_loop) + assert progresses["epoch_progress"] is trainer.fit_loop.epoch_progress + assert progresses["epoch_loop"]["batch_progress"] is trainer.fit_loop.epoch_loop.batch_progress + assert progresses["epoch_loop"]["val_loop"]["dataloader_progress" + ] is trainer.fit_loop.epoch_loop.val_loop.dataloader_progress + assert progresses["epoch_loop"]["val_loop"]["epoch_loop"][ + "batch_progress"] is trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress From 65405b806a116316b49b7eac69865043b4fc94db Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 14 Jul 2021 17:25:24 +0200 Subject: [PATCH 099/157] Actually remove it --- tests/loops/test_loop_progress_integration.py | 20 ------------------- 1 file changed, 20 deletions(-) delete mode 100644 tests/loops/test_loop_progress_integration.py diff --git a/tests/loops/test_loop_progress_integration.py b/tests/loops/test_loop_progress_integration.py deleted file mode 100644 index 32eac6d037a87..0000000000000 --- a/tests/loops/test_loop_progress_integration.py +++ /dev/null @@ -1,20 +0,0 @@ -from pytorch_lightning import Trainer -from tests.loops.test_loops import _collect_loop_progress - - -def test_loop_progress_integration(): - trainer = Trainer() - # check no progresses are shared - assert trainer.fit_loop.epoch_progress is not trainer.validate_loop.dataloader_progress - assert trainer.validate_loop.dataloader_progress is not trainer.test_loop.dataloader_progress - assert trainer.test_loop.dataloader_progress is not trainer.predict_loop.dataloader_progress - # check the validation progresses are not shared - assert trainer.fit_loop.epoch_loop.val_loop.dataloader_progress is not trainer.validate_loop.dataloader_progress - # check recursive collection of progresses - progresses = _collect_loop_progress(trainer.fit_loop) - assert progresses["epoch_progress"] is trainer.fit_loop.epoch_progress - assert progresses["epoch_loop"]["batch_progress"] is trainer.fit_loop.epoch_loop.batch_progress - assert progresses["epoch_loop"]["val_loop"]["dataloader_progress" - ] is trainer.fit_loop.epoch_loop.val_loop.dataloader_progress - assert progresses["epoch_loop"]["val_loop"]["epoch_loop"][ - "batch_progress"] is trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress From 6dd2182a97e91984435208e96457b3c6813962b7 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 14 Jul 2021 17:38:14 +0200 Subject: [PATCH 100/157] Remove unused TrainingEpochProgress --- .../loops/epoch/training_epoch_loop.py | 2 +- pytorch_lightning/trainer/progress.py | 16 ---------------- tests/trainer/test_progress.py | 2 -- 3 files changed, 1 insertion(+), 19 deletions(-) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index ee0e399a6c98b..eb0f4040a6102 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -20,7 +20,7 @@ from pytorch_lightning import loops # import as loops to avoid circular imports from pytorch_lightning.loops.batch import TrainingBatchLoop from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection -from pytorch_lightning.trainer.progress import Progress, TrainingEpochProgress +from pytorch_lightning.trainer.progress import Progress from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature diff --git a/pytorch_lightning/trainer/progress.py b/pytorch_lightning/trainer/progress.py index b34b448d2a0e9..32e2ba0ea9b98 100644 --- a/pytorch_lightning/trainer/progress.py +++ b/pytorch_lightning/trainer/progress.py @@ -130,22 +130,6 @@ def load_state_dict(self, state_dict: dict) -> None: self.current.load_state_dict(state_dict["current"]) -@dataclass -class TrainingEpochProgress(Progress): - """ - Tracks the epoch progress - - Args: - total: Tracks the total epoch progress - current: Tracks the current epoch progress - """ - should_check_val: bool = False - - def load_state_dict(self, state_dict: dict) -> None: - super().load_state_dict(state_dict) - self.should_check_val = state_dict["should_check_val"] - - @dataclass class DataLoaderProgress(Progress): """ diff --git a/tests/trainer/test_progress.py b/tests/trainer/test_progress.py index 32d97be167b52..7205b72ad3963 100644 --- a/tests/trainer/test_progress.py +++ b/tests/trainer/test_progress.py @@ -26,7 +26,6 @@ OptimizerProgress, Progress, Tracker, - TrainingEpochProgress, ) from tests.helpers import BoringModel @@ -116,7 +115,6 @@ def test_optimizer_progress_default_factory(): def test_deepcopy(): _ = deepcopy(Tracker()) _ = deepcopy(Progress()) - _ = deepcopy(TrainingEpochProgress()) _ = deepcopy(DataLoaderProgress()) _ = deepcopy(OptimizerProgress()) _ = deepcopy(OptimizationProgress()) From b71e1516653901e3a222a82d4c56520b77046f75 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 14 Jul 2021 18:09:45 +0200 Subject: [PATCH 101/157] Fix optimization progress - missing scheduler --- pytorch_lightning/core/lightning.py | 3 +- .../loops/batch/training_batch_loop.py | 15 ++-------- .../trainer/connectors/optimizer_connector.py | 4 +-- pytorch_lightning/trainer/progress.py | 29 +++++++++++++------ 4 files changed, 26 insertions(+), 25 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 735f8ab160c1f..aeed3c9304a76 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -31,6 +31,7 @@ from torch.optim.optimizer import Optimizer from torchmetrics import Metric +import pytorch_lightning as pl from pytorch_lightning.core.grads import GradInformation from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks from pytorch_lightning.core.memory import ModelSummary @@ -89,7 +90,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: torch._C._log_api_usage_once(f"lightning.module.{self.__class__.__name__}") # pointer to the trainer object - self.trainer = None + self.trainer: Optional['pl.Trainer'] = None self._distrib_type = None self._device_type = None diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 912c4f8fbe548..26adb3234f44c 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -245,11 +245,6 @@ def _training_step_and_backward_closure( result = self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens) if result is not None: return_result.update(result) - - # this should be done only if result.loss exists and ``optimizer step`` is being run - if not self.should_accumulate(): - self.optim_progress.optimizer.step.increment_started() - return return_result.loss def _make_closure(self, *closure_args: Any, **closure_kwargs: Any) -> Callable: @@ -419,7 +414,8 @@ def _optimizer_step( using_lbfgs=is_lbfgs, ) - self.optim_progress.optimizer.step.increment_processed() + # FIXME: why does it not fail? + # self.optim_progress.optimizer.step.increment_processed() self.optim_progress.optimizer.step.increment_completed() def _on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None: @@ -441,7 +437,6 @@ def _optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer, opt_idx: the index of the current optimizer """ self.trainer.accelerator.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx) - self.optim_progress.optimizer.zero_grad.increment_completed() def _track_and_norm_grad(self, optimizer: torch.optim.Optimizer) -> Dict[str, Tensor]: @@ -701,9 +696,3 @@ def _truncated_bptt_steps(self) -> int: if lightning_module.truncated_bptt_steps > 0: return lightning_module.truncated_bptt_steps return self.trainer.truncated_bptt_steps or 0 - - def increment_scheduler_ready(self): - self.optim_progress.scheduler.increment_ready() - - def increment_scheduler_completed(self): - self.optim_progress.scheduler.increment_completed() diff --git a/pytorch_lightning/trainer/connectors/optimizer_connector.py b/pytorch_lightning/trainer/connectors/optimizer_connector.py index 16b751e7db4b9..4c49b6e028cb4 100644 --- a/pytorch_lightning/trainer/connectors/optimizer_connector.py +++ b/pytorch_lightning/trainer/connectors/optimizer_connector.py @@ -83,7 +83,7 @@ def update_learning_rates( # update LR old_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr'] - self.trainer.fit_loop.epoch_loop.batch_loop.increment_scheduler_ready() + self.trainer.fit_loop.epoch_loop.scheduler_progress.increment_ready() if lr_scheduler['reduce_on_plateau']: lr_scheduler['scheduler'].step(monitor_val) @@ -92,7 +92,7 @@ def update_learning_rates( new_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr'] - self.trainer.fit_loop.epoch_loop.batch_loop.increment_scheduler_completed() + self.trainer.fit_loop.epoch_loop.scheduler_progress.increment_completed() if self.trainer.dev_debugger.enabled: self.trainer.dev_debugger.track_lr_schedulers_update( diff --git a/pytorch_lightning/trainer/progress.py b/pytorch_lightning/trainer/progress.py index 32e2ba0ea9b98..99895266e6b3c 100644 --- a/pytorch_lightning/trainer/progress.py +++ b/pytorch_lightning/trainer/progress.py @@ -154,7 +154,7 @@ class OptimizerProgress(BaseProgress): zero_grad: Tracks ``optimizer.zero_grad`` calls. """ - step: Progress = field(default_factory=lambda: Progress.from_defaults(processed=None)) + step: Progress = field(default_factory=lambda: Progress.from_defaults(started=None, processed=None)) zero_grad: Progress = field(default_factory=lambda: Progress.from_defaults(processed=None)) def reset_on_epoch(self) -> None: @@ -173,28 +173,39 @@ class OptimizationProgress(BaseProgress): Args: optimizer: Tracks optimizer progress. - scheduler: Tracks scheduler progress. optimizer_idx: The index of the current optimizer. """ # TODO: support for multiple optimizers optimizer: OptimizerProgress = field(default_factory=OptimizerProgress) - scheduler: Progress = field(default_factory=lambda: Progress.from_defaults(started=None, processed=None)) optimizer_idx: int = 0 @property def optimizer_steps(self) -> int: return self.optimizer.step.total.completed + def reset_on_epoch(self) -> None: + self.optimizer.current.reset() + self.optimizer_idx = 0 + + def load_state_dict(self, state_dict: dict) -> None: + self.optimizer.load_state_dict(state_dict["optimizer"]) + self.optimizer_idx = state_dict["optimizer_idx"] + + +class SchedulerProgress(BaseProgress): + """ + Track scheduler progress. + + Args: + scheduler: Tracks scheduler progress. + """ + + scheduler: Progress = field(default_factory=lambda: Progress.from_defaults(started=None, processed=None)) + @property def scheduler_steps(self) -> int: return self.scheduler.total.completed - def reset_on_epoch(self) -> None: - self.optimizer.reset_on_epoch() - self.scheduler.current.reset() - def load_state_dict(self, state_dict: dict) -> None: - self.optimizer.load_state_dict(state_dict["optimizer"]) self.scheduler.load_state_dict(state_dict["scheduler"]) - self.optimizer_idx = state_dict["optimizer_idx"] From e5a392a4ee30c7dd9115f1bbece66e7ea9e715d9 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 14 Jul 2021 18:14:03 +0200 Subject: [PATCH 102/157] Restarting changes --- pytorch_lightning/loops/base.py | 2 ++ pytorch_lightning/loops/batch/training_batch_loop.py | 1 - pytorch_lightning/loops/dataloader/dataloader_loop.py | 4 +--- pytorch_lightning/loops/epoch/evaluation_epoch_loop.py | 4 +--- pytorch_lightning/loops/epoch/training_epoch_loop.py | 5 +---- tests/loops/test_loops.py | 6 +++--- 6 files changed, 8 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index f67447fc19a03..66844a9dc5ddd 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -113,6 +113,8 @@ def run(self, *args: Any, **kwargs: Any) -> Optional[Any]: self.advance(*args, **kwargs) self.on_advance_end() self.iteration_count += 1 + if self.restarting: + self.restarting = False except StopIteration: break diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 26adb3234f44c..bd82659fe5476 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -149,7 +149,6 @@ def advance(self, batch, batch_idx, dataloader_idx): if self.restarting: if opt_idx < self.optim_progress.optimizer_idx: continue - self.restarting = False # track optimizer_idx self.optim_progress.optimizer_idx = opt_idx diff --git a/pytorch_lightning/loops/dataloader/dataloader_loop.py b/pytorch_lightning/loops/dataloader/dataloader_loop.py index d8bdd67b41c17..ed7f776fcad7d 100644 --- a/pytorch_lightning/loops/dataloader/dataloader_loop.py +++ b/pytorch_lightning/loops/dataloader/dataloader_loop.py @@ -55,9 +55,7 @@ def done(self) -> bool: def reset(self) -> None: """Resets the internal state""" - if self.restarting: - self.restarting = False - else: + if not self.restarting: # reset batch / epoch progress tracking self.dataloader_progress.current.reset() diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index 757df0fb29ef6..1c76d33acd404 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -58,9 +58,7 @@ def reset(self) -> None: self.num_dataloaders = None self.outputs = [] - if self.restarting: - self.restarting = False - else: + if not self.restarting: self.batch_progress.current.reset() def on_run_start( diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index eb0f4040a6102..cae9093291f44 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -85,10 +85,7 @@ def reset(self) -> None: self._epoch_output = [[] for _ in range(self.batch_loop.num_active_optimizers(self.total_batch_idx))] if self.restarting: - self.iteration_count = self.batch_loop.current_batch_completed - self.batches_seen = self.batch_loop.current_batch_completed - # restarting is finished. - self.restarting = False + self.iteration_count = self.batches_seen = self.batch_progress.current.completed else: # todo (tchaton) the batch_loop should be responsible for that. self.batch_loop.split_progress.current.reset() diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 70e2ca7a62d3e..34828bf8d59f1 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -57,7 +57,7 @@ def reset(self) -> None: for _ in range(self.iteration_count): next(self.iter_dataset) self.iteration_count += 1 - self.restarting = False + # self.restarting = False else: self.outputs = [] @@ -132,8 +132,8 @@ def skip(self) -> bool: def done(self) -> bool: return self.iteration_count > 0 - def reset(self) -> None: - self.restarting = False + # def reset(self) -> None: + # self.restarting = False def on_save_checkpoint(self) -> Dict: return {"a": self.a} From 49c511277f8275175024e232d225fc37be786219 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 14 Jul 2021 18:39:59 +0200 Subject: [PATCH 103/157] Scheduler progress --- .../loops/batch/training_batch_loop.py | 2 +- .../loops/epoch/training_epoch_loop.py | 5 ++- pytorch_lightning/loops/fit_loop.py | 4 +- .../trainer/connectors/optimizer_connector.py | 2 +- pytorch_lightning/trainer/progress.py | 38 +++++++++---------- 5 files changed, 26 insertions(+), 25 deletions(-) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index bd82659fe5476..7cd7c76a2f6bd 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -150,7 +150,6 @@ def advance(self, batch, batch_idx, dataloader_idx): if opt_idx < self.optim_progress.optimizer_idx: continue - # track optimizer_idx self.optim_progress.optimizer_idx = opt_idx result = self._run_optimization(batch_idx, split_batch, opt_idx, optimizer) @@ -416,6 +415,7 @@ def _optimizer_step( # FIXME: why does it not fail? # self.optim_progress.optimizer.step.increment_processed() self.optim_progress.optimizer.step.increment_completed() + self.trainer.fit_loop.epoch_loop.scheduler_progress.increment_ready() def _on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None: """Calls the ``on_before_zero_grad`` hook. diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index cae9093291f44..0a3ba558e2f64 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -20,7 +20,7 @@ from pytorch_lightning import loops # import as loops to avoid circular imports from pytorch_lightning.loops.batch import TrainingBatchLoop from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection -from pytorch_lightning.trainer.progress import Progress +from pytorch_lightning.trainer.progress import Progress, SchedulerProgress from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature @@ -45,6 +45,7 @@ def __init__(self, min_steps: int, max_steps: int): self.batches_seen: int = 0 self.is_last_batch: Optional[bool] = None self.batch_progress = Progress() + self.scheduler_progress = SchedulerProgress() self.batch_loop = TrainingBatchLoop() self.val_loop = loops.EvaluationLoop() @@ -230,6 +231,8 @@ def on_run_end(self) -> List[List[STEP_OUTPUT]]: self.trainer.call_hook('on_epoch_end') self.trainer.logger_connector.on_epoch_end() + self.update_lr_schedulers('epoch', update_plateau_schedulers=True) + epoch_output = self._epoch_output # free memory self._epoch_output = None diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 21087d73a4662..75dacdcec4ba9 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -222,15 +222,13 @@ def advance(self) -> None: # TODO(@carmocca): deprecate and rename so users don't get confused self.global_step -= 1 # log epoch metrics + # FIXME: was this wrong??? self.trainer.logger_connector.update_train_epoch_metrics() self.global_step += 1 def on_advance_end(self) -> None: """Updates the LR schedulers and does some internal bookkeeping""" - if self.epoch_loop.batches_seen != 0: - self.epoch_loop.update_lr_schedulers('epoch', update_plateau_schedulers=True) - did_train_only = not self.trainer.enable_validation or self.epoch_loop.val_loop.skip if did_train_only: self.global_step -= 1 diff --git a/pytorch_lightning/trainer/connectors/optimizer_connector.py b/pytorch_lightning/trainer/connectors/optimizer_connector.py index 4c49b6e028cb4..9939901832c0e 100644 --- a/pytorch_lightning/trainer/connectors/optimizer_connector.py +++ b/pytorch_lightning/trainer/connectors/optimizer_connector.py @@ -83,7 +83,7 @@ def update_learning_rates( # update LR old_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr'] - self.trainer.fit_loop.epoch_loop.scheduler_progress.increment_ready() + self.trainer.fit_loop.epoch_loop.scheduler_progress.increment_started() if lr_scheduler['reduce_on_plateau']: lr_scheduler['scheduler'].step(monitor_val) diff --git a/pytorch_lightning/trainer/progress.py b/pytorch_lightning/trainer/progress.py index 99895266e6b3c..043c6ebe9d4c4 100644 --- a/pytorch_lightning/trainer/progress.py +++ b/pytorch_lightning/trainer/progress.py @@ -133,7 +133,7 @@ def load_state_dict(self, state_dict: dict) -> None: @dataclass class DataLoaderProgress(Progress): """ - Tracks the data-loader progress + Tracks the dataloader progress These counters are local to a trainer rank. By default, they are not globally synced across all ranks. Args: @@ -144,6 +144,24 @@ class DataLoaderProgress(Progress): current: Tracker = field(default_factory=lambda: Tracker(started=None, processed=None)) +class SchedulerProgress(Progress): + """ + Tracks the scheduler progress + These counters are local to a trainer rank. By default, they are not globally synced across all ranks. + + Args: + total: Tracks the total scheduler progress + current: Tracks the current scheduler progress + """ + + total: Tracker = field(default_factory=lambda: Tracker(started=None, processed=None)) + current: Tracker = field(default_factory=lambda: Tracker(started=None, processed=None)) + + @property + def scheduler_steps(self) -> int: + return self.total.completed + + @dataclass class OptimizerProgress(BaseProgress): """ @@ -191,21 +209,3 @@ def reset_on_epoch(self) -> None: def load_state_dict(self, state_dict: dict) -> None: self.optimizer.load_state_dict(state_dict["optimizer"]) self.optimizer_idx = state_dict["optimizer_idx"] - - -class SchedulerProgress(BaseProgress): - """ - Track scheduler progress. - - Args: - scheduler: Tracks scheduler progress. - """ - - scheduler: Progress = field(default_factory=lambda: Progress.from_defaults(started=None, processed=None)) - - @property - def scheduler_steps(self) -> int: - return self.scheduler.total.completed - - def load_state_dict(self, state_dict: dict) -> None: - self.scheduler.load_state_dict(state_dict["scheduler"]) From 018da6a530ab268295177dadcf54de96a25e1f12 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 14 Jul 2021 18:44:08 +0200 Subject: [PATCH 104/157] Unused property, reset on epoch --- pytorch_lightning/trainer/progress.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/progress.py b/pytorch_lightning/trainer/progress.py index 043c6ebe9d4c4..be78d7b8208a2 100644 --- a/pytorch_lightning/trainer/progress.py +++ b/pytorch_lightning/trainer/progress.py @@ -157,10 +157,6 @@ class SchedulerProgress(Progress): total: Tracker = field(default_factory=lambda: Tracker(started=None, processed=None)) current: Tracker = field(default_factory=lambda: Tracker(started=None, processed=None)) - @property - def scheduler_steps(self) -> int: - return self.total.completed - @dataclass class OptimizerProgress(BaseProgress): @@ -203,7 +199,7 @@ def optimizer_steps(self) -> int: return self.optimizer.step.total.completed def reset_on_epoch(self) -> None: - self.optimizer.current.reset() + self.optimizer.reset_on_epoch() self.optimizer_idx = 0 def load_state_dict(self, state_dict: dict) -> None: From 0b1834c6b0726f25144c64d9056d4f7648982de2 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 14 Jul 2021 19:11:45 +0200 Subject: [PATCH 105/157] Resolve FIXME --- pytorch_lightning/loops/batch/training_batch_loop.py | 2 -- pytorch_lightning/trainer/progress.py | 8 -------- 2 files changed, 10 deletions(-) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 7cd7c76a2f6bd..60f43f5ba8825 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -412,8 +412,6 @@ def _optimizer_step( using_lbfgs=is_lbfgs, ) - # FIXME: why does it not fail? - # self.optim_progress.optimizer.step.increment_processed() self.optim_progress.optimizer.step.increment_completed() self.trainer.fit_loop.epoch_loop.scheduler_progress.increment_ready() diff --git a/pytorch_lightning/trainer/progress.py b/pytorch_lightning/trainer/progress.py index be78d7b8208a2..4410cf3901ff2 100644 --- a/pytorch_lightning/trainer/progress.py +++ b/pytorch_lightning/trainer/progress.py @@ -98,26 +98,18 @@ class Progress(BaseProgress): current: Tracker = field(default_factory=Tracker) def increment_ready(self) -> None: - if self.total.ready is None or self.current.ready is None: - return self.total.ready += 1 self.current.ready += 1 def increment_started(self) -> None: - if self.total.started is None or self.current.started is None: - return self.total.started += 1 self.current.started += 1 def increment_processed(self) -> None: - if self.total.processed is None or self.current.processed is None: - return self.total.processed += 1 self.current.processed += 1 def increment_completed(self) -> None: - if self.total.completed is None or self.current.completed is None: - return self.total.completed += 1 self.current.completed += 1 From d7bcafa883d538b83b73db33fa6bf915937bba10 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 14 Jul 2021 19:12:30 +0200 Subject: [PATCH 106/157] Remove FIXME --- pytorch_lightning/loops/fit_loop.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 75dacdcec4ba9..d2f0dfe954a6f 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -222,7 +222,6 @@ def advance(self) -> None: # TODO(@carmocca): deprecate and rename so users don't get confused self.global_step -= 1 # log epoch metrics - # FIXME: was this wrong??? self.trainer.logger_connector.update_train_epoch_metrics() self.global_step += 1 From e794fbe746b5fd09a7d1b89629a2c5ad3747a6b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 14 Jul 2021 23:29:56 +0200 Subject: [PATCH 107/157] fix test_progress (wip) --- tests/trainer/test_progress.py | 109 +++++++++++++++++++-------------- 1 file changed, 62 insertions(+), 47 deletions(-) diff --git a/tests/trainer/test_progress.py b/tests/trainer/test_progress.py index 7205b72ad3963..c6e4a5d266be0 100644 --- a/tests/trainer/test_progress.py +++ b/tests/trainer/test_progress.py @@ -34,7 +34,7 @@ class CustomException(BaseException): pass -def test_progress_geattr_setattr(): +def test_progress_getattr_setattr(): p = Tracker(ready=10, completed=None) # can read assert p.completed is None @@ -134,7 +134,7 @@ def __init__(self): self.should_fail = True def training_step(self, batch, batch_idx, optimizer_idx: int = None): - # breaking on global_step 4 + # simulate failure during the the 5-th training step, 2nd epoch (global_step = 4) if self.should_fail and self.trainer.current_epoch == 1 and batch_idx == 1 and optimizer_idx == ( 1 if use_multiple_optimizers else None ): @@ -177,7 +177,7 @@ def configure_optimizers_3(self): # VALIDATE CHECKPOINT # ####################### - checkpoint = torch.load(trainer.checkpoint_callback.last_model_path) + checkpoint = torch.load(str(tmpdir / ".pl_auto_save.ckpt")) num_epochs = 1 num_batches = 4 @@ -188,13 +188,13 @@ def configure_optimizers_3(self): total_optimizer_step = (4 * num_optimizers + (1 if use_multiple_optimizers else 0)) // accumulate_grad_batches # we raised expection on the first optimizer - current_optimize_step = (1 if use_multiple_optimizers else 0) + current_optimizer_step = (1 if use_multiple_optimizers else 0) if accumulate_grad_batches == 2 and use_multiple_optimizers: total_optimizer_step += 1 total_optimizer_zero_grad = total_optimizer_step - current_optimizer_zero_grad = current_optimize_step + current_optimizer_zero_grad = current_optimizer_step if accumulate_grad_batches == 2: # that's weird ! todo (tchaton) investigate this @@ -214,34 +214,49 @@ def configure_optimizers_3(self): expected = { "state_dict": {}, "epoch_loop.state_dict": {}, - "epoch_loop.progress": { + "epoch_loop.batch_progress": { "total": { - "ready": num_epochs + 1, - "started": num_epochs + 1, - "processed": 1, - "completed": 1 + "ready": 5, + "started": 5, + "processed": 4, + "completed": 4, }, "current": { - "ready": num_epochs + 1, - "started": num_epochs + 1, + "ready": 2, + "started": 2, "processed": 1, - "completed": 1 + "completed": 1, + }, + }, + "epoch_loop.scheduler_progress": { + "scheduler": { + "total": { + "ready": total_scheduler_step, + "started": None, + "processed": None, + "completed": total_scheduler_step, + }, + "current": { + "ready": current_scheduler_step, + "started": None, + "processed": None, + "completed": current_scheduler_step, + }, }, - "should_check_val": False, }, "epoch_loop.batch_loop.state_dict": {}, - "epoch_loop.batch_loop.progress": { + "epoch_loop.batch_loop.split_progress": { "total": { - "ready": num_batches + 1, - "started": num_batches + 1, - "processed": num_batches, - "completed": num_batches + "ready": 0, + "started": 0, + "processed": 0, + "completed": 0, }, "current": { - "ready": num_batches - limit_train_batches + 1, - "started": num_batches - limit_train_batches + 1, - "processed": num_batches - limit_train_batches, - "completed": num_batches - limit_train_batches + "ready": 0, + "started": 0, + "processed": 0, + "completed": 0, }, }, "epoch_loop.batch_loop.optim_progress": { @@ -250,15 +265,15 @@ def configure_optimizers_3(self): "step": { "total": { "ready": total_optimizer_step + 1, - "started": total_optimizer_step, + "started": None, "processed": None, - "completed": total_optimizer_step + "completed": total_optimizer_step, }, "current": { - "ready": current_optimize_step + 1, - "started": current_optimize_step, + "ready": current_optimizer_step + 1, + "started": None, "processed": None, - "completed": current_optimize_step, + "completed": current_optimizer_step, }, }, "zero_grad": { @@ -266,7 +281,7 @@ def configure_optimizers_3(self): "ready": total_optimizer_zero_grad, "started": total_optimizer_zero_grad, "processed": None, - "completed": total_optimizer_zero_grad + "completed": total_optimizer_zero_grad, }, "current": { "ready": current_optimizer_zero_grad, @@ -276,32 +291,32 @@ def configure_optimizers_3(self): }, }, }, - "scheduler": { - "total": { - "ready": total_scheduler_step, - "started": None, - "processed": None, - "completed": total_scheduler_step - }, - "current": { - "ready": current_scheduler_step, - "started": None, - "processed": None, - "completed": current_scheduler_step - }, - }, }, "epoch_loop.val_loop.state_dict": {}, - "epoch_loop.val_loop.progress": { - "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "epoch_loop.val_loop.dataloader_progress": { + "total": {"ready": 0, "started": None, "processed": None, "completed": 0}, + "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, "dataloader_idx": 0, }, "epoch_loop.val_loop.epoch_loop.state_dict": {}, - "epoch_loop.val_loop.epoch_loop.progress": { + "epoch_loop.val_loop.epoch_loop.batch_progress": { "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, }, + "epoch_progress": { + "total": { + "ready": 2, + "started": 2, + "processed": 1, + "completed": 1, + }, + "current": { + "ready": 2, + "started": 2, + "processed": 1, + "completed": 1, + }, + }, } # yapf: enable From c98bd292a2c3e3330c8f314086f29aef1d509e05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 14 Jul 2021 23:34:12 +0200 Subject: [PATCH 108/157] fix batch_progress.current.reset --- pytorch_lightning/loops/epoch/training_epoch_loop.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 0a3ba558e2f64..9a6f6bd59eac8 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -88,8 +88,7 @@ def reset(self) -> None: if self.restarting: self.iteration_count = self.batches_seen = self.batch_progress.current.completed else: - # todo (tchaton) the batch_loop should be responsible for that. - self.batch_loop.split_progress.current.reset() + self.batch_progress.current.reset() def on_run_start(self, *args: Any, **kwargs: Any) -> None: # hook From f90334cc7dc3e8fe98c4bc89dc6741103354c62a Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 14 Jul 2021 23:39:36 +0200 Subject: [PATCH 109/157] Hold off on split progress. Out of scope of this PR --- pytorch_lightning/loops/batch/training_batch_loop.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 60f43f5ba8825..7976dfb5159f6 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -28,7 +28,7 @@ from pytorch_lightning.loops.base import Loop from pytorch_lightning.plugins import ParallelPlugin from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection -from pytorch_lightning.trainer.progress import OptimizationProgress, Progress +from pytorch_lightning.trainer.progress import OptimizationProgress from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import AMPType, AttributeDict, DeviceType, grad_norm from pytorch_lightning.utilities.apply_func import apply_to_collection @@ -50,8 +50,6 @@ def __init__(self) -> None: self.running_loss: TensorRunningAccum = TensorRunningAccum(window_length=20) self.batch_idx: int = 0 self.split_idx: Optional[int] = None - # TODO: add progress updates for batch splits - self.split_progress = Progress() self.optim_progress = OptimizationProgress() self._warning_cache: WarningCache = WarningCache() From 7fb78deba5e18b30ad828a495c371feecf83cdc0 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 15 Jul 2021 00:07:53 +0200 Subject: [PATCH 110/157] Unnecessary if --- pytorch_lightning/loops/base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 66844a9dc5ddd..6aa8ebefb60b1 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -113,8 +113,7 @@ def run(self, *args: Any, **kwargs: Any) -> Optional[Any]: self.advance(*args, **kwargs) self.on_advance_end() self.iteration_count += 1 - if self.restarting: - self.restarting = False + self.restarting = False except StopIteration: break From 8130a47b9beaa9b9d80ad44111736279d61fbd8f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 15 Jul 2021 00:09:34 +0200 Subject: [PATCH 111/157] fix structure in test_progress --- tests/trainer/test_progress.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/tests/trainer/test_progress.py b/tests/trainer/test_progress.py index c6e4a5d266be0..442d446a8f7d1 100644 --- a/tests/trainer/test_progress.py +++ b/tests/trainer/test_progress.py @@ -229,19 +229,17 @@ def configure_optimizers_3(self): }, }, "epoch_loop.scheduler_progress": { - "scheduler": { - "total": { - "ready": total_scheduler_step, - "started": None, - "processed": None, - "completed": total_scheduler_step, - }, - "current": { - "ready": current_scheduler_step, - "started": None, - "processed": None, - "completed": current_scheduler_step, - }, + "total": { + "ready": total_scheduler_step, + "started": None, + "processed": None, + "completed": total_scheduler_step, + }, + "current": { + "ready": current_scheduler_step, + "started": None, + "processed": None, + "completed": current_scheduler_step, }, }, "epoch_loop.batch_loop.state_dict": {}, @@ -296,7 +294,6 @@ def configure_optimizers_3(self): "epoch_loop.val_loop.dataloader_progress": { "total": {"ready": 0, "started": None, "processed": None, "completed": 0}, "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, - "dataloader_idx": 0, }, "epoch_loop.val_loop.epoch_loop.state_dict": {}, "epoch_loop.val_loop.epoch_loop.batch_progress": { From b6b9ea4c1d4e220e918739ee5f961a381fc47ee8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 15 Jul 2021 00:10:15 +0200 Subject: [PATCH 112/157] structure --- tests/trainer/test_progress.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/test_progress.py b/tests/trainer/test_progress.py index 442d446a8f7d1..26fd16e43a8c1 100644 --- a/tests/trainer/test_progress.py +++ b/tests/trainer/test_progress.py @@ -326,7 +326,7 @@ def configure_optimizers_3(self): trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"]) state_dict = trainer.fit_loop.state_dict() assert state_dict != checkpoint["loops"]["fit_loop"] - assert state_dict['epoch_loop.progress']["total"]["started"] == num_epochs + assert state_dict["epoch_progress"]["total"]["started"] == 1 @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) From 4780b19fa74e2aec22a21a0b5ebba6824d9580e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 15 Jul 2021 00:10:52 +0200 Subject: [PATCH 113/157] clean up unused variables in test_progress --- tests/trainer/test_progress.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/tests/trainer/test_progress.py b/tests/trainer/test_progress.py index 26fd16e43a8c1..6318c29c39701 100644 --- a/tests/trainer/test_progress.py +++ b/tests/trainer/test_progress.py @@ -131,11 +131,10 @@ def __init__(self): super().__init__() if use_multiple_optimizers: self.configure_optimizers = self.configure_optimizers_3 - self.should_fail = True def training_step(self, batch, batch_idx, optimizer_idx: int = None): # simulate failure during the the 5-th training step, 2nd epoch (global_step = 4) - if self.should_fail and self.trainer.current_epoch == 1 and batch_idx == 1 and optimizer_idx == ( + if self.trainer.current_epoch == 1 and batch_idx == 1 and optimizer_idx == ( 1 if use_multiple_optimizers else None ): raise CustomException @@ -155,16 +154,12 @@ def configure_optimizers_3(self): limit_train_batches = 3 - chk = ModelCheckpoint(dirpath=tmpdir, filename=str(use_multiple_optimizers), save_last=True) - chk.last_model_path = None trainer = Trainer( default_root_dir=tmpdir, max_epochs=2, limit_train_batches=limit_train_batches, limit_val_batches=0, - callbacks=chk, accumulate_grad_batches=accumulate_grad_batches, - resume_from_checkpoint=None, ) # simulate random failure in training_step @@ -179,9 +174,6 @@ def configure_optimizers_3(self): checkpoint = torch.load(str(tmpdir / ".pl_auto_save.ckpt")) - num_epochs = 1 - num_batches = 4 - num_optimizers = 3 if use_multiple_optimizers else 1 # 4 optimizer steps because breaking on the second batch of the second epoch (3 + 1) From 7eee718d421ded8e93afcc30155223f753ef2752 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 15 Jul 2021 00:11:24 +0200 Subject: [PATCH 114/157] refactor naming and organization in test_progress --- tests/trainer/test_progress.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/tests/trainer/test_progress.py b/tests/trainer/test_progress.py index 6318c29c39701..919ded5da60c6 100644 --- a/tests/trainer/test_progress.py +++ b/tests/trainer/test_progress.py @@ -141,13 +141,17 @@ def training_step(self, batch, batch_idx, optimizer_idx: int = None): return super().training_step(batch, batch_idx) def configure_optimizers_3(self): - optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) - lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) + optimizer_0 = torch.optim.SGD(self.layer.parameters(), lr=0.1) optimizer_1 = torch.optim.Adam(self.layer.parameters(), lr=0.1) - lr_scheduler_1 = torch.optim.lr_scheduler.StepLR(optimizer_1, step_size=1) optimizer_2 = torch.optim.Adam(self.layer.parameters(), lr=0.1) - return [optimizer, optimizer_1, optimizer_2], \ - [lr_scheduler, {"scheduler": lr_scheduler_1, "interval": "step"}] + optimizers = [optimizer_0, optimizer_1, optimizer_2] + + lr_scheduler_0 = torch.optim.lr_scheduler.StepLR(optimizer_0, step_size=1) + lr_scheduler_1 = torch.optim.lr_scheduler.StepLR(optimizer_1, step_size=1) + # no scheduler for optimizer_2 + lr_schedulers = [lr_scheduler_0, {"scheduler": lr_scheduler_1, "interval": "step"}] + + return optimizers, lr_schedulers model = TestModel() model.training_epoch_end = None @@ -177,15 +181,15 @@ def configure_optimizers_3(self): num_optimizers = 3 if use_multiple_optimizers else 1 # 4 optimizer steps because breaking on the second batch of the second epoch (3 + 1) - total_optimizer_step = (4 * num_optimizers + (1 if use_multiple_optimizers else 0)) // accumulate_grad_batches + completed_optimizer_steps = (4 * num_optimizers + (1 if use_multiple_optimizers else 0)) // accumulate_grad_batches # we raised expection on the first optimizer current_optimizer_step = (1 if use_multiple_optimizers else 0) if accumulate_grad_batches == 2 and use_multiple_optimizers: - total_optimizer_step += 1 + completed_optimizer_steps += 1 - total_optimizer_zero_grad = total_optimizer_step + total_optimizer_zero_grad = completed_optimizer_steps current_optimizer_zero_grad = current_optimizer_step if accumulate_grad_batches == 2: @@ -254,10 +258,10 @@ def configure_optimizers_3(self): "optimizer": { "step": { "total": { - "ready": total_optimizer_step + 1, + "ready": completed_optimizer_steps + 1, "started": None, "processed": None, - "completed": total_optimizer_step, + "completed": completed_optimizer_steps, }, "current": { "ready": current_optimizer_step + 1, From a1bd9892a64d91bfc6a531e278d58dbd931de91b Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 15 Jul 2021 00:09:01 +0200 Subject: [PATCH 115/157] Unnecessary variable --- pytorch_lightning/loops/batch/training_batch_loop.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 7976dfb5159f6..d28b5a2bd39de 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -140,9 +140,7 @@ def advance(self, batch, batch_idx, dataloader_idx): self.trainer.logger_connector.on_train_split_start(batch_idx, split_idx, split_batch) if self.trainer.lightning_module.automatic_optimization: - active_optimizers = self.get_active_optimizers(batch_idx) - for opt_idx, optimizer in active_optimizers: - + for opt_idx, optimizer in self.get_active_optimizers(batch_idx): # handle optimization restart if self.restarting: if opt_idx < self.optim_progress.optimizer_idx: From f6d3a5f3e59dbdc955959f71325a8e515552b839 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 15 Jul 2021 00:13:32 +0200 Subject: [PATCH 116/157] Remove unnecessary diff --- pytorch_lightning/loops/dataloader/dataloader_loop.py | 1 - pytorch_lightning/loops/epoch/training_epoch_loop.py | 4 ---- pytorch_lightning/loops/fit_loop.py | 3 ++- 3 files changed, 2 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/loops/dataloader/dataloader_loop.py b/pytorch_lightning/loops/dataloader/dataloader_loop.py index ed7f776fcad7d..65521aea547d8 100644 --- a/pytorch_lightning/loops/dataloader/dataloader_loop.py +++ b/pytorch_lightning/loops/dataloader/dataloader_loop.py @@ -56,7 +56,6 @@ def done(self) -> bool: def reset(self) -> None: """Resets the internal state""" if not self.restarting: - # reset batch / epoch progress tracking self.dataloader_progress.current.reset() def on_advance_start(self, *args: Any, **kwargs: Any) -> None: diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 9a6f6bd59eac8..91b938404f1ef 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -119,8 +119,6 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None: with self.trainer.profiler.profile("run_training_batch"): batch_output = self.batch_loop.run(batch, self.iteration_count, self._dataloader_idx) - - # TODO: remove with progress tracking self.batches_seen += 1 self.batch_progress.increment_processed() @@ -165,7 +163,6 @@ def on_advance_end(self): # VALIDATE IF NEEDED + CHECKPOINT CALLBACK # ----------------------------------------- should_check_val = self._should_check_val_fx(self.iteration_count, self.is_last_batch) - if should_check_val: self.trainer.validating = True self._run_validation() @@ -393,7 +390,6 @@ def update_lr_schedulers(self, interval: str, update_plateau_schedulers: bool) - """updates the lr schedulers based on the given interval""" if interval == "step" and self.batch_loop.should_accumulate(): return - self.trainer.optimizer_connector.update_learning_rates( interval=interval, update_plateau_schedulers=update_plateau_schedulers, diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index d2f0dfe954a6f..7df0d1445e3b3 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -51,9 +51,10 @@ def __init__( super().__init__() self.max_epochs = 1000 if (max_epochs is None and max_steps is None) else max_epochs self.min_epochs = 1 if (min_epochs is None and min_steps is None) else min_epochs - self.epoch_loop = TrainingEpochLoop(min_steps, max_steps) self.epoch_progress = Progress() + self.epoch_loop = TrainingEpochLoop(min_steps, max_steps) + @property def current_epoch(self) -> int: """Return the current epoch""" From d57bddffb66101187ba8ab4404143463571b5859 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 15 Jul 2021 00:15:42 +0200 Subject: [PATCH 117/157] Improve comment --- pytorch_lightning/trainer/trainer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index bb65ff47817e7..80e0508601b46 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1255,7 +1255,6 @@ def _log_device_info(self) -> None: def _on_expection(self): if not self.is_global_zero or not _fault_tolerant_enabled(): return - - # save a checkpoint for fault tolerant training + # save a checkpoint for fault tolerant training. we don't use `log_dir` to minimize the chances of failure. file_path = os.path.join(self.default_root_dir, ".pl_auto_save.ckpt") self.save_checkpoint(file_path) From 099edd01cc28bebca506242a8ac633cca7823af6 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 15 Jul 2021 00:23:38 +0200 Subject: [PATCH 118/157] Undo typing change to avoid polluting everything with mypy fixes --- pytorch_lightning/core/lightning.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index aeed3c9304a76..735f8ab160c1f 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -31,7 +31,6 @@ from torch.optim.optimizer import Optimizer from torchmetrics import Metric -import pytorch_lightning as pl from pytorch_lightning.core.grads import GradInformation from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks from pytorch_lightning.core.memory import ModelSummary @@ -90,7 +89,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: torch._C._log_api_usage_once(f"lightning.module.{self.__class__.__name__}") # pointer to the trainer object - self.trainer: Optional['pl.Trainer'] = None + self.trainer = None self._distrib_type = None self._device_type = None From 9145c82c1cd6da4c175bc4d4e88f65ee206ed444 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 15 Jul 2021 01:32:02 +0200 Subject: [PATCH 119/157] Fix and improve test_loops.py --- tests/loops/test_loops.py | 81 +++++++++------------------------------ 1 file changed, 18 insertions(+), 63 deletions(-) diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 34828bf8d59f1..59f84b36cf3dd 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -20,17 +20,6 @@ from pytorch_lightning.trainer.trainer import Trainer -def _collect_loop_progress(loop: Loop) -> Dict[str, Any]: - """Return the progress for the current loop and its children.""" - progress = {} - for k, v in loop.__dict__.items(): - if isinstance(v, BaseProgress): - progress[k] = v - elif isinstance(v, Loop): - progress[k] = _collect_loop_progress(v) - return progress - - def test_loop_restore(): class CustomExpection(Exception): @@ -52,12 +41,10 @@ def done(self) -> bool: def reset(self) -> None: self.iter_dataset = iter(self.dataset) - if self.restarting: for _ in range(self.iteration_count): next(self.iter_dataset) self.iteration_count += 1 - # self.restarting = False else: self.outputs = [] @@ -101,15 +88,8 @@ def test_loop_hierarchy(): @dataclass class SimpleProgress(BaseProgress): - increment: int = 0 - def state_dict(self): - return {"increment": self.increment} - - def load_state_dict(self, state_dict): - self.increment = state_dict["increment"] - class Simple(Loop): def __init__(self, a): @@ -122,18 +102,16 @@ def advance(self, *args: Any, **kwargs: Any) -> None: if not loop: return loop.run() - self.progress.increment += 1 - @property - def skip(self) -> bool: - return False + def on_advance_end(self): + self.progress.increment += 1 @property def done(self) -> bool: - return self.iteration_count > 0 + return self.progress.increment > 0 - # def reset(self) -> None: - # self.restarting = False + def reset(self) -> None: + ... def on_save_checkpoint(self) -> Dict: return {"a": self.a} @@ -141,26 +119,15 @@ def on_save_checkpoint(self) -> Dict: def on_load_checkpoint(self, state_dict: Dict) -> None: self.a = state_dict["a"] - grand_loop_parent = Simple(0) loop_parent = Simple(1) loop_child = Simple(2) - loop_parent.loop_child = loop_child - assert not loop_parent.skip - - state_dict = loop_parent.state_dict() - - loop_progress = _collect_loop_progress(loop_parent) - assert loop_progress["progress"] == loop_parent.progress - assert loop_progress["loop_child"]["progress"] == loop_child.progress - - loop_progress = _collect_loop_progress(loop_child) - assert loop_progress["progress"] == loop_child.progress - + # check the trainer reference is propagated loop_parent.trainer = Trainer() - assert loop_child.trainer == loop_parent.trainer + assert loop_child.trainer is loop_parent.trainer + state_dict = loop_parent.state_dict() assert state_dict == { 'state_dict': { 'a': 1 @@ -176,23 +143,14 @@ def on_load_checkpoint(self, state_dict: Dict) -> None: } } - loop_parent.progress - state_dict["loop_child.state_dict"]["a"] = 3 - + # check restarting after `load_state_dict` loop_parent.load_state_dict(state_dict) assert loop_parent.restarting loop_parent.run() - loop_parent_copy = deepcopy(loop_parent) - assert loop_parent_copy.state_dict() == loop_parent.state_dict() - - assert loop_parent_copy.on_save_checkpoint() == {'a': 1} - assert loop_parent_copy.loop_child.on_save_checkpoint() == {'a': 3} - - assert not loop_parent.restarting - + # check the new state after `run` state_dict = loop_parent.state_dict() assert state_dict == { 'state_dict': { @@ -205,26 +163,23 @@ def on_load_checkpoint(self, state_dict: Dict) -> None: 'a': 3 }, 'loop_child.progress': { - 'increment': 0 + 'increment': 1 } } + loop_parent_copy = deepcopy(loop_parent) + assert loop_parent_copy.state_dict() == loop_parent.state_dict() + + assert loop_parent_copy.on_save_checkpoint() == state_dict['state_dict'] + assert loop_parent_copy.loop_child.on_save_checkpoint() == state_dict['loop_child.state_dict'] + loop_parent = Simple(1) loop_child = Simple(2) loop_parent.loop_child = loop_child loop_parent.load_state_dict(state_dict) assert loop_parent.progress.increment == 1 - assert loop_parent.loop_child.progress.increment == 0 + assert loop_parent.loop_child.progress.increment == 1 del loop_parent.loop_child state_dict = loop_parent.state_dict() assert state_dict == {'state_dict': {'a': 1}, 'progress': {'increment': 1}} - - grand_loop_parent = Simple(0) - loop_parent = Simple(1) - loop_child = Simple(2) - grand_loop_parent.loop_child = loop_parent - loop_parent.loop_child = loop_child - - grand_loop_parent.trainer = Trainer() - assert loop_child.trainer is not None From b0fc845f80ca7131ba2774c326284c3e5edd305f Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 15 Jul 2021 01:48:58 +0200 Subject: [PATCH 120/157] Fix and organize `test_loop_state_dict` --- tests/loops/test_loop_state_dict.py | 138 +++++++++------------------- 1 file changed, 43 insertions(+), 95 deletions(-) diff --git a/tests/loops/test_loop_state_dict.py b/tests/loops/test_loop_state_dict.py index 7dc182e2df8fd..cb6ed55d71b31 100644 --- a/tests/loops/test_loop_state_dict.py +++ b/tests/loops/test_loop_state_dict.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import pytest from pytorch_lightning.loops import FitLoop @@ -33,140 +32,89 @@ def test_loops_state_dict(): def test_loops_state_dict_structure(): trainer = Trainer() - # structure saved by the checkpoint connector - state_dict = { - "fit_loop": trainer.fit_loop.state_dict(), - "validate_loop": trainer.validate_loop.state_dict(), - "test_loop": trainer.test_loop.state_dict(), - "predict_loop": trainer.predict_loop.state_dict(), - } + state_dict = trainer.checkpoint_connector._get_loops_state_dict() # yapf: disable expected = { "fit_loop": { "state_dict": {}, - "epoch_loop.state_dict": {}, - "epoch_loop.progress": { - "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "epoch_progress": { "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "should_check_val": False, + "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, }, - "epoch_loop.batch_loop.state_dict": {}, - "epoch_loop.batch_loop.progress": { + + "epoch_loop.state_dict": {}, + "epoch_loop.batch_progress": { + "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + }, + "epoch_loop.scheduler_progress": { "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, }, + "epoch_loop.batch_loop.optim_progress": { - "optimizer_idx": 0, "optimizer": { "step": { - "total": { - "ready": 0, - "started": 0, - "processed": None, - "completed": 0, - }, - "current": { - "ready": 0, - "started": 0, - "processed": None, - "completed": 0, - }, + "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, + "total": {"ready": 0, "started": None, "processed": None, "completed": 0}, }, "zero_grad": { - "total": { - "ready": 0, - "started": 0, - "processed": None, - "completed": 0, - }, - "current": { - "ready": 0, - "started": 0, - "processed": None, - "completed": 0, - }, - }, - }, - "scheduler": { - "total": { - "ready": 0, - "started": None, - "processed": None, - "completed": 0, - }, - "current": { - "ready": 0, - "started": None, - "processed": None, - "completed": 0, + "current": {"ready": 0, "started": 0, "processed": None, "completed": 0}, + "total": {"ready": 0, "started": 0, "processed": None, "completed": 0}, }, }, + "optimizer_idx": 0, }, + "epoch_loop.batch_loop.state_dict": {}, + "epoch_loop.val_loop.state_dict": {}, - "epoch_loop.val_loop.progress": { - "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "dataloader_idx": 0, + "epoch_loop.val_loop.dataloader_progress": { + "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, + "total": {"ready": 0, "started": None, "processed": None, "completed": 0}, }, + "epoch_loop.val_loop.epoch_loop.state_dict": {}, - "epoch_loop.val_loop.epoch_loop.progress": { - "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "epoch_loop.val_loop.epoch_loop.batch_progress": { "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, }, }, - "validate_loop": { + "predict_loop": { "state_dict": {}, - "progress": { - "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "dataloader_idx": 0, + "dataloader_progress": { + "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, + "total": {"ready": 0, "started": None, "processed": None, "completed": 0}, }, + "epoch_loop.state_dict": {}, - "epoch_loop.progress": { - "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "epoch_loop.batch_progress": { "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, }, }, "test_loop": { "state_dict": {}, - "progress": { - "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "dataloader_idx": 0, + "dataloader_progress": { + "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, + "total": {"ready": 0, "started": None, "processed": None, "completed": 0}, }, "epoch_loop.state_dict": {}, - "epoch_loop.progress": { - "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "epoch_loop.batch_progress": { "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, }, }, - "predict_loop": { + "validate_loop": { "state_dict": {}, - "progress": { - "epoch": { - "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "dataloader_idx": 0, - "batch": { - "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "current": { - "ready": 0, - "started": 0, - "processed": 0, - "completed": 0, - }, - }, - } + "dataloader_progress": { + "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, + "total": {"ready": 0, "started": None, "processed": None, "completed": 0}, }, + "epoch_loop.state_dict": {}, - "epoch_loop.progress": { - "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "epoch_loop.batch_progress": { "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "dataloader_idx": 0, - "batch": { - "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - }, + "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, }, }, } From 1577aa8bd1dbcaf8733b23cd6f27c8646dd51bac Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 15 Jul 2021 01:52:39 +0200 Subject: [PATCH 121/157] Remove unnecessary checks in test --- tests/trainer/test_progress.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/tests/trainer/test_progress.py b/tests/trainer/test_progress.py index 919ded5da60c6..280cf22e4c378 100644 --- a/tests/trainer/test_progress.py +++ b/tests/trainer/test_progress.py @@ -20,13 +20,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.trainer.progress import ( - DataLoaderProgress, - OptimizationProgress, - OptimizerProgress, - Progress, - Tracker, -) +from pytorch_lightning.trainer.progress import BaseProgress, OptimizerProgress, Progress, Tracker from tests.helpers import BoringModel @@ -113,11 +107,9 @@ def test_optimizer_progress_default_factory(): def test_deepcopy(): - _ = deepcopy(Tracker()) + _ = deepcopy(BaseProgress()) _ = deepcopy(Progress()) - _ = deepcopy(DataLoaderProgress()) - _ = deepcopy(OptimizerProgress()) - _ = deepcopy(OptimizationProgress()) + _ = deepcopy(Tracker()) @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) From 1f3ae633a7eb84032e2a9e2728a5c939d24e7da0 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 15 Jul 2021 01:55:21 +0200 Subject: [PATCH 122/157] Update test after disallowing updates on None attributes --- tests/trainer/test_progress.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/trainer/test_progress.py b/tests/trainer/test_progress.py index 280cf22e4c378..4fd484128a98e 100644 --- a/tests/trainer/test_progress.py +++ b/tests/trainer/test_progress.py @@ -74,23 +74,23 @@ def test_base_progress_from_defaults(): def test_epoch_loop_progress_increment_sequence(): """Test sequences for incrementing batches reads and epochs.""" - batch = Progress(total=Tracker(started=None)) + batch = Progress() batch.increment_ready() - assert batch.total == Tracker(ready=1, started=None) + assert batch.total == Tracker(ready=1) assert batch.current == Tracker(ready=1) batch.increment_started() - assert batch.total == Tracker(ready=1, started=None) - assert batch.current == Tracker(ready=1) + assert batch.total == Tracker(ready=1, started=1) + assert batch.current == Tracker(ready=1, started=1) batch.increment_processed() - assert batch.total == Tracker(ready=1, started=None, processed=1) - assert batch.current == Tracker(ready=1, processed=1) + assert batch.total == Tracker(ready=1, started=1, processed=1) + assert batch.current == Tracker(ready=1, started=1, processed=1) batch.increment_completed() - assert batch.total == Tracker(ready=1, started=None, processed=1, completed=1) - assert batch.current == Tracker(ready=1, processed=1, completed=1) + assert batch.total == Tracker(ready=1, started=1, processed=1, completed=1) + assert batch.current == Tracker(ready=1, started=1, processed=1, completed=1) def test_optimizer_progress_default_factory(): From ad8224ca4f1bf40c5dc3fbf23a40f52e29704633 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 15 Jul 2021 02:02:29 +0200 Subject: [PATCH 123/157] Typing --- pytorch_lightning/loops/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 6aa8ebefb60b1..1efd67bb26f8e 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -160,7 +160,7 @@ def on_save_checkpoint(self) -> Dict: """ return {} - def on_load_checkpoint(self, state_dict: Dict): + def on_load_checkpoint(self, state_dict: Dict) -> None: """Called when loading a model checkpoint, use to reload loop state.""" def state_dict(self, destination: Optional[Dict] = None, prefix: Optional[str] = "") -> Dict: @@ -185,14 +185,14 @@ def state_dict(self, destination: Optional[Dict] = None, prefix: Optional[str] = return destination - def load_state_dict(self, state_dict: Dict, prefix="", restart_progress: bool = True): + def load_state_dict(self, state_dict: Dict, prefix: str = "", restart_progress: bool = True) -> None: """ Loads the state of this loop and all its children. """ self._load_from_state_dict(state_dict.copy(), prefix, restart_progress) for k, v in self.__dict__.items(): if isinstance(v, Loop): v.load_state_dict(state_dict.copy(), prefix + k + ".", restart_progress) - def _load_from_state_dict(self, state_dict, prefix, restart_progress): + def _load_from_state_dict(self, state_dict: Dict, prefix: str, restart_progress: bool) -> None: for k, v in self.__dict__.items(): if isinstance(v, BaseProgress): v.load_state_dict(state_dict[prefix + k]) From 403ea9d1e7be6eb86ac6835a0fb4ccf4e9c46593 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 15 Jul 2021 02:06:45 +0200 Subject: [PATCH 124/157] Minor test cleanup --- tests/trainer/test_progress.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/tests/trainer/test_progress.py b/tests/trainer/test_progress.py index 4fd484128a98e..b1b60580cf179 100644 --- a/tests/trainer/test_progress.py +++ b/tests/trainer/test_progress.py @@ -336,14 +336,12 @@ def val_dataloader(self): model = ValidationModel() model.validation_epoch_end = None - chk = ModelCheckpoint(dirpath=tmpdir, save_last=True) - chk.last_model_path = None trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_train_batches=5, limit_val_batches=3, - callbacks=chk, + callbacks=ModelCheckpoint(dirpath=tmpdir, save_last=True), resume_from_checkpoint=None, val_check_interval=2, num_sanity_val_steps=0, @@ -355,12 +353,6 @@ def val_dataloader(self): except CustomException: pass - ####################### - # VALIDATE CHECKPOINT # - ####################### - - checkpoint = torch.load(trainer.checkpoint_callback.last_model_path)["loops"]["fit_loop"] - checkpoint = torch.load(trainer.checkpoint_callback.last_model_path)["loops"]["fit_loop"] expected = { From 6492cde5350c4e49ccf3d6fdcc5c55a6231b678b Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 15 Jul 2021 02:52:48 +0200 Subject: [PATCH 125/157] Fix and move loop test --- tests/loops/test_loops.py | 111 +++++++++++++++++++++++++++++++-- tests/trainer/test_progress.py | 85 ------------------------- 2 files changed, 106 insertions(+), 90 deletions(-) diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 59f84b36cf3dd..9219a1f832db0 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -11,19 +11,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os from copy import deepcopy from dataclasses import dataclass from typing import Any, Dict, Iterator +from unittest import mock +import torch + +from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loops.base import Loop from pytorch_lightning.trainer.progress import BaseProgress from pytorch_lightning.trainer.trainer import Trainer +from tests.helpers import BoringModel -def test_loop_restore(): +class CustomException(Exception): + pass - class CustomExpection(Exception): - pass + +def test_loop_restore(): class Simple(Loop): @@ -52,7 +59,7 @@ def advance(self) -> None: value = next(self.iter_dataset) if self.iteration_count == 5: - raise CustomExpection + raise CustomException self.outputs.append(value) @@ -71,7 +78,7 @@ def load_state_dict(self, state_dict: Dict) -> None: try: loop.run() state_dict = {} - except CustomExpection: + except CustomException: state_dict = loop.state_dict() loop = Simple(data) @@ -183,3 +190,97 @@ def on_load_checkpoint(self, state_dict: Dict) -> None: del loop_parent.loop_child state_dict = loop_parent.state_dict() assert state_dict == {'state_dict': {'a': 1}, 'progress': {'increment': 1}} + + +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +def test_loop_restart_progress_multiple_datasets(tmpdir): + stop_epoch = stop_batch = stop_dataloader = 1 + n_dataloaders = 3 + n_batches = 3 + n_epochs = 2 + + class ValidationModel(BoringModel): + + def __init__(self): + super().__init__() + + def validation_step(self, batch, batch_idx, dataloader_idx): + if self.current_epoch == stop_epoch and batch_idx == stop_batch and dataloader_idx == stop_dataloader: + raise CustomException + return super().validation_step(batch, batch_idx) + + def val_dataloader(self): + return [super().val_dataloader()] * n_dataloaders + + model = ValidationModel() + model.validation_epoch_end = None + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=n_epochs, + limit_train_batches=1, + limit_val_batches=n_batches, + callbacks=ModelCheckpoint(dirpath=tmpdir, save_last=True), + num_sanity_val_steps=0, + ) + + # simulate random failure in training_step + try: + trainer.fit(model) + except CustomException: + pass + + ckpt_path = str(tmpdir / '.pl_auto_save.ckpt') + checkpoint = torch.load(ckpt_path)["loops"]["fit_loop"] + + total = (n_epochs - 1) * n_dataloaders + stop_dataloader + expected = { + "total": { + "ready": total + 1, + "started": None, + "processed": None, + "completed": total + }, + "current": { + "ready": stop_dataloader + 1, + "started": None, + "processed": None, + "completed": stop_dataloader, + }, + } + assert checkpoint["epoch_loop.val_loop.dataloader_progress"] == expected + + trainer.fit_loop.load_state_dict(checkpoint, restart_progress=False) + total = n_dataloaders * n_batches + n_batches + stop_epoch + expected = { + "total": { + "ready": total + 1, + "started": total + 1, + "processed": total, + "completed": total + }, + "current": { + "ready": stop_batch + 1, + "started": stop_batch + 1, + "processed": stop_batch, + "completed": stop_batch + }, + } + assert trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.state_dict() == expected + + trainer.fit_loop.load_state_dict(checkpoint) + expected = { + "total": { + "ready": total, + "started": total, + "processed": total, + "completed": total + }, + "current": { + "ready": stop_batch, + "started": stop_batch, + "processed": stop_batch, + "completed": stop_batch + }, + } + assert trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.state_dict() == expected diff --git a/tests/trainer/test_progress.py b/tests/trainer/test_progress.py index b1b60580cf179..889689a241612 100644 --- a/tests/trainer/test_progress.py +++ b/tests/trainer/test_progress.py @@ -19,7 +19,6 @@ import torch from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.trainer.progress import BaseProgress, OptimizerProgress, Progress, Tracker from tests.helpers import BoringModel @@ -315,87 +314,3 @@ def configure_optimizers_3(self): state_dict = trainer.fit_loop.state_dict() assert state_dict != checkpoint["loops"]["fit_loop"] assert state_dict["epoch_progress"]["total"]["started"] == 1 - - -@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) -def test_progress_tracking_validation_multiple_datasets(tmpdir): - - class ValidationModel(BoringModel): - - def __init__(self): - super().__init__() - - def validation_step(self, batch, batch_idx, dataloader_idx): - if self.trainer.fit_loop.epoch_loop.batch_idx == 3 and batch_idx == 1 and dataloader_idx == 1: - raise CustomException - return super().validation_step(batch, batch_idx) - - def val_dataloader(self): - return [super().val_dataloader(), super().val_dataloader(), super().val_dataloader()] - - model = ValidationModel() - model.validation_epoch_end = None - - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - limit_train_batches=5, - limit_val_batches=3, - callbacks=ModelCheckpoint(dirpath=tmpdir, save_last=True), - resume_from_checkpoint=None, - val_check_interval=2, - num_sanity_val_steps=0, - ) - - # simulate random failure in training_step - try: - trainer.fit(model) - except CustomException: - pass - - checkpoint = torch.load(trainer.checkpoint_callback.last_model_path)["loops"]["fit_loop"] - - expected = { - "total": { - "ready": 2, - "started": 2, - "processed": 1, - "completed": 1 - }, - "current": { - "ready": 1, - "started": 1, - "processed": 0, - "completed": 0 - }, - "dataloader_idx": 1, - } - - assert checkpoint["epoch_loop.val_loop.progress"] == expected - - # 3 dataloaders with 3 samples for batch_idx == 1 + first dataloader on batch_idx == 1 + failure on batch_idx = 1 - current = 2 - total = 3 * 3 + 3 + current - expected = { - "total": { - "ready": total, - "started": total, - "processed": total - 1, - "completed": total - 1 - }, - "current": { - "ready": current, - "started": current, - "processed": current - 1, - "completed": current - 1 - }, - } - - assert checkpoint["epoch_loop.val_loop.epoch_loop.progress"] == expected - - trainer = Trainer() - trainer.fit_loop.load_state_dict(checkpoint, restart_progress=False) - assert trainer.fit_loop.state_dict() == checkpoint - - trainer.fit_loop.load_state_dict(checkpoint) - assert trainer.fit_loop.state_dict() != checkpoint From bc5544dcba75c4ff2f4c7d48cb7807d794a16e05 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 15 Jul 2021 02:57:25 +0200 Subject: [PATCH 126/157] Move test from progress to loops --- tests/loops/test_loops.py | 206 +++++++++++++++++++++++++++++++ tests/trainer/test_progress.py | 214 --------------------------------- 2 files changed, 206 insertions(+), 214 deletions(-) diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 9219a1f832db0..1792fbc41c497 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -17,6 +17,7 @@ from typing import Any, Dict, Iterator from unittest import mock +import pytest import torch from pytorch_lightning.callbacks import ModelCheckpoint @@ -284,3 +285,208 @@ def val_dataloader(self): }, } assert trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.state_dict() == expected + + +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +@pytest.mark.parametrize("use_multiple_optimizers", [False, True]) +@pytest.mark.parametrize("accumulate_grad_batches", [1, 2]) +def test_progress_tracking(use_multiple_optimizers, accumulate_grad_batches, tmpdir): + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + if use_multiple_optimizers: + self.configure_optimizers = self.configure_optimizers_3 + + def training_step(self, batch, batch_idx, optimizer_idx: int = None): + # simulate failure during the the 5-th training step, 2nd epoch (global_step = 4) + if self.trainer.current_epoch == 1 and batch_idx == 1 and optimizer_idx == ( + 1 if use_multiple_optimizers else None + ): + raise CustomException + return super().training_step(batch, batch_idx) + + def configure_optimizers_3(self): + optimizer_0 = torch.optim.SGD(self.layer.parameters(), lr=0.1) + optimizer_1 = torch.optim.Adam(self.layer.parameters(), lr=0.1) + optimizer_2 = torch.optim.Adam(self.layer.parameters(), lr=0.1) + optimizers = [optimizer_0, optimizer_1, optimizer_2] + + lr_scheduler_0 = torch.optim.lr_scheduler.StepLR(optimizer_0, step_size=1) + lr_scheduler_1 = torch.optim.lr_scheduler.StepLR(optimizer_1, step_size=1) + # no scheduler for optimizer_2 + lr_schedulers = [lr_scheduler_0, {"scheduler": lr_scheduler_1, "interval": "step"}] + + return optimizers, lr_schedulers + + model = TestModel() + model.training_epoch_end = None + + limit_train_batches = 3 + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + limit_train_batches=limit_train_batches, + limit_val_batches=0, + accumulate_grad_batches=accumulate_grad_batches, + ) + + # simulate random failure in training_step + try: + trainer.fit(model) + except CustomException: + pass + + ####################### + # VALIDATE CHECKPOINT # + ####################### + + checkpoint = torch.load(str(tmpdir / ".pl_auto_save.ckpt")) + + num_optimizers = 3 if use_multiple_optimizers else 1 + + # 4 optimizer steps because breaking on the second batch of the second epoch (3 + 1) + completed_optimizer_steps = (4 * num_optimizers + (1 if use_multiple_optimizers else 0)) // accumulate_grad_batches + + # we raised expection on the first optimizer + current_optimizer_step = (1 if use_multiple_optimizers else 0) + + if accumulate_grad_batches == 2 and use_multiple_optimizers: + completed_optimizer_steps += 1 + + total_optimizer_zero_grad = completed_optimizer_steps + current_optimizer_zero_grad = current_optimizer_step + + if accumulate_grad_batches == 2: + # that's weird ! todo (tchaton) investigate this + total_optimizer_zero_grad = (9 if use_multiple_optimizers else 3) + current_optimizer_zero_grad = 0 # same there. + + total_scheduler_step = (5 if use_multiple_optimizers else 1) // accumulate_grad_batches + + current_scheduler_step = 0 + + if accumulate_grad_batches == 2: + total_scheduler_step += 1 + + optimizer_idx = (1 if use_multiple_optimizers else 0) + + # yapf: disable + expected = { + "state_dict": {}, + "epoch_loop.state_dict": {}, + "epoch_loop.batch_progress": { + "total": { + "ready": 5, + "started": 5, + "processed": 4, + "completed": 4, + }, + "current": { + "ready": 2, + "started": 2, + "processed": 1, + "completed": 1, + }, + }, + "epoch_loop.scheduler_progress": { + "total": { + "ready": total_scheduler_step, + "started": None, + "processed": None, + "completed": total_scheduler_step, + }, + "current": { + "ready": current_scheduler_step, + "started": None, + "processed": None, + "completed": current_scheduler_step, + }, + }, + "epoch_loop.batch_loop.state_dict": {}, + "epoch_loop.batch_loop.split_progress": { + "total": { + "ready": 0, + "started": 0, + "processed": 0, + "completed": 0, + }, + "current": { + "ready": 0, + "started": 0, + "processed": 0, + "completed": 0, + }, + }, + "epoch_loop.batch_loop.optim_progress": { + "optimizer_idx": optimizer_idx, + "optimizer": { + "step": { + "total": { + "ready": completed_optimizer_steps + 1, + "started": None, + "processed": None, + "completed": completed_optimizer_steps, + }, + "current": { + "ready": current_optimizer_step + 1, + "started": None, + "processed": None, + "completed": current_optimizer_step, + }, + }, + "zero_grad": { + "total": { + "ready": total_optimizer_zero_grad, + "started": total_optimizer_zero_grad, + "processed": None, + "completed": total_optimizer_zero_grad, + }, + "current": { + "ready": current_optimizer_zero_grad, + "started": current_optimizer_zero_grad, + "processed": None, + "completed": current_optimizer_zero_grad, + }, + }, + }, + }, + "epoch_loop.val_loop.state_dict": {}, + "epoch_loop.val_loop.dataloader_progress": { + "total": {"ready": 0, "started": None, "processed": None, "completed": 0}, + "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, + }, + "epoch_loop.val_loop.epoch_loop.state_dict": {}, + "epoch_loop.val_loop.epoch_loop.batch_progress": { + "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + }, + "epoch_progress": { + "total": { + "ready": 2, + "started": 2, + "processed": 1, + "completed": 1, + }, + "current": { + "ready": 2, + "started": 2, + "processed": 1, + "completed": 1, + }, + }, + } + # yapf: enable + + assert checkpoint["loops"]["fit_loop"] == expected + + trainer = Trainer() + trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"], restart_progress=False) + assert trainer.fit_loop.state_dict() == checkpoint["loops"]["fit_loop"] + + trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"]) + state_dict = trainer.fit_loop.state_dict() + assert state_dict != checkpoint["loops"]["fit_loop"] + assert state_dict["epoch_progress"]["total"]["started"] == 1 diff --git a/tests/trainer/test_progress.py b/tests/trainer/test_progress.py index 889689a241612..4057a2a686134 100644 --- a/tests/trainer/test_progress.py +++ b/tests/trainer/test_progress.py @@ -11,20 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os from copy import deepcopy -from unittest import mock import pytest -import torch -from pytorch_lightning import Trainer from pytorch_lightning.trainer.progress import BaseProgress, OptimizerProgress, Progress, Tracker -from tests.helpers import BoringModel - - -class CustomException(BaseException): - pass def test_progress_getattr_setattr(): @@ -109,208 +100,3 @@ def test_deepcopy(): _ = deepcopy(BaseProgress()) _ = deepcopy(Progress()) _ = deepcopy(Tracker()) - - -@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) -@pytest.mark.parametrize("use_multiple_optimizers", [False, True]) -@pytest.mark.parametrize("accumulate_grad_batches", [1, 2]) -def test_progress_tracking(use_multiple_optimizers, accumulate_grad_batches, tmpdir): - - class TestModel(BoringModel): - - def __init__(self): - super().__init__() - if use_multiple_optimizers: - self.configure_optimizers = self.configure_optimizers_3 - - def training_step(self, batch, batch_idx, optimizer_idx: int = None): - # simulate failure during the the 5-th training step, 2nd epoch (global_step = 4) - if self.trainer.current_epoch == 1 and batch_idx == 1 and optimizer_idx == ( - 1 if use_multiple_optimizers else None - ): - raise CustomException - return super().training_step(batch, batch_idx) - - def configure_optimizers_3(self): - optimizer_0 = torch.optim.SGD(self.layer.parameters(), lr=0.1) - optimizer_1 = torch.optim.Adam(self.layer.parameters(), lr=0.1) - optimizer_2 = torch.optim.Adam(self.layer.parameters(), lr=0.1) - optimizers = [optimizer_0, optimizer_1, optimizer_2] - - lr_scheduler_0 = torch.optim.lr_scheduler.StepLR(optimizer_0, step_size=1) - lr_scheduler_1 = torch.optim.lr_scheduler.StepLR(optimizer_1, step_size=1) - # no scheduler for optimizer_2 - lr_schedulers = [lr_scheduler_0, {"scheduler": lr_scheduler_1, "interval": "step"}] - - return optimizers, lr_schedulers - - model = TestModel() - model.training_epoch_end = None - - limit_train_batches = 3 - - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=2, - limit_train_batches=limit_train_batches, - limit_val_batches=0, - accumulate_grad_batches=accumulate_grad_batches, - ) - - # simulate random failure in training_step - try: - trainer.fit(model) - except CustomException: - pass - - ####################### - # VALIDATE CHECKPOINT # - ####################### - - checkpoint = torch.load(str(tmpdir / ".pl_auto_save.ckpt")) - - num_optimizers = 3 if use_multiple_optimizers else 1 - - # 4 optimizer steps because breaking on the second batch of the second epoch (3 + 1) - completed_optimizer_steps = (4 * num_optimizers + (1 if use_multiple_optimizers else 0)) // accumulate_grad_batches - - # we raised expection on the first optimizer - current_optimizer_step = (1 if use_multiple_optimizers else 0) - - if accumulate_grad_batches == 2 and use_multiple_optimizers: - completed_optimizer_steps += 1 - - total_optimizer_zero_grad = completed_optimizer_steps - current_optimizer_zero_grad = current_optimizer_step - - if accumulate_grad_batches == 2: - # that's weird ! todo (tchaton) investigate this - total_optimizer_zero_grad = (9 if use_multiple_optimizers else 3) - current_optimizer_zero_grad = 0 # same there. - - total_scheduler_step = (5 if use_multiple_optimizers else 1) // accumulate_grad_batches - - current_scheduler_step = 0 - - if accumulate_grad_batches == 2: - total_scheduler_step += 1 - - optimizer_idx = (1 if use_multiple_optimizers else 0) - - # yapf: disable - expected = { - "state_dict": {}, - "epoch_loop.state_dict": {}, - "epoch_loop.batch_progress": { - "total": { - "ready": 5, - "started": 5, - "processed": 4, - "completed": 4, - }, - "current": { - "ready": 2, - "started": 2, - "processed": 1, - "completed": 1, - }, - }, - "epoch_loop.scheduler_progress": { - "total": { - "ready": total_scheduler_step, - "started": None, - "processed": None, - "completed": total_scheduler_step, - }, - "current": { - "ready": current_scheduler_step, - "started": None, - "processed": None, - "completed": current_scheduler_step, - }, - }, - "epoch_loop.batch_loop.state_dict": {}, - "epoch_loop.batch_loop.split_progress": { - "total": { - "ready": 0, - "started": 0, - "processed": 0, - "completed": 0, - }, - "current": { - "ready": 0, - "started": 0, - "processed": 0, - "completed": 0, - }, - }, - "epoch_loop.batch_loop.optim_progress": { - "optimizer_idx": optimizer_idx, - "optimizer": { - "step": { - "total": { - "ready": completed_optimizer_steps + 1, - "started": None, - "processed": None, - "completed": completed_optimizer_steps, - }, - "current": { - "ready": current_optimizer_step + 1, - "started": None, - "processed": None, - "completed": current_optimizer_step, - }, - }, - "zero_grad": { - "total": { - "ready": total_optimizer_zero_grad, - "started": total_optimizer_zero_grad, - "processed": None, - "completed": total_optimizer_zero_grad, - }, - "current": { - "ready": current_optimizer_zero_grad, - "started": current_optimizer_zero_grad, - "processed": None, - "completed": current_optimizer_zero_grad, - }, - }, - }, - }, - "epoch_loop.val_loop.state_dict": {}, - "epoch_loop.val_loop.dataloader_progress": { - "total": {"ready": 0, "started": None, "processed": None, "completed": 0}, - "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, - }, - "epoch_loop.val_loop.epoch_loop.state_dict": {}, - "epoch_loop.val_loop.epoch_loop.batch_progress": { - "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - }, - "epoch_progress": { - "total": { - "ready": 2, - "started": 2, - "processed": 1, - "completed": 1, - }, - "current": { - "ready": 2, - "started": 2, - "processed": 1, - "completed": 1, - }, - }, - } - # yapf: enable - - assert checkpoint["loops"]["fit_loop"] == expected - - trainer = Trainer() - trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"], restart_progress=False) - assert trainer.fit_loop.state_dict() == checkpoint["loops"]["fit_loop"] - - trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"]) - state_dict = trainer.fit_loop.state_dict() - assert state_dict != checkpoint["loops"]["fit_loop"] - assert state_dict["epoch_progress"]["total"]["started"] == 1 From 098c7b5becc0b64d9784ea63553376ff47f79874 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 15 Jul 2021 03:47:55 +0200 Subject: [PATCH 127/157] Reset the scheduler progress --- pytorch_lightning/loops/epoch/training_epoch_loop.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 91b938404f1ef..2e482a01132cb 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -89,6 +89,7 @@ def reset(self) -> None: self.iteration_count = self.batches_seen = self.batch_progress.current.completed else: self.batch_progress.current.reset() + self.scheduler_progress.current.reset() def on_run_start(self, *args: Any, **kwargs: Any) -> None: # hook From ef7c9e05059146e077bd32602f74804c428473b9 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 15 Jul 2021 04:03:38 +0200 Subject: [PATCH 128/157] SchedulerProgress fix --- pytorch_lightning/loops/batch/training_batch_loop.py | 1 - pytorch_lightning/trainer/connectors/optimizer_connector.py | 2 +- pytorch_lightning/trainer/progress.py | 1 + 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index d28b5a2bd39de..a4d76e3547126 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -409,7 +409,6 @@ def _optimizer_step( ) self.optim_progress.optimizer.step.increment_completed() - self.trainer.fit_loop.epoch_loop.scheduler_progress.increment_ready() def _on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None: """Calls the ``on_before_zero_grad`` hook. diff --git a/pytorch_lightning/trainer/connectors/optimizer_connector.py b/pytorch_lightning/trainer/connectors/optimizer_connector.py index 9939901832c0e..4c49b6e028cb4 100644 --- a/pytorch_lightning/trainer/connectors/optimizer_connector.py +++ b/pytorch_lightning/trainer/connectors/optimizer_connector.py @@ -83,7 +83,7 @@ def update_learning_rates( # update LR old_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr'] - self.trainer.fit_loop.epoch_loop.scheduler_progress.increment_started() + self.trainer.fit_loop.epoch_loop.scheduler_progress.increment_ready() if lr_scheduler['reduce_on_plateau']: lr_scheduler['scheduler'].step(monitor_val) diff --git a/pytorch_lightning/trainer/progress.py b/pytorch_lightning/trainer/progress.py index 4410cf3901ff2..1321cdd596fe7 100644 --- a/pytorch_lightning/trainer/progress.py +++ b/pytorch_lightning/trainer/progress.py @@ -136,6 +136,7 @@ class DataLoaderProgress(Progress): current: Tracker = field(default_factory=lambda: Tracker(started=None, processed=None)) +@dataclass class SchedulerProgress(Progress): """ Tracks the scheduler progress From 7938403a5e4cb6726e1739a35fdb36dcf144fc3f Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 15 Jul 2021 04:04:37 +0200 Subject: [PATCH 129/157] Consistent whitespace --- pytorch_lightning/trainer/progress.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/trainer/progress.py b/pytorch_lightning/trainer/progress.py index 1321cdd596fe7..fe9f90613ea9c 100644 --- a/pytorch_lightning/trainer/progress.py +++ b/pytorch_lightning/trainer/progress.py @@ -132,6 +132,7 @@ class DataLoaderProgress(Progress): total: Tracks the total dataloader progress current: Tracks the current dataloader progress """ + total: Tracker = field(default_factory=lambda: Tracker(started=None, processed=None)) current: Tracker = field(default_factory=lambda: Tracker(started=None, processed=None)) From 7799101b0ce2a4cfbf1a828ed1a1aa374caef3c7 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 15 Jul 2021 04:15:51 +0200 Subject: [PATCH 130/157] Fix final test --- tests/loops/test_loops.py | 157 +++++++++++++++----------------------- 1 file changed, 61 insertions(+), 96 deletions(-) diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 1792fbc41c497..0b6eb4205a802 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -16,6 +16,7 @@ from dataclasses import dataclass from typing import Any, Dict, Iterator from unittest import mock +from unittest.mock import ANY import pytest import torch @@ -225,7 +226,7 @@ def val_dataloader(self): num_sanity_val_steps=0, ) - # simulate random failure in training_step + # simulate a failure try: trainer.fit(model) except CustomException: @@ -290,7 +291,12 @@ def val_dataloader(self): @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) @pytest.mark.parametrize("use_multiple_optimizers", [False, True]) @pytest.mark.parametrize("accumulate_grad_batches", [1, 2]) -def test_progress_tracking(use_multiple_optimizers, accumulate_grad_batches, tmpdir): +def test_loop_state_on_exception(use_multiple_optimizers, accumulate_grad_batches, tmpdir): + stop_epoch = stop_batch = 1 + stop_optimizer = 1 if use_multiple_optimizers else 0 + n_optimizers = 3 if use_multiple_optimizers else 1 + n_epochs = 2 + n_batches = 3 class TestModel(BoringModel): @@ -299,11 +305,8 @@ def __init__(self): if use_multiple_optimizers: self.configure_optimizers = self.configure_optimizers_3 - def training_step(self, batch, batch_idx, optimizer_idx: int = None): - # simulate failure during the the 5-th training step, 2nd epoch (global_step = 4) - if self.trainer.current_epoch == 1 and batch_idx == 1 and optimizer_idx == ( - 1 if use_multiple_optimizers else None - ): + def training_step(self, batch, batch_idx, optimizer_idx=0): + if self.trainer.current_epoch == stop_epoch and batch_idx == stop_batch and optimizer_idx == stop_optimizer: raise CustomException return super().training_step(batch, batch_idx) @@ -323,118 +326,102 @@ def configure_optimizers_3(self): model = TestModel() model.training_epoch_end = None - limit_train_batches = 3 - trainer = Trainer( default_root_dir=tmpdir, - max_epochs=2, - limit_train_batches=limit_train_batches, + max_epochs=n_epochs, + limit_train_batches=n_batches, limit_val_batches=0, accumulate_grad_batches=accumulate_grad_batches, ) - # simulate random failure in training_step + # simulate a failure try: trainer.fit(model) except CustomException: pass - ####################### - # VALIDATE CHECKPOINT # - ####################### - - checkpoint = torch.load(str(tmpdir / ".pl_auto_save.ckpt")) - - num_optimizers = 3 if use_multiple_optimizers else 1 - - # 4 optimizer steps because breaking on the second batch of the second epoch (3 + 1) - completed_optimizer_steps = (4 * num_optimizers + (1 if use_multiple_optimizers else 0)) // accumulate_grad_batches - - # we raised expection on the first optimizer - current_optimizer_step = (1 if use_multiple_optimizers else 0) - - if accumulate_grad_batches == 2 and use_multiple_optimizers: - completed_optimizer_steps += 1 + ckpt_path = str(tmpdir / ".pl_auto_save.ckpt") + checkpoint = torch.load(ckpt_path) - total_optimizer_zero_grad = completed_optimizer_steps - current_optimizer_zero_grad = current_optimizer_step + batches_seen = (n_epochs - stop_epoch) * n_batches + stop_batch + total_optimizer_steps = batches_seen // accumulate_grad_batches * n_optimizers + stop_optimizer + total_optimizer_zero_grad = total_optimizer_steps + current_optimizer_zero_grad = stop_optimizer if accumulate_grad_batches == 2: - # that's weird ! todo (tchaton) investigate this + # FIXME: that's weird ! total_optimizer_zero_grad = (9 if use_multiple_optimizers else 3) current_optimizer_zero_grad = 0 # same there. - total_scheduler_step = (5 if use_multiple_optimizers else 1) // accumulate_grad_batches - - current_scheduler_step = 0 - - if accumulate_grad_batches == 2: - total_scheduler_step += 1 - - optimizer_idx = (1 if use_multiple_optimizers else 0) + total_scheduler_steps = n_epochs - stop_epoch + current_scheduler_steps = 0 # the current epoch did not complete + if use_multiple_optimizers: + # 1 for the epoch-interval scheduler and `batches_seen` for the batch-interval scheduler + total_scheduler_steps = 1 + batches_seen // accumulate_grad_batches + current_scheduler_steps = stop_batch // accumulate_grad_batches # yapf: disable expected = { "state_dict": {}, + "epoch_progress": { + "total": { + "ready": stop_epoch + 1, + "started": stop_epoch + 1, + "processed": stop_epoch, + "completed": stop_epoch, + }, + "current": { + "ready": stop_epoch + 1, + "started": stop_epoch + 1, + "processed": stop_epoch, + "completed": stop_epoch, + }, + }, "epoch_loop.state_dict": {}, "epoch_loop.batch_progress": { "total": { - "ready": 5, - "started": 5, - "processed": 4, - "completed": 4, + "ready": batches_seen + 1, + "started": batches_seen + 1, + "processed": batches_seen, + "completed": batches_seen, }, "current": { - "ready": 2, - "started": 2, - "processed": 1, - "completed": 1, + "ready": stop_batch + 1, + "started": stop_batch + 1, + "processed": stop_batch, + "completed": stop_batch, }, }, "epoch_loop.scheduler_progress": { "total": { - "ready": total_scheduler_step, + "ready": total_scheduler_steps, "started": None, "processed": None, - "completed": total_scheduler_step, + "completed": total_scheduler_steps, }, "current": { - "ready": current_scheduler_step, + "ready": current_scheduler_steps, "started": None, "processed": None, - "completed": current_scheduler_step, + "completed": current_scheduler_steps, }, }, "epoch_loop.batch_loop.state_dict": {}, - "epoch_loop.batch_loop.split_progress": { - "total": { - "ready": 0, - "started": 0, - "processed": 0, - "completed": 0, - }, - "current": { - "ready": 0, - "started": 0, - "processed": 0, - "completed": 0, - }, - }, "epoch_loop.batch_loop.optim_progress": { - "optimizer_idx": optimizer_idx, + "optimizer_idx": stop_optimizer, "optimizer": { "step": { "total": { - "ready": completed_optimizer_steps + 1, + "ready": total_optimizer_steps + 1, "started": None, "processed": None, - "completed": completed_optimizer_steps, + "completed": total_optimizer_steps, }, "current": { - "ready": current_optimizer_step + 1, + "ready": stop_optimizer + 1, "started": None, "processed": None, - "completed": current_optimizer_step, + "completed": stop_optimizer, }, }, "zero_grad": { @@ -453,36 +440,14 @@ def configure_optimizers_3(self): }, }, }, - "epoch_loop.val_loop.state_dict": {}, - "epoch_loop.val_loop.dataloader_progress": { - "total": {"ready": 0, "started": None, "processed": None, "completed": 0}, - "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, - }, - "epoch_loop.val_loop.epoch_loop.state_dict": {}, - "epoch_loop.val_loop.epoch_loop.batch_progress": { - "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - }, - "epoch_progress": { - "total": { - "ready": 2, - "started": 2, - "processed": 1, - "completed": 1, - }, - "current": { - "ready": 2, - "started": 2, - "processed": 1, - "completed": 1, - }, - }, + "epoch_loop.val_loop.state_dict": ANY, + "epoch_loop.val_loop.dataloader_progress": ANY, + "epoch_loop.val_loop.epoch_loop.state_dict": ANY, + "epoch_loop.val_loop.epoch_loop.batch_progress": ANY, } # yapf: enable - assert checkpoint["loops"]["fit_loop"] == expected - trainer = Trainer() trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"], restart_progress=False) assert trainer.fit_loop.state_dict() == checkpoint["loops"]["fit_loop"] From a3756076fb7bd71441dacb8f02d63d21b35e6f0c Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 15 Jul 2021 05:22:07 +0200 Subject: [PATCH 131/157] Minor test changes --- tests/loops/test_loop_state_dict.py | 4 ++-- tests/loops/test_loops.py | 15 ++++++--------- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/tests/loops/test_loop_state_dict.py b/tests/loops/test_loop_state_dict.py index cb6ed55d71b31..f014f8c619b54 100644 --- a/tests/loops/test_loop_state_dict.py +++ b/tests/loops/test_loop_state_dict.py @@ -48,8 +48,8 @@ def test_loops_state_dict_structure(): "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, }, "epoch_loop.scheduler_progress": { - "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, + "total": {"ready": 0, "started": None, "processed": None, "completed": 0}, }, "epoch_loop.batch_loop.optim_progress": { diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 0b6eb4205a802..ec9ad3d2b9257 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -212,7 +212,7 @@ def validation_step(self, batch, batch_idx, dataloader_idx): return super().validation_step(batch, batch_idx) def val_dataloader(self): - return [super().val_dataloader()] * n_dataloaders + return [super(ValidationModel, self).val_dataloader() for _ in range(n_dataloaders)] model = ValidationModel() model.validation_epoch_end = None @@ -303,21 +303,18 @@ class TestModel(BoringModel): def __init__(self): super().__init__() if use_multiple_optimizers: - self.configure_optimizers = self.configure_optimizers_3 + self.configure_optimizers = self.configure_optimizers_multiple def training_step(self, batch, batch_idx, optimizer_idx=0): if self.trainer.current_epoch == stop_epoch and batch_idx == stop_batch and optimizer_idx == stop_optimizer: raise CustomException return super().training_step(batch, batch_idx) - def configure_optimizers_3(self): - optimizer_0 = torch.optim.SGD(self.layer.parameters(), lr=0.1) - optimizer_1 = torch.optim.Adam(self.layer.parameters(), lr=0.1) - optimizer_2 = torch.optim.Adam(self.layer.parameters(), lr=0.1) - optimizers = [optimizer_0, optimizer_1, optimizer_2] + def configure_optimizers_multiple(self): + optimizers = [torch.optim.Adam(self.layer.parameters(), lr=0.1) for _ in range(n_optimizers)] - lr_scheduler_0 = torch.optim.lr_scheduler.StepLR(optimizer_0, step_size=1) - lr_scheduler_1 = torch.optim.lr_scheduler.StepLR(optimizer_1, step_size=1) + lr_scheduler_0 = torch.optim.lr_scheduler.StepLR(optimizers[0], step_size=1) + lr_scheduler_1 = torch.optim.lr_scheduler.StepLR(optimizers[1], step_size=1) # no scheduler for optimizer_2 lr_schedulers = [lr_scheduler_0, {"scheduler": lr_scheduler_1, "interval": "step"}] From abb08a063120bcbe8d71b552b1c54519e4fe0c95 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 15 Jul 2021 18:05:18 +0200 Subject: [PATCH 132/157] One test to rule them all --- .../loops/batch/training_batch_loop.py | 2 - .../loops/epoch/training_epoch_loop.py | 1 + tests/loops/test_loops.py | 128 ++++++++++++------ 3 files changed, 89 insertions(+), 42 deletions(-) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index a4d76e3547126..b7a5eceae916e 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -110,8 +110,6 @@ def reset(self) -> None: self.batch_idx = 0 self.batch_outputs = [[] for _ in range(len(self.trainer.optimizers))] - self.optim_progress.reset_on_epoch() - def on_run_start(self, batch: Any, batch_idx: int, dataloader_idx: int): """Splits the data into tbptt splits diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 2e482a01132cb..d9a2e6bb8cbb3 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -90,6 +90,7 @@ def reset(self) -> None: else: self.batch_progress.current.reset() self.scheduler_progress.current.reset() + self.batch_loop.optim_progress.reset_on_epoch() def on_run_start(self, *args: Any, **kwargs: Any) -> None: # hook diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index ec9ad3d2b9257..23e28338719e4 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -149,7 +149,7 @@ def on_load_checkpoint(self, state_dict: Dict) -> None: }, 'loop_child.progress': { 'increment': 0 - } + }, } state_dict["loop_child.state_dict"]["a"] = 3 @@ -173,7 +173,7 @@ def on_load_checkpoint(self, state_dict: Dict) -> None: }, 'loop_child.progress': { 'increment': 1 - } + }, } loop_parent_copy = deepcopy(loop_parent) @@ -195,7 +195,7 @@ def on_load_checkpoint(self, state_dict: Dict) -> None: @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) -def test_loop_restart_progress_multiple_datasets(tmpdir): +def test_loop_restart_progress_multiple_dataloaders(tmpdir): stop_epoch = stop_batch = stop_dataloader = 1 n_dataloaders = 3 n_batches = 3 @@ -265,7 +265,7 @@ def val_dataloader(self): "ready": stop_batch + 1, "started": stop_batch + 1, "processed": stop_batch, - "completed": stop_batch + "completed": stop_batch, }, } assert trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.state_dict() == expected @@ -289,20 +289,21 @@ def val_dataloader(self): @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) -@pytest.mark.parametrize("use_multiple_optimizers", [False, True]) -@pytest.mark.parametrize("accumulate_grad_batches", [1, 2]) -def test_loop_state_on_exception(use_multiple_optimizers, accumulate_grad_batches, tmpdir): - stop_epoch = stop_batch = 1 - stop_optimizer = 1 if use_multiple_optimizers else 0 - n_optimizers = 3 if use_multiple_optimizers else 1 - n_epochs = 2 +@pytest.mark.parametrize("accumulate_grad_batches", (1, 2)) # FIXME: 3 is broken +@pytest.mark.parametrize("n_optimizers", (1, 3, 5)) +@pytest.mark.parametrize("stop_epoch", (1, 2)) +@pytest.mark.parametrize("stop_batch", (1, )) # FIXME: 2 is broken +@pytest.mark.parametrize("stop_optimizer", (1, 2)) +def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch, stop_optimizer, n_optimizers, tmpdir): + stop_optimizer = stop_optimizer if stop_optimizer < n_optimizers else 0 + n_epochs = 3 n_batches = 3 class TestModel(BoringModel): def __init__(self): super().__init__() - if use_multiple_optimizers: + if n_optimizers > 1: self.configure_optimizers = self.configure_optimizers_multiple def training_step(self, batch, batch_idx, optimizer_idx=0): @@ -329,6 +330,8 @@ def configure_optimizers_multiple(self): limit_train_batches=n_batches, limit_val_batches=0, accumulate_grad_batches=accumulate_grad_batches, + progress_bar_refresh_rate=0, + logger=False, ) # simulate a failure @@ -340,22 +343,65 @@ def configure_optimizers_multiple(self): ckpt_path = str(tmpdir / ".pl_auto_save.ckpt") checkpoint = torch.load(ckpt_path) - batches_seen = (n_epochs - stop_epoch) * n_batches + stop_batch - total_optimizer_steps = batches_seen // accumulate_grad_batches * n_optimizers + stop_optimizer + optim_progress = trainer.fit_loop.epoch_loop.batch_loop.optim_progress + scheduler_progress = trainer.fit_loop.epoch_loop.scheduler_progress - total_optimizer_zero_grad = total_optimizer_steps - current_optimizer_zero_grad = stop_optimizer - if accumulate_grad_batches == 2: - # FIXME: that's weird ! - total_optimizer_zero_grad = (9 if use_multiple_optimizers else 3) - current_optimizer_zero_grad = 0 # same there. + non_breaking_epoch_batches_completed = stop_epoch * n_batches + breaking_epoch_batches_completed = stop_batch + breaking_epoch_batches_ready = stop_batch + 1 + # lightning applies leftover accumulated gradients when the epoch ends + has_leftover_accumulation_batches = n_batches % accumulate_grad_batches != 0 - total_scheduler_steps = n_epochs - stop_epoch - current_scheduler_steps = 0 # the current epoch did not complete - if use_multiple_optimizers: - # 1 for the epoch-interval scheduler and `batches_seen` for the batch-interval scheduler - total_scheduler_steps = 1 + batches_seen // accumulate_grad_batches - current_scheduler_steps = stop_batch // accumulate_grad_batches + non_breaking_total_optimizer_steps = ( + non_breaking_epoch_batches_completed // accumulate_grad_batches * n_optimizers + + has_leftover_accumulation_batches * n_optimizers + ) + should_last_batch_step = breaking_epoch_batches_ready % accumulate_grad_batches == 0 + breaking_total_optimizer_steps = ( + breaking_epoch_batches_completed // accumulate_grad_batches * n_optimizers + + should_last_batch_step * stop_optimizer + ) + total_optimizer_steps = non_breaking_total_optimizer_steps + breaking_total_optimizer_steps + current_optimizer_steps = breaking_total_optimizer_steps + has_optimizer_step_in_breaking_epoch = accumulate_grad_batches == 1 or n_batches % accumulate_grad_batches != 0 + assert optim_progress.optimizer_steps == total_optimizer_steps + assert optim_progress.optimizer.step.current.completed == current_optimizer_steps + + non_breaking_total_zero_grad = ( + non_breaking_epoch_batches_completed // accumulate_grad_batches + has_leftover_accumulation_batches + ) * n_optimizers + # FIXME: What the hell + if accumulate_grad_batches > 1: + # FIXME: ready or completed? 0 or stop_optimizer? + breaking_total_zero_grad = ( + n_optimizers + (breaking_epoch_batches_ready // accumulate_grad_batches - (accumulate_grad_batches > 1)) * + (n_optimizers - 1) + 0 + ) + # breaking_total_zero_grad = breaking_epoch_batches_ready // accumulate_grad_batches * n_optimizers + 0 + else: + breaking_total_zero_grad = ( + breaking_epoch_batches_completed // accumulate_grad_batches * n_optimizers + stop_optimizer + ) + total_zero_grad = non_breaking_total_zero_grad + breaking_total_zero_grad + current_zero_grad = breaking_total_zero_grad + assert optim_progress.optimizer.zero_grad.total.completed == total_zero_grad + assert optim_progress.optimizer.zero_grad.current.completed == current_zero_grad + + non_breaking_scheduler_steps = stop_epoch + breaking_scheduler_steps = 0 # the current epoch did not complete + if n_optimizers > 1: + # assumes that the scheduler config is unchanged + # `* 1` because there is only one step-level scheduler + non_breaking_scheduler_steps = ( + stop_epoch + non_breaking_epoch_batches_completed // accumulate_grad_batches + + has_leftover_accumulation_batches * 1 + ) + # `0 +` for the epoch-level scheduler + breaking_scheduler_steps = 0 + breaking_epoch_batches_completed // accumulate_grad_batches + total_scheduler_steps = non_breaking_scheduler_steps + breaking_scheduler_steps + current_scheduler_steps = breaking_scheduler_steps + assert scheduler_progress.total.completed == total_scheduler_steps + assert scheduler_progress.current.completed == current_scheduler_steps # yapf: disable expected = { @@ -377,10 +423,10 @@ def configure_optimizers_multiple(self): "epoch_loop.state_dict": {}, "epoch_loop.batch_progress": { "total": { - "ready": batches_seen + 1, - "started": batches_seen + 1, - "processed": batches_seen, - "completed": batches_seen, + "ready": non_breaking_epoch_batches_completed + breaking_epoch_batches_completed + 1, + "started": non_breaking_epoch_batches_completed + breaking_epoch_batches_completed + 1, + "processed": non_breaking_epoch_batches_completed + breaking_epoch_batches_completed, + "completed": non_breaking_epoch_batches_completed + breaking_epoch_batches_completed, }, "current": { "ready": stop_batch + 1, @@ -409,30 +455,30 @@ def configure_optimizers_multiple(self): "optimizer": { "step": { "total": { - "ready": total_optimizer_steps + 1, + "ready": total_optimizer_steps + has_optimizer_step_in_breaking_epoch, "started": None, "processed": None, "completed": total_optimizer_steps, }, "current": { - "ready": stop_optimizer + 1, + "ready": current_optimizer_steps + has_optimizer_step_in_breaking_epoch, "started": None, "processed": None, - "completed": stop_optimizer, + "completed": current_optimizer_steps, }, }, "zero_grad": { "total": { - "ready": total_optimizer_zero_grad, - "started": total_optimizer_zero_grad, + "ready": total_zero_grad, + "started": total_zero_grad, "processed": None, - "completed": total_optimizer_zero_grad, + "completed": total_zero_grad, }, "current": { - "ready": current_optimizer_zero_grad, - "started": current_optimizer_zero_grad, + "ready": current_zero_grad, + "started": current_zero_grad, "processed": None, - "completed": current_optimizer_zero_grad, + "completed": current_zero_grad, }, }, }, @@ -451,4 +497,6 @@ def configure_optimizers_multiple(self): trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"]) state_dict = trainer.fit_loop.state_dict() assert state_dict != checkpoint["loops"]["fit_loop"] - assert state_dict["epoch_progress"]["total"]["started"] == 1 + # TODO(@carmocca): do not reset for total + assert state_dict["epoch_progress"]["total"]["started"] == stop_epoch + assert state_dict["epoch_progress"]["current"]["started"] == stop_epoch From fc18c16adf508e9714e5bc513941a6ec82038ccb Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 15 Jul 2021 19:10:28 +0200 Subject: [PATCH 133/157] Formatting --- tests/loops/test_loops.py | 99 ++++++++++----------------------------- 1 file changed, 26 insertions(+), 73 deletions(-) diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 23e28338719e4..47a22512f2fd7 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -33,9 +33,7 @@ class CustomException(Exception): def test_loop_restore(): - class Simple(Loop): - def __init__(self, dataset: Iterator): super().__init__() self.dataset = dataset @@ -94,13 +92,11 @@ def load_state_dict(self, state_dict: Dict) -> None: def test_loop_hierarchy(): - @dataclass class SimpleProgress(BaseProgress): increment: int = 0 class Simple(Loop): - def __init__(self, a): super().__init__() self.a = a @@ -138,18 +134,10 @@ def on_load_checkpoint(self, state_dict: Dict) -> None: state_dict = loop_parent.state_dict() assert state_dict == { - 'state_dict': { - 'a': 1 - }, - 'progress': { - 'increment': 0 - }, - 'loop_child.state_dict': { - 'a': 2 - }, - 'loop_child.progress': { - 'increment': 0 - }, + 'state_dict': {'a': 1}, + 'progress': {'increment': 0}, + 'loop_child.state_dict': {'a': 2}, + 'loop_child.progress': {'increment': 0}, } state_dict["loop_child.state_dict"]["a"] = 3 @@ -162,18 +150,10 @@ def on_load_checkpoint(self, state_dict: Dict) -> None: # check the new state after `run` state_dict = loop_parent.state_dict() assert state_dict == { - 'state_dict': { - 'a': 1 - }, - 'progress': { - 'increment': 1 - }, - 'loop_child.state_dict': { - 'a': 3 - }, - 'loop_child.progress': { - 'increment': 1 - }, + 'state_dict': {'a': 1}, + 'progress': {'increment': 1}, + 'loop_child.state_dict': {'a': 3}, + 'loop_child.progress': {'increment': 1}, } loop_parent_copy = deepcopy(loop_parent) @@ -202,7 +182,6 @@ def test_loop_restart_progress_multiple_dataloaders(tmpdir): n_epochs = 2 class ValidationModel(BoringModel): - def __init__(self): super().__init__() @@ -237,12 +216,7 @@ def val_dataloader(self): total = (n_epochs - 1) * n_dataloaders + stop_dataloader expected = { - "total": { - "ready": total + 1, - "started": None, - "processed": None, - "completed": total - }, + "total": {"ready": total + 1, "started": None, "processed": None, "completed": total}, "current": { "ready": stop_dataloader + 1, "started": None, @@ -255,12 +229,7 @@ def val_dataloader(self): trainer.fit_loop.load_state_dict(checkpoint, restart_progress=False) total = n_dataloaders * n_batches + n_batches + stop_epoch expected = { - "total": { - "ready": total + 1, - "started": total + 1, - "processed": total, - "completed": total - }, + "total": {"ready": total + 1, "started": total + 1, "processed": total, "completed": total}, "current": { "ready": stop_batch + 1, "started": stop_batch + 1, @@ -272,18 +241,8 @@ def val_dataloader(self): trainer.fit_loop.load_state_dict(checkpoint) expected = { - "total": { - "ready": total, - "started": total, - "processed": total, - "completed": total - }, - "current": { - "ready": stop_batch, - "started": stop_batch, - "processed": stop_batch, - "completed": stop_batch - }, + "total": {"ready": total, "started": total, "processed": total, "completed": total}, + "current": {"ready": stop_batch, "started": stop_batch, "processed": stop_batch, "completed": stop_batch}, } assert trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.state_dict() == expected @@ -292,7 +251,7 @@ def val_dataloader(self): @pytest.mark.parametrize("accumulate_grad_batches", (1, 2)) # FIXME: 3 is broken @pytest.mark.parametrize("n_optimizers", (1, 3, 5)) @pytest.mark.parametrize("stop_epoch", (1, 2)) -@pytest.mark.parametrize("stop_batch", (1, )) # FIXME: 2 is broken +@pytest.mark.parametrize("stop_batch", (1,)) # FIXME: 2 is broken @pytest.mark.parametrize("stop_optimizer", (1, 2)) def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch, stop_optimizer, n_optimizers, tmpdir): stop_optimizer = stop_optimizer if stop_optimizer < n_optimizers else 0 @@ -300,7 +259,6 @@ def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch n_batches = 3 class TestModel(BoringModel): - def __init__(self): super().__init__() if n_optimizers > 1: @@ -351,37 +309,33 @@ def configure_optimizers_multiple(self): breaking_epoch_batches_ready = stop_batch + 1 # lightning applies leftover accumulated gradients when the epoch ends has_leftover_accumulation_batches = n_batches % accumulate_grad_batches != 0 + non_breaking_stepping_batches = non_breaking_epoch_batches_completed // accumulate_grad_batches + breaking_stepping_batches = breaking_epoch_batches_completed // accumulate_grad_batches non_breaking_total_optimizer_steps = ( - non_breaking_epoch_batches_completed // accumulate_grad_batches * n_optimizers - + has_leftover_accumulation_batches * n_optimizers - ) + non_breaking_stepping_batches + has_leftover_accumulation_batches + ) * n_optimizers should_last_batch_step = breaking_epoch_batches_ready % accumulate_grad_batches == 0 - breaking_total_optimizer_steps = ( - breaking_epoch_batches_completed // accumulate_grad_batches * n_optimizers - + should_last_batch_step * stop_optimizer - ) + breaking_total_optimizer_steps = breaking_stepping_batches * n_optimizers + should_last_batch_step * stop_optimizer total_optimizer_steps = non_breaking_total_optimizer_steps + breaking_total_optimizer_steps current_optimizer_steps = breaking_total_optimizer_steps has_optimizer_step_in_breaking_epoch = accumulate_grad_batches == 1 or n_batches % accumulate_grad_batches != 0 assert optim_progress.optimizer_steps == total_optimizer_steps assert optim_progress.optimizer.step.current.completed == current_optimizer_steps - non_breaking_total_zero_grad = ( - non_breaking_epoch_batches_completed // accumulate_grad_batches + has_leftover_accumulation_batches - ) * n_optimizers + non_breaking_total_zero_grad = (non_breaking_stepping_batches + has_leftover_accumulation_batches) * n_optimizers # FIXME: What the hell if accumulate_grad_batches > 1: # FIXME: ready or completed? 0 or stop_optimizer? breaking_total_zero_grad = ( - n_optimizers + (breaking_epoch_batches_ready // accumulate_grad_batches - (accumulate_grad_batches > 1)) * - (n_optimizers - 1) + 0 + n_optimizers + + (breaking_epoch_batches_ready // accumulate_grad_batches - (accumulate_grad_batches > 1)) + * (n_optimizers - 1) + + 0 ) # breaking_total_zero_grad = breaking_epoch_batches_ready // accumulate_grad_batches * n_optimizers + 0 else: - breaking_total_zero_grad = ( - breaking_epoch_batches_completed // accumulate_grad_batches * n_optimizers + stop_optimizer - ) + breaking_total_zero_grad = breaking_stepping_batches * n_optimizers + stop_optimizer total_zero_grad = non_breaking_total_zero_grad + breaking_total_zero_grad current_zero_grad = breaking_total_zero_grad assert optim_progress.optimizer.zero_grad.total.completed == total_zero_grad @@ -393,11 +347,10 @@ def configure_optimizers_multiple(self): # assumes that the scheduler config is unchanged # `* 1` because there is only one step-level scheduler non_breaking_scheduler_steps = ( - stop_epoch + non_breaking_epoch_batches_completed // accumulate_grad_batches - + has_leftover_accumulation_batches * 1 + stop_epoch + non_breaking_stepping_batches + has_leftover_accumulation_batches * 1 ) # `0 +` for the epoch-level scheduler - breaking_scheduler_steps = 0 + breaking_epoch_batches_completed // accumulate_grad_batches + breaking_scheduler_steps = 0 + breaking_stepping_batches total_scheduler_steps = non_breaking_scheduler_steps + breaking_scheduler_steps current_scheduler_steps = breaking_scheduler_steps assert scheduler_progress.total.completed == total_scheduler_steps From e550e6d03c8f55fbd57c7485d8d4906bca1214a0 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 15 Jul 2021 19:47:08 +0200 Subject: [PATCH 134/157] Rename and clean variables --- tests/loops/test_loops.py | 173 ++++++++++++++++++++++---------------- 1 file changed, 102 insertions(+), 71 deletions(-) diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 47a22512f2fd7..49a9570649027 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -33,7 +33,9 @@ class CustomException(Exception): def test_loop_restore(): + class Simple(Loop): + def __init__(self, dataset: Iterator): super().__init__() self.dataset = dataset @@ -92,11 +94,13 @@ def load_state_dict(self, state_dict: Dict) -> None: def test_loop_hierarchy(): + @dataclass class SimpleProgress(BaseProgress): increment: int = 0 class Simple(Loop): + def __init__(self, a): super().__init__() self.a = a @@ -134,10 +138,18 @@ def on_load_checkpoint(self, state_dict: Dict) -> None: state_dict = loop_parent.state_dict() assert state_dict == { - 'state_dict': {'a': 1}, - 'progress': {'increment': 0}, - 'loop_child.state_dict': {'a': 2}, - 'loop_child.progress': {'increment': 0}, + 'state_dict': { + 'a': 1 + }, + 'progress': { + 'increment': 0 + }, + 'loop_child.state_dict': { + 'a': 2 + }, + 'loop_child.progress': { + 'increment': 0 + }, } state_dict["loop_child.state_dict"]["a"] = 3 @@ -150,10 +162,18 @@ def on_load_checkpoint(self, state_dict: Dict) -> None: # check the new state after `run` state_dict = loop_parent.state_dict() assert state_dict == { - 'state_dict': {'a': 1}, - 'progress': {'increment': 1}, - 'loop_child.state_dict': {'a': 3}, - 'loop_child.progress': {'increment': 1}, + 'state_dict': { + 'a': 1 + }, + 'progress': { + 'increment': 1 + }, + 'loop_child.state_dict': { + 'a': 3 + }, + 'loop_child.progress': { + 'increment': 1 + }, } loop_parent_copy = deepcopy(loop_parent) @@ -182,6 +202,7 @@ def test_loop_restart_progress_multiple_dataloaders(tmpdir): n_epochs = 2 class ValidationModel(BoringModel): + def __init__(self): super().__init__() @@ -216,7 +237,12 @@ def val_dataloader(self): total = (n_epochs - 1) * n_dataloaders + stop_dataloader expected = { - "total": {"ready": total + 1, "started": None, "processed": None, "completed": total}, + "total": { + "ready": total + 1, + "started": None, + "processed": None, + "completed": total + }, "current": { "ready": stop_dataloader + 1, "started": None, @@ -229,7 +255,12 @@ def val_dataloader(self): trainer.fit_loop.load_state_dict(checkpoint, restart_progress=False) total = n_dataloaders * n_batches + n_batches + stop_epoch expected = { - "total": {"ready": total + 1, "started": total + 1, "processed": total, "completed": total}, + "total": { + "ready": total + 1, + "started": total + 1, + "processed": total, + "completed": total + }, "current": { "ready": stop_batch + 1, "started": stop_batch + 1, @@ -241,8 +272,18 @@ def val_dataloader(self): trainer.fit_loop.load_state_dict(checkpoint) expected = { - "total": {"ready": total, "started": total, "processed": total, "completed": total}, - "current": {"ready": stop_batch, "started": stop_batch, "processed": stop_batch, "completed": stop_batch}, + "total": { + "ready": total, + "started": total, + "processed": total, + "completed": total + }, + "current": { + "ready": stop_batch, + "started": stop_batch, + "processed": stop_batch, + "completed": stop_batch + }, } assert trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.state_dict() == expected @@ -251,7 +292,7 @@ def val_dataloader(self): @pytest.mark.parametrize("accumulate_grad_batches", (1, 2)) # FIXME: 3 is broken @pytest.mark.parametrize("n_optimizers", (1, 3, 5)) @pytest.mark.parametrize("stop_epoch", (1, 2)) -@pytest.mark.parametrize("stop_batch", (1,)) # FIXME: 2 is broken +@pytest.mark.parametrize("stop_batch", (1, )) # FIXME: 2 is broken @pytest.mark.parametrize("stop_optimizer", (1, 2)) def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch, stop_optimizer, n_optimizers, tmpdir): stop_optimizer = stop_optimizer if stop_optimizer < n_optimizers else 0 @@ -259,6 +300,7 @@ def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch n_batches = 3 class TestModel(BoringModel): + def __init__(self): super().__init__() if n_optimizers > 1: @@ -304,57 +346,46 @@ def configure_optimizers_multiple(self): optim_progress = trainer.fit_loop.epoch_loop.batch_loop.optim_progress scheduler_progress = trainer.fit_loop.epoch_loop.scheduler_progress - non_breaking_epoch_batches_completed = stop_epoch * n_batches - breaking_epoch_batches_completed = stop_batch - breaking_epoch_batches_ready = stop_batch + 1 + # `nb_`: non-breaking, as in, no exception will be raised. `b_`: breaking + nb_epoch_batches_completed = stop_epoch * n_batches + b_epoch_batches_completed = stop_batch + b_epoch_batches_ready = stop_batch + 1 # lightning applies leftover accumulated gradients when the epoch ends has_leftover_accumulation_batches = n_batches % accumulate_grad_batches != 0 - non_breaking_stepping_batches = non_breaking_epoch_batches_completed // accumulate_grad_batches - breaking_stepping_batches = breaking_epoch_batches_completed // accumulate_grad_batches - - non_breaking_total_optimizer_steps = ( - non_breaking_stepping_batches + has_leftover_accumulation_batches - ) * n_optimizers - should_last_batch_step = breaking_epoch_batches_ready % accumulate_grad_batches == 0 - breaking_total_optimizer_steps = breaking_stepping_batches * n_optimizers + should_last_batch_step * stop_optimizer - total_optimizer_steps = non_breaking_total_optimizer_steps + breaking_total_optimizer_steps - current_optimizer_steps = breaking_total_optimizer_steps - has_optimizer_step_in_breaking_epoch = accumulate_grad_batches == 1 or n_batches % accumulate_grad_batches != 0 - assert optim_progress.optimizer_steps == total_optimizer_steps - assert optim_progress.optimizer.step.current.completed == current_optimizer_steps - - non_breaking_total_zero_grad = (non_breaking_stepping_batches + has_leftover_accumulation_batches) * n_optimizers + nb_stepping_batches = nb_epoch_batches_completed // accumulate_grad_batches + b_stepping_batches = b_epoch_batches_completed // accumulate_grad_batches + + nb_total_optimizer_steps = (nb_stepping_batches + has_leftover_accumulation_batches) * n_optimizers + should_last_batch_step = b_epoch_batches_ready % accumulate_grad_batches == 0 + b_total_optimizer_steps = b_stepping_batches * n_optimizers + should_last_batch_step * stop_optimizer + has_optimizer_step_in_b_epoch = accumulate_grad_batches == 1 or n_batches % accumulate_grad_batches != 0 + assert optim_progress.optimizer_steps == nb_total_optimizer_steps + b_total_optimizer_steps + assert optim_progress.optimizer.step.current.completed == b_total_optimizer_steps + + nb_total_zero_grad = (nb_stepping_batches + has_leftover_accumulation_batches) * n_optimizers # FIXME: What the hell if accumulate_grad_batches > 1: # FIXME: ready or completed? 0 or stop_optimizer? - breaking_total_zero_grad = ( - n_optimizers - + (breaking_epoch_batches_ready // accumulate_grad_batches - (accumulate_grad_batches > 1)) - * (n_optimizers - 1) - + 0 + b_total_zero_grad = ( + n_optimizers + (b_epoch_batches_ready // accumulate_grad_batches - (accumulate_grad_batches > 1)) * + (n_optimizers - 1) + 0 ) - # breaking_total_zero_grad = breaking_epoch_batches_ready // accumulate_grad_batches * n_optimizers + 0 + # b_total_zero_grad = b_epoch_batches_ready // accumulate_grad_batches * n_optimizers + 0 else: - breaking_total_zero_grad = breaking_stepping_batches * n_optimizers + stop_optimizer - total_zero_grad = non_breaking_total_zero_grad + breaking_total_zero_grad - current_zero_grad = breaking_total_zero_grad - assert optim_progress.optimizer.zero_grad.total.completed == total_zero_grad - assert optim_progress.optimizer.zero_grad.current.completed == current_zero_grad - - non_breaking_scheduler_steps = stop_epoch - breaking_scheduler_steps = 0 # the current epoch did not complete + b_total_zero_grad = b_stepping_batches * n_optimizers + stop_optimizer + assert optim_progress.optimizer.zero_grad.total.completed == nb_total_zero_grad + b_total_zero_grad + assert optim_progress.optimizer.zero_grad.current.completed == b_total_zero_grad + + nb_scheduler_steps = stop_epoch + b_scheduler_steps = 0 # the current epoch did not complete if n_optimizers > 1: # assumes that the scheduler config is unchanged # `* 1` because there is only one step-level scheduler - non_breaking_scheduler_steps = ( - stop_epoch + non_breaking_stepping_batches + has_leftover_accumulation_batches * 1 - ) + nb_scheduler_steps = stop_epoch + nb_stepping_batches + has_leftover_accumulation_batches * 1 # `0 +` for the epoch-level scheduler - breaking_scheduler_steps = 0 + breaking_stepping_batches - total_scheduler_steps = non_breaking_scheduler_steps + breaking_scheduler_steps - current_scheduler_steps = breaking_scheduler_steps - assert scheduler_progress.total.completed == total_scheduler_steps - assert scheduler_progress.current.completed == current_scheduler_steps + b_scheduler_steps = 0 + b_stepping_batches + assert scheduler_progress.total.completed == nb_scheduler_steps + b_scheduler_steps + assert scheduler_progress.current.completed == b_scheduler_steps # yapf: disable expected = { @@ -376,10 +407,10 @@ def configure_optimizers_multiple(self): "epoch_loop.state_dict": {}, "epoch_loop.batch_progress": { "total": { - "ready": non_breaking_epoch_batches_completed + breaking_epoch_batches_completed + 1, - "started": non_breaking_epoch_batches_completed + breaking_epoch_batches_completed + 1, - "processed": non_breaking_epoch_batches_completed + breaking_epoch_batches_completed, - "completed": non_breaking_epoch_batches_completed + breaking_epoch_batches_completed, + "ready": nb_epoch_batches_completed + b_epoch_batches_completed + 1, + "started": nb_epoch_batches_completed + b_epoch_batches_completed + 1, + "processed": nb_epoch_batches_completed + b_epoch_batches_completed, + "completed": nb_epoch_batches_completed + b_epoch_batches_completed, }, "current": { "ready": stop_batch + 1, @@ -390,16 +421,16 @@ def configure_optimizers_multiple(self): }, "epoch_loop.scheduler_progress": { "total": { - "ready": total_scheduler_steps, + "ready": nb_scheduler_steps + b_scheduler_steps, "started": None, "processed": None, - "completed": total_scheduler_steps, + "completed": nb_scheduler_steps + b_scheduler_steps, }, "current": { - "ready": current_scheduler_steps, + "ready": b_scheduler_steps, "started": None, "processed": None, - "completed": current_scheduler_steps, + "completed": b_scheduler_steps, }, }, "epoch_loop.batch_loop.state_dict": {}, @@ -408,30 +439,30 @@ def configure_optimizers_multiple(self): "optimizer": { "step": { "total": { - "ready": total_optimizer_steps + has_optimizer_step_in_breaking_epoch, + "ready": nb_total_optimizer_steps + b_total_optimizer_steps + has_optimizer_step_in_b_epoch, "started": None, "processed": None, - "completed": total_optimizer_steps, + "completed": nb_total_optimizer_steps + b_total_optimizer_steps, }, "current": { - "ready": current_optimizer_steps + has_optimizer_step_in_breaking_epoch, + "ready": b_total_optimizer_steps + has_optimizer_step_in_b_epoch, "started": None, "processed": None, - "completed": current_optimizer_steps, + "completed": b_total_optimizer_steps, }, }, "zero_grad": { "total": { - "ready": total_zero_grad, - "started": total_zero_grad, + "ready": nb_total_zero_grad + b_total_zero_grad, + "started": nb_total_zero_grad + b_total_zero_grad, "processed": None, - "completed": total_zero_grad, + "completed": nb_total_zero_grad + b_total_zero_grad, }, "current": { - "ready": current_zero_grad, - "started": current_zero_grad, + "ready": b_total_zero_grad, + "started": b_total_zero_grad, "processed": None, - "completed": current_zero_grad, + "completed": b_total_zero_grad, }, }, }, From 01a8a456969be25bec24a1b22f878ef5dba320a4 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 15 Jul 2021 20:03:50 +0200 Subject: [PATCH 135/157] Shorter names --- tests/loops/test_loops.py | 93 ++++++++++++++++++++------------------- 1 file changed, 47 insertions(+), 46 deletions(-) diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 49a9570649027..d4e096641d74f 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -332,6 +332,7 @@ def configure_optimizers_multiple(self): accumulate_grad_batches=accumulate_grad_batches, progress_bar_refresh_rate=0, logger=False, + checkpoint_callback=False, ) # simulate a failure @@ -346,46 +347,46 @@ def configure_optimizers_multiple(self): optim_progress = trainer.fit_loop.epoch_loop.batch_loop.optim_progress scheduler_progress = trainer.fit_loop.epoch_loop.scheduler_progress - # `nb_`: non-breaking, as in, no exception will be raised. `b_`: breaking - nb_epoch_batches_completed = stop_epoch * n_batches - b_epoch_batches_completed = stop_batch - b_epoch_batches_ready = stop_batch + 1 + # `nbe_`: non-breaking epoch, as in, no exception will be raised. `be_`: breaking epoch + nbe_batches_completed = stop_epoch * n_batches + be_batches_completed = stop_batch + be_batches_ready = stop_batch + 1 # lightning applies leftover accumulated gradients when the epoch ends has_leftover_accumulation_batches = n_batches % accumulate_grad_batches != 0 - nb_stepping_batches = nb_epoch_batches_completed // accumulate_grad_batches - b_stepping_batches = b_epoch_batches_completed // accumulate_grad_batches - - nb_total_optimizer_steps = (nb_stepping_batches + has_leftover_accumulation_batches) * n_optimizers - should_last_batch_step = b_epoch_batches_ready % accumulate_grad_batches == 0 - b_total_optimizer_steps = b_stepping_batches * n_optimizers + should_last_batch_step * stop_optimizer - has_optimizer_step_in_b_epoch = accumulate_grad_batches == 1 or n_batches % accumulate_grad_batches != 0 - assert optim_progress.optimizer_steps == nb_total_optimizer_steps + b_total_optimizer_steps - assert optim_progress.optimizer.step.current.completed == b_total_optimizer_steps - - nb_total_zero_grad = (nb_stepping_batches + has_leftover_accumulation_batches) * n_optimizers - # FIXME: What the hell + # number of batches that will call `optimizer.step()` during non-breaking and breaking epochs + nbe_stepping_batches = nbe_batches_completed // accumulate_grad_batches + be_stepping_batches = be_batches_completed // accumulate_grad_batches + + nbe_total_opt_steps = (nbe_stepping_batches + has_leftover_accumulation_batches) * n_optimizers + is_last_batch_stepping = be_batches_ready % accumulate_grad_batches == 0 + be_total_opt_steps = be_stepping_batches * n_optimizers + is_last_batch_stepping * stop_optimizer + assert optim_progress.optimizer_steps == nbe_total_opt_steps + be_total_opt_steps + assert optim_progress.optimizer.step.current.completed == be_total_opt_steps + has_opt_stepped_in_be = accumulate_grad_batches == 1 or n_batches % accumulate_grad_batches != 0 + + nbe_total_zero_grad = (nbe_stepping_batches + has_leftover_accumulation_batches) * n_optimizers if accumulate_grad_batches > 1: # FIXME: ready or completed? 0 or stop_optimizer? - b_total_zero_grad = ( - n_optimizers + (b_epoch_batches_ready // accumulate_grad_batches - (accumulate_grad_batches > 1)) * + be_total_zero_grad = ( + n_optimizers + (be_batches_ready // accumulate_grad_batches - (accumulate_grad_batches > 1)) * (n_optimizers - 1) + 0 ) - # b_total_zero_grad = b_epoch_batches_ready // accumulate_grad_batches * n_optimizers + 0 + # be_total_zero_grad = be_epoch_batches_ready // accumulate_grad_batches * n_optimizers + 0 else: - b_total_zero_grad = b_stepping_batches * n_optimizers + stop_optimizer - assert optim_progress.optimizer.zero_grad.total.completed == nb_total_zero_grad + b_total_zero_grad - assert optim_progress.optimizer.zero_grad.current.completed == b_total_zero_grad + be_total_zero_grad = be_stepping_batches * n_optimizers + stop_optimizer + assert optim_progress.optimizer.zero_grad.total.completed == nbe_total_zero_grad + be_total_zero_grad + assert optim_progress.optimizer.zero_grad.current.completed == be_total_zero_grad - nb_scheduler_steps = stop_epoch - b_scheduler_steps = 0 # the current epoch did not complete + nbe_scheduler_steps = stop_epoch + be_scheduler_steps = 0 # the current epoch did not complete if n_optimizers > 1: # assumes that the scheduler config is unchanged # `* 1` because there is only one step-level scheduler - nb_scheduler_steps = stop_epoch + nb_stepping_batches + has_leftover_accumulation_batches * 1 + nbe_scheduler_steps = stop_epoch + nbe_stepping_batches + has_leftover_accumulation_batches * 1 # `0 +` for the epoch-level scheduler - b_scheduler_steps = 0 + b_stepping_batches - assert scheduler_progress.total.completed == nb_scheduler_steps + b_scheduler_steps - assert scheduler_progress.current.completed == b_scheduler_steps + be_scheduler_steps = 0 + be_stepping_batches + assert scheduler_progress.total.completed == nbe_scheduler_steps + be_scheduler_steps + assert scheduler_progress.current.completed == be_scheduler_steps # yapf: disable expected = { @@ -407,10 +408,10 @@ def configure_optimizers_multiple(self): "epoch_loop.state_dict": {}, "epoch_loop.batch_progress": { "total": { - "ready": nb_epoch_batches_completed + b_epoch_batches_completed + 1, - "started": nb_epoch_batches_completed + b_epoch_batches_completed + 1, - "processed": nb_epoch_batches_completed + b_epoch_batches_completed, - "completed": nb_epoch_batches_completed + b_epoch_batches_completed, + "ready": nbe_batches_completed + be_batches_completed + 1, + "started": nbe_batches_completed + be_batches_completed + 1, + "processed": nbe_batches_completed + be_batches_completed, + "completed": nbe_batches_completed + be_batches_completed, }, "current": { "ready": stop_batch + 1, @@ -421,16 +422,16 @@ def configure_optimizers_multiple(self): }, "epoch_loop.scheduler_progress": { "total": { - "ready": nb_scheduler_steps + b_scheduler_steps, + "ready": nbe_scheduler_steps + be_scheduler_steps, "started": None, "processed": None, - "completed": nb_scheduler_steps + b_scheduler_steps, + "completed": nbe_scheduler_steps + be_scheduler_steps, }, "current": { - "ready": b_scheduler_steps, + "ready": be_scheduler_steps, "started": None, "processed": None, - "completed": b_scheduler_steps, + "completed": be_scheduler_steps, }, }, "epoch_loop.batch_loop.state_dict": {}, @@ -439,30 +440,30 @@ def configure_optimizers_multiple(self): "optimizer": { "step": { "total": { - "ready": nb_total_optimizer_steps + b_total_optimizer_steps + has_optimizer_step_in_b_epoch, + "ready": nbe_total_opt_steps + be_total_opt_steps + has_opt_stepped_in_be, "started": None, "processed": None, - "completed": nb_total_optimizer_steps + b_total_optimizer_steps, + "completed": nbe_total_opt_steps + be_total_opt_steps, }, "current": { - "ready": b_total_optimizer_steps + has_optimizer_step_in_b_epoch, + "ready": be_total_opt_steps + has_opt_stepped_in_be, "started": None, "processed": None, - "completed": b_total_optimizer_steps, + "completed": be_total_opt_steps, }, }, "zero_grad": { "total": { - "ready": nb_total_zero_grad + b_total_zero_grad, - "started": nb_total_zero_grad + b_total_zero_grad, + "ready": nbe_total_zero_grad + be_total_zero_grad, + "started": nbe_total_zero_grad + be_total_zero_grad, "processed": None, - "completed": nb_total_zero_grad + b_total_zero_grad, + "completed": nbe_total_zero_grad + be_total_zero_grad, }, "current": { - "ready": b_total_zero_grad, - "started": b_total_zero_grad, + "ready": be_total_zero_grad, + "started": be_total_zero_grad, "processed": None, - "completed": b_total_zero_grad, + "completed": be_total_zero_grad, }, }, }, From 1a6c2a1d40380ad7d090866aa22c6b2a3cd62c63 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 15 Jul 2021 20:07:15 +0200 Subject: [PATCH 136/157] Shorter scheduler name --- tests/loops/test_loops.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index d4e096641d74f..db46a974ce340 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -345,7 +345,7 @@ def configure_optimizers_multiple(self): checkpoint = torch.load(ckpt_path) optim_progress = trainer.fit_loop.epoch_loop.batch_loop.optim_progress - scheduler_progress = trainer.fit_loop.epoch_loop.scheduler_progress + sch_progress = trainer.fit_loop.epoch_loop.scheduler_progress # `nbe_`: non-breaking epoch, as in, no exception will be raised. `be_`: breaking epoch nbe_batches_completed = stop_epoch * n_batches @@ -368,8 +368,9 @@ def configure_optimizers_multiple(self): if accumulate_grad_batches > 1: # FIXME: ready or completed? 0 or stop_optimizer? be_total_zero_grad = ( - n_optimizers + (be_batches_ready // accumulate_grad_batches - (accumulate_grad_batches > 1)) * - (n_optimizers - 1) + 0 + n_optimizers + + (be_batches_ready // accumulate_grad_batches - (accumulate_grad_batches > 1)) * (n_optimizers - 1) + + 0 ) # be_total_zero_grad = be_epoch_batches_ready // accumulate_grad_batches * n_optimizers + 0 else: @@ -377,16 +378,16 @@ def configure_optimizers_multiple(self): assert optim_progress.optimizer.zero_grad.total.completed == nbe_total_zero_grad + be_total_zero_grad assert optim_progress.optimizer.zero_grad.current.completed == be_total_zero_grad - nbe_scheduler_steps = stop_epoch - be_scheduler_steps = 0 # the current epoch did not complete + nbe_sch_steps = stop_epoch + be_sch_steps = 0 # the current epoch did not complete if n_optimizers > 1: # assumes that the scheduler config is unchanged # `* 1` because there is only one step-level scheduler - nbe_scheduler_steps = stop_epoch + nbe_stepping_batches + has_leftover_accumulation_batches * 1 + nbe_sch_steps = stop_epoch + nbe_stepping_batches + has_leftover_accumulation_batches * 1 # `0 +` for the epoch-level scheduler - be_scheduler_steps = 0 + be_stepping_batches - assert scheduler_progress.total.completed == nbe_scheduler_steps + be_scheduler_steps - assert scheduler_progress.current.completed == be_scheduler_steps + be_sch_steps = 0 + be_stepping_batches + assert sch_progress.total.completed == nbe_sch_steps + be_sch_steps + assert sch_progress.current.completed == be_sch_steps # yapf: disable expected = { @@ -422,16 +423,16 @@ def configure_optimizers_multiple(self): }, "epoch_loop.scheduler_progress": { "total": { - "ready": nbe_scheduler_steps + be_scheduler_steps, + "ready": nbe_sch_steps + be_sch_steps, "started": None, "processed": None, - "completed": nbe_scheduler_steps + be_scheduler_steps, + "completed": nbe_sch_steps + be_sch_steps, }, "current": { - "ready": be_scheduler_steps, + "ready": be_sch_steps, "started": None, "processed": None, - "completed": be_scheduler_steps, + "completed": be_sch_steps, }, }, "epoch_loop.batch_loop.state_dict": {}, From e1906b75fce7767dcc5fd2574fbfca68eb39a76a Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 15 Jul 2021 20:27:26 +0200 Subject: [PATCH 137/157] Fix optimizer step calculation for stop_batch=2 --- tests/loops/test_loops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index db46a974ce340..42e84f502473a 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -358,8 +358,8 @@ def configure_optimizers_multiple(self): be_stepping_batches = be_batches_completed // accumulate_grad_batches nbe_total_opt_steps = (nbe_stepping_batches + has_leftover_accumulation_batches) * n_optimizers - is_last_batch_stepping = be_batches_ready % accumulate_grad_batches == 0 - be_total_opt_steps = be_stepping_batches * n_optimizers + is_last_batch_stepping * stop_optimizer + is_last_be_batch_stepping = be_batches_ready % accumulate_grad_batches == 0 or has_leftover_accumulation_batches + be_total_opt_steps = be_stepping_batches * n_optimizers + is_last_be_batch_stepping * stop_optimizer assert optim_progress.optimizer_steps == nbe_total_opt_steps + be_total_opt_steps assert optim_progress.optimizer.step.current.completed == be_total_opt_steps has_opt_stepped_in_be = accumulate_grad_batches == 1 or n_batches % accumulate_grad_batches != 0 From 5eaf5b3421d5aa6963eea21703d07d87fe574bd6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 15 Jul 2021 18:35:20 +0000 Subject: [PATCH 138/157] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/loops/test_loops.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 42e84f502473a..f3beaf5332e44 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -368,9 +368,8 @@ def configure_optimizers_multiple(self): if accumulate_grad_batches > 1: # FIXME: ready or completed? 0 or stop_optimizer? be_total_zero_grad = ( - n_optimizers - + (be_batches_ready // accumulate_grad_batches - (accumulate_grad_batches > 1)) * (n_optimizers - 1) - + 0 + n_optimizers + (be_batches_ready // accumulate_grad_batches - (accumulate_grad_batches > 1)) * + (n_optimizers - 1) + 0 ) # be_total_zero_grad = be_epoch_batches_ready // accumulate_grad_batches * n_optimizers + 0 else: From 29ce5528d572d7f005b29b936ff05a48c595b72a Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 15 Jul 2021 20:37:11 +0200 Subject: [PATCH 139/157] Remove empty connects --- pytorch_lightning/loops/batch/training_batch_loop.py | 5 ----- pytorch_lightning/loops/epoch/evaluation_epoch_loop.py | 5 ----- 2 files changed, 10 deletions(-) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index b7a5eceae916e..3e5a8081f9eca 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -23,7 +23,6 @@ from torch import Tensor from torch.optim import Optimizer -import pytorch_lightning as pl from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.loops.base import Loop from pytorch_lightning.plugins import ParallelPlugin @@ -58,10 +57,6 @@ def __init__(self) -> None: self._remaining_splits: Optional[List[Any]] = None self._skip_backward: bool = False - def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None: - """Connects the loop with necessary arguments like the trainer""" - super().connect(trainer, *args, **kwargs) - @property def done(self) -> bool: """Returns if all batch splits have been processed already""" diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index 1a0b0f9c8bd9b..bd697d8cc8653 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -18,7 +18,6 @@ from deprecate import void from torch import Tensor -import pytorch_lightning as pl from pytorch_lightning.loops.base import Loop from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.progress import Progress @@ -42,10 +41,6 @@ def __init__(self) -> None: self.outputs: List[STEP_OUTPUT] = [] self.batch_progress = Progress() - def connect(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: - """Connects the loop with necessary arguments like the trainer""" - super().connect(trainer, *args, **kwargs) - @property def done(self) -> bool: """Returns ``True`` if the current iteration count reaches the number of dataloader batches.""" From 398457896b75744b751f486aa80e5a9666b18929 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 15 Jul 2021 20:47:28 +0200 Subject: [PATCH 140/157] Update CHANGELOG --- CHANGELOG.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fb1e20572eb01..e63657851449d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,7 +33,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Progress tracking - * Added dataclasses for progress tracking ([#6603](https://github.com/PyTorchLightning/pytorch-lightning/pull/6603), [#7574](https://github.com/PyTorchLightning/pytorch-lightning/pull/7574), [#8140](https://github.com/PyTorchLightning/pytorch-lightning/pull/8140)) + * Added dataclasses for progress tracking ([#6603](https://github.com/PyTorchLightning/pytorch-lightning/pull/6603), [#7574](https://github.com/PyTorchLightning/pytorch-lightning/pull/7574), [#8140](https://github.com/PyTorchLightning/pytorch-lightning/pull/8140), [#8362](https://github.com/PyTorchLightning/pytorch-lightning/pull/8362)) * Add `{,load_}state_dict` to the progress tracking dataclasses ([#8140](https://github.com/PyTorchLightning/pytorch-lightning/pull/8140)) * Connect the progress tracking dataclasses to the loops ([#8244](https://github.com/PyTorchLightning/pytorch-lightning/pull/8244), [#8362](https://github.com/PyTorchLightning/pytorch-lightning/pull/8362)) @@ -92,6 +92,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fault-tolerant training * Added `{,load_}state_dict` to `ResultCollection` ([#7948](https://github.com/PyTorchLightning/pytorch-lightning/pull/7948)) * Added `{,load_}state_dict` to `Loops` ([#8197](https://github.com/PyTorchLightning/pytorch-lightning/pull/8197)) + * Set `Loop.restarting=False` at the end of the `run` in the first iteration ([#8362](https://github.com/PyTorchLightning/pytorch-lightning/pull/8362)) + * Save the loops state with the checkpoint (opt-in) ([#8362](https://github.com/PyTorchLightning/pytorch-lightning/pull/8362)) + * Save a checkpoint to restore the state on exception (opt-in) ([#8362](https://github.com/PyTorchLightning/pytorch-lightning/pull/8362)) - Added `rank_zero_only` to `LightningModule.log` function ([#7966](https://github.com/PyTorchLightning/pytorch-lightning/pull/7966)) From 70a9bcac224cdd12bd5193e3feeee22b87f55ca5 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 16 Jul 2021 01:49:26 +0200 Subject: [PATCH 141/157] Holy shit finally got the formula right --- tests/loops/test_loops.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index f3beaf5332e44..0f8dd58cb364a 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -289,7 +289,7 @@ def val_dataloader(self): @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) -@pytest.mark.parametrize("accumulate_grad_batches", (1, 2)) # FIXME: 3 is broken +@pytest.mark.parametrize("accumulate_grad_batches", (1, 2, 3)) @pytest.mark.parametrize("n_optimizers", (1, 3, 5)) @pytest.mark.parametrize("stop_epoch", (1, 2)) @pytest.mark.parametrize("stop_batch", (1, )) # FIXME: 2 is broken @@ -358,22 +358,16 @@ def configure_optimizers_multiple(self): be_stepping_batches = be_batches_completed // accumulate_grad_batches nbe_total_opt_steps = (nbe_stepping_batches + has_leftover_accumulation_batches) * n_optimizers - is_last_be_batch_stepping = be_batches_ready % accumulate_grad_batches == 0 or has_leftover_accumulation_batches - be_total_opt_steps = be_stepping_batches * n_optimizers + is_last_be_batch_stepping * stop_optimizer + does_last_be_batch_step = be_batches_ready % accumulate_grad_batches == 0 or has_leftover_accumulation_batches + be_total_opt_steps = be_stepping_batches * n_optimizers + does_last_be_batch_step * stop_optimizer assert optim_progress.optimizer_steps == nbe_total_opt_steps + be_total_opt_steps assert optim_progress.optimizer.step.current.completed == be_total_opt_steps has_opt_stepped_in_be = accumulate_grad_batches == 1 or n_batches % accumulate_grad_batches != 0 nbe_total_zero_grad = (nbe_stepping_batches + has_leftover_accumulation_batches) * n_optimizers - if accumulate_grad_batches > 1: - # FIXME: ready or completed? 0 or stop_optimizer? - be_total_zero_grad = ( - n_optimizers + (be_batches_ready // accumulate_grad_batches - (accumulate_grad_batches > 1)) * - (n_optimizers - 1) + 0 - ) - # be_total_zero_grad = be_epoch_batches_ready // accumulate_grad_batches * n_optimizers + 0 - else: - be_total_zero_grad = be_stepping_batches * n_optimizers + stop_optimizer + does_last_be_batch_zero_grad = be_batches_completed % accumulate_grad_batches == 0 + # `max` because the first batch always zero-grads + be_total_zero_grad = max(1, be_stepping_batches) * n_optimizers + stop_optimizer * does_last_be_batch_zero_grad assert optim_progress.optimizer.zero_grad.total.completed == nbe_total_zero_grad + be_total_zero_grad assert optim_progress.optimizer.zero_grad.current.completed == be_total_zero_grad From ae94d7a34c96e4e51be95397cc069bc8c32378df Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 16 Jul 2021 02:22:14 +0200 Subject: [PATCH 142/157] Fix final thing!!! --- tests/loops/test_loops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 0f8dd58cb364a..2af9d941b3b34 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -292,7 +292,7 @@ def val_dataloader(self): @pytest.mark.parametrize("accumulate_grad_batches", (1, 2, 3)) @pytest.mark.parametrize("n_optimizers", (1, 3, 5)) @pytest.mark.parametrize("stop_epoch", (1, 2)) -@pytest.mark.parametrize("stop_batch", (1, )) # FIXME: 2 is broken +@pytest.mark.parametrize("stop_batch", (1, 2)) @pytest.mark.parametrize("stop_optimizer", (1, 2)) def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch, stop_optimizer, n_optimizers, tmpdir): stop_optimizer = stop_optimizer if stop_optimizer < n_optimizers else 0 @@ -362,7 +362,7 @@ def configure_optimizers_multiple(self): be_total_opt_steps = be_stepping_batches * n_optimizers + does_last_be_batch_step * stop_optimizer assert optim_progress.optimizer_steps == nbe_total_opt_steps + be_total_opt_steps assert optim_progress.optimizer.step.current.completed == be_total_opt_steps - has_opt_stepped_in_be = accumulate_grad_batches == 1 or n_batches % accumulate_grad_batches != 0 + has_opt_stepped_in_be = stop_batch + 1 >= accumulate_grad_batches nbe_total_zero_grad = (nbe_stepping_batches + has_leftover_accumulation_batches) * n_optimizers does_last_be_batch_zero_grad = be_batches_completed % accumulate_grad_batches == 0 From 83b3dd60826707d594df7738ae62948df3426a44 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 16 Jul 2021 02:23:50 +0200 Subject: [PATCH 143/157] Do not check state dicts --- tests/loops/test_loops.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 2af9d941b3b34..28edb52de055c 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -384,7 +384,7 @@ def configure_optimizers_multiple(self): # yapf: disable expected = { - "state_dict": {}, + "state_dict": ANY, "epoch_progress": { "total": { "ready": stop_epoch + 1, @@ -399,7 +399,7 @@ def configure_optimizers_multiple(self): "completed": stop_epoch, }, }, - "epoch_loop.state_dict": {}, + "epoch_loop.state_dict": ANY, "epoch_loop.batch_progress": { "total": { "ready": nbe_batches_completed + be_batches_completed + 1, @@ -428,7 +428,7 @@ def configure_optimizers_multiple(self): "completed": be_sch_steps, }, }, - "epoch_loop.batch_loop.state_dict": {}, + "epoch_loop.batch_loop.state_dict": ANY, "epoch_loop.batch_loop.optim_progress": { "optimizer_idx": stop_optimizer, "optimizer": { From 5af97306111be7652826581eb9fd36612373bdb5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 16 Jul 2021 04:11:30 +0200 Subject: [PATCH 144/157] parametrize multiple_dataloader progress test --- tests/loops/test_loops.py | 40 +++++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 28edb52de055c..5173736812e08 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -195,11 +195,12 @@ def on_load_checkpoint(self, state_dict: Dict) -> None: @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) -def test_loop_restart_progress_multiple_dataloaders(tmpdir): - stop_epoch = stop_batch = stop_dataloader = 1 - n_dataloaders = 3 - n_batches = 3 - n_epochs = 2 +@pytest.mark.parametrize("stop_epoch", (1, 2)) +@pytest.mark.parametrize("stop_batch", (1, 2)) +@pytest.mark.parametrize("n_dataloaders,stop_dataloader", [(2, 0), (2, 1), (3, 2)]) +def test_loop_restart_progress_multiple_dataloaders(tmpdir, n_dataloaders, stop_dataloader, stop_epoch, stop_batch): + n_batches = 5 + n_epochs = 3 class ValidationModel(BoringModel): @@ -222,7 +223,6 @@ def val_dataloader(self): max_epochs=n_epochs, limit_train_batches=1, limit_val_batches=n_batches, - callbacks=ModelCheckpoint(dirpath=tmpdir, save_last=True), num_sanity_val_steps=0, ) @@ -235,13 +235,13 @@ def val_dataloader(self): ckpt_path = str(tmpdir / '.pl_auto_save.ckpt') checkpoint = torch.load(ckpt_path)["loops"]["fit_loop"] - total = (n_epochs - 1) * n_dataloaders + stop_dataloader + total_dataloader = stop_epoch * n_dataloaders + stop_dataloader expected = { "total": { - "ready": total + 1, + "ready": total_dataloader + 1, "started": None, "processed": None, - "completed": total + "completed": total_dataloader }, "current": { "ready": stop_dataloader + 1, @@ -253,13 +253,17 @@ def val_dataloader(self): assert checkpoint["epoch_loop.val_loop.dataloader_progress"] == expected trainer.fit_loop.load_state_dict(checkpoint, restart_progress=False) - total = n_dataloaders * n_batches + n_batches + stop_epoch + + # `nbe_`: non-breaking epoch, as in, no exception will be raised. `be_`: breaking epoch + nbe_total_val_batch = stop_epoch * n_dataloaders * n_batches + be_total_val_batch = stop_dataloader * n_batches + stop_batch + total_val_batch = nbe_total_val_batch + be_total_val_batch expected = { "total": { - "ready": total + 1, - "started": total + 1, - "processed": total, - "completed": total + "ready": total_val_batch + 1, + "started": total_val_batch + 1, + "processed": total_val_batch, + "completed": total_val_batch }, "current": { "ready": stop_batch + 1, @@ -273,10 +277,10 @@ def val_dataloader(self): trainer.fit_loop.load_state_dict(checkpoint) expected = { "total": { - "ready": total, - "started": total, - "processed": total, - "completed": total + "ready": total_val_batch, + "started": total_val_batch, + "processed": total_val_batch, + "completed": total_val_batch }, "current": { "ready": stop_batch, From d1a8bc0e55503679e3dfc5c62a5d3401bfded7e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 16 Jul 2021 04:16:45 +0200 Subject: [PATCH 145/157] Update CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e63657851449d..484c1362f9d83 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -92,7 +92,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fault-tolerant training * Added `{,load_}state_dict` to `ResultCollection` ([#7948](https://github.com/PyTorchLightning/pytorch-lightning/pull/7948)) * Added `{,load_}state_dict` to `Loops` ([#8197](https://github.com/PyTorchLightning/pytorch-lightning/pull/8197)) - * Set `Loop.restarting=False` at the end of the `run` in the first iteration ([#8362](https://github.com/PyTorchLightning/pytorch-lightning/pull/8362)) + * Set `Loop.restarting=False` at the end of the first iteration ([#8362](https://github.com/PyTorchLightning/pytorch-lightning/pull/8362)) * Save the loops state with the checkpoint (opt-in) ([#8362](https://github.com/PyTorchLightning/pytorch-lightning/pull/8362)) * Save a checkpoint to restore the state on exception (opt-in) ([#8362](https://github.com/PyTorchLightning/pytorch-lightning/pull/8362)) From a7a2781c1b90cd116c90d855d1e40d95daf89a84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 16 Jul 2021 15:22:34 +0200 Subject: [PATCH 146/157] fix test --- tests/loops/test_loop_state_dict.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/loops/test_loop_state_dict.py b/tests/loops/test_loop_state_dict.py index f014f8c619b54..99b9dce1ec8ad 100644 --- a/tests/loops/test_loop_state_dict.py +++ b/tests/loops/test_loop_state_dict.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import Mock + import pytest from pytorch_lightning.loops import FitLoop @@ -21,9 +23,9 @@ def test_loops_state_dict(): fit_loop = FitLoop() with pytest.raises(MisconfigurationException, match="Loop FitLoop should be connected to a"): - fit_loop.connect(object()) # noqa + fit_loop.trainer = object() - fit_loop.connect(Trainer()) + fit_loop.connect(Mock()) state_dict = fit_loop.state_dict() new_fit_loop = FitLoop() new_fit_loop.load_state_dict(state_dict) From 5d4cca7b3d6790fc3b362c4b95f9b6fe1d367ef7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 16 Jul 2021 15:31:31 +0200 Subject: [PATCH 147/157] move setters and add docs --- pytorch_lightning/trainer/properties.py | 62 +++++++++++++++++++++++-- pytorch_lightning/trainer/trainer.py | 36 -------------- 2 files changed, 58 insertions(+), 40 deletions(-) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 54d0079b9255e..650ca0f6b3873 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -52,9 +52,13 @@ class TrainerProperties(ABC): _default_root_dir: str + _fit_loop: FitLoop _lightning_optimizers = None + _predict_loop: PredictionLoop _progress_bar_callback: ProgressBarBase + _validate_loop: EvaluationLoop _weights_save_path: str + _test_loop: EvaluationLoop accelerator_connector: AcceleratorConnector callbacks: List[Callback] @@ -64,10 +68,7 @@ class TrainerProperties(ABC): logger: LightningLoggerBase logger_connector: LoggerConnector state: TrainerState - fit_loop: FitLoop - validate_loop: EvaluationLoop - test_loop: EvaluationLoop - predict_loop: PredictionLoop + """ Accelerator properties """ @@ -529,6 +530,59 @@ def min_steps(self) -> Optional[int]: def is_last_batch(self) -> bool: return self.fit_loop.epoch_loop.is_last_batch + @property + def fit_loop(self): + return self._fit_loop + + @fit_loop.setter + def fit_loop(self, loop: FitLoop): + """ + Attach a custom fit loop to this Trainer. It will run with + :meth:`~pytorch_lighting.trainer.trainer.Trainer.fit`. + """ + loop.trainer = self + self._fit_loop = loop + + @property + def validate_loop(self): + return self._validate_loop + + @validate_loop.setter + def validate_loop(self, loop: EvaluationLoop): + """ + Attach a custom validation loop to this Trainer. It will run with + :meth:`~pytorch_lighting.trainer.trainer.Trainer.validate`. Note that this loop is different from the one + running during training inside the :meth:`pytorch_lightning.trainer.trainer.Trainer.fit` call. + """ + loop.trainer = self + self._validate_loop = loop + + @property + def test_loop(self): + return self._test_loop + + @test_loop.setter + def test_loop(self, loop: EvaluationLoop): + """ + Attach a custom test loop to this Trainer. It will run with + :meth:`~pytorch_lightning.trainer.trainer.Trainer.test`. + """ + loop.trainer = self + self._test_loop = loop + + @property + def predict_loop(self): + return self._predict_loop + + @predict_loop.setter + def predict_loop(self, loop: PredictionLoop): + """ + Attach a custom prediction loop to this Trainer. It will run with + :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`. + """ + loop.trainer = self + self._predict_loop = loop + @property def _evaluation_loop(self) -> EvaluationLoop: if self.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING): diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7483f6cef48ff..2a716ebc3a4ac 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -453,42 +453,6 @@ def __init__( # Callback system self.on_init_end() - @property - def fit_loop(self): - return self._fit_loop - - @fit_loop.setter - def fit_loop(self, loop: FitLoop): - loop.trainer = self - self._fit_loop = loop - - @property - def validate_loop(self): - return self._validate_loop - - @validate_loop.setter - def validate_loop(self, loop: EvaluationLoop): - loop.trainer = self - self._validate_loop = loop - - @property - def test_loop(self): - return self._test_loop - - @test_loop.setter - def test_loop(self, loop: EvaluationLoop): - loop.trainer = self - self._test_loop = loop - - @property - def predict_loop(self): - return self._predict_loop - - @predict_loop.setter - def predict_loop(self, loop: PredictionLoop): - loop.trainer = self - self._predict_loop = loop - def _setup_on_init( self, num_sanity_val_steps: int, From 83592bfdaa323a4981a5b354983eb9b2842a77fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 16 Jul 2021 15:32:50 +0200 Subject: [PATCH 148/157] remove the loop examples for now --- pl_examples/loop_examples/__init__.py | 0 pl_examples/loop_examples/example1.py | 57 ----------- pl_examples/loop_examples/example2.py | 41 -------- pl_examples/loop_examples/example3.py | 53 ----------- pl_examples/loop_examples/simple_loop.py | 116 ----------------------- 5 files changed, 267 deletions(-) delete mode 100644 pl_examples/loop_examples/__init__.py delete mode 100644 pl_examples/loop_examples/example1.py delete mode 100644 pl_examples/loop_examples/example2.py delete mode 100644 pl_examples/loop_examples/example3.py delete mode 100644 pl_examples/loop_examples/simple_loop.py diff --git a/pl_examples/loop_examples/__init__.py b/pl_examples/loop_examples/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/pl_examples/loop_examples/example1.py b/pl_examples/loop_examples/example1.py deleted file mode 100644 index 6da59586dad2a..0000000000000 --- a/pl_examples/loop_examples/example1.py +++ /dev/null @@ -1,57 +0,0 @@ -import os - -from torch.utils.data import DataLoader - -from pl_examples.bug_report_model import BoringModel, RandomDataset -from pytorch_lightning import Trainer -from pytorch_lightning.loops import EvaluationLoop, FitLoop, TrainingBatchLoop, TrainingEpochLoop - - -def run(): - """ - This example demonstrates how loops are linked together. - Here we form a simple tree structure of three basic loops that make up the FitLoop: - - - Trainer - - fit_loop: FitLoop - - epoch_loop: TrainingEpochLoop - - batch_loop: TrainingBatchLoop - - val_loop: EvaluationLoop - """ - train_data = DataLoader(RandomDataset(32, 64), batch_size=2) - val_data = DataLoader(RandomDataset(32, 64), batch_size=2) - test_data = DataLoader(RandomDataset(32, 64), batch_size=2) - - model = BoringModel() - - trainer = Trainer( - default_root_dir=os.getcwd(), - limit_train_batches=1, - limit_val_batches=1, - num_sanity_val_steps=0, - max_epochs=1, - weights_summary=None, - ) - - # construct loops - fit_loop = FitLoop(max_epochs=2) - train_epoch_loop = TrainingEpochLoop(min_steps=0, max_steps=2) - train_batch_loop = TrainingBatchLoop() - val_loop = EvaluationLoop() - - # connect loops together - train_epoch_loop.connect(batch_loop=train_batch_loop, val_loop=val_loop) - fit_loop.connect(epoch_loop=train_epoch_loop) - - # connect fit loop to trainer (main entry point for the call in trainer.fit()) - trainer.fit_loop = fit_loop - - # this will use the newly constructed loop! - trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data) - - # this will still use the default test loop - trainer.test(model, dataloaders=test_data) - - -if __name__ == '__main__': - run() diff --git a/pl_examples/loop_examples/example2.py b/pl_examples/loop_examples/example2.py deleted file mode 100644 index 4a347a1777909..0000000000000 --- a/pl_examples/loop_examples/example2.py +++ /dev/null @@ -1,41 +0,0 @@ -import os - -from torch.utils.data import DataLoader - -from pl_examples.bug_report_model import RandomDataset -from pl_examples.loop_examples.example1 import BoringModel -from pl_examples.loop_examples.simple_loop import SimpleLoop -from pytorch_lightning import Trainer - - -def run(): - """ - This example shows how to replace the FitLoop on the Trainer with a very simple, custom iteration-based - training loop. - """ - train_data = DataLoader(RandomDataset(32, 64), batch_size=2) - - model = BoringModel() - trainer = Trainer( - default_root_dir=os.getcwd(), - limit_train_batches=1, - limit_val_batches=1, - num_sanity_val_steps=0, - max_epochs=1, - weights_summary=None, - progress_bar_refresh_rate=1, - ) - - # instantiate the new loop - simple_loop = SimpleLoop(num_iterations=1000) - - # replace the fit loop - # the trainer reference will be set internally - trainer.fit_loop = simple_loop - - # fit using the new loop! - trainer.fit(model, train_dataloader=train_data) - - -if __name__ == '__main__': - run() diff --git a/pl_examples/loop_examples/example3.py b/pl_examples/loop_examples/example3.py deleted file mode 100644 index a11f7aa84a2d5..0000000000000 --- a/pl_examples/loop_examples/example3.py +++ /dev/null @@ -1,53 +0,0 @@ -import os - -from torch.utils.data import DataLoader - -from pl_examples.bug_report_model import BoringModel, RandomDataset -from pytorch_lightning import Trainer -from pytorch_lightning.loops import EvaluationLoop, TrainingBatchLoop - - -def run(): - """ - This example shows how to switch out an individual loop. - Here, we want to take the default FitLoop from Lightning but switch out - - 1. the batch_loop inside the training epoch loop - 2. the val_loop inside the training epoch loop - - """ - train_data = DataLoader(RandomDataset(32, 64), batch_size=2) - val_data = DataLoader(RandomDataset(32, 64), batch_size=2) - - model = BoringModel() - - trainer = Trainer( - default_root_dir=os.getcwd(), - limit_train_batches=1, - limit_val_batches=1, - num_sanity_val_steps=0, - max_epochs=1, - weights_summary=None, - ) - - # instantiate the new batch- and validation loop - new_batch_loop = TrainingBatchLoop() - new_val_loop = EvaluationLoop() - - # call connect on the existing, default fit_loop.epoch_loop - trainer.fit_loop.epoch_loop.connect(batch_loop=new_batch_loop, val_loop=new_val_loop) - - # the new batch loop is registered - assert trainer.fit_loop.epoch_loop.batch_loop is new_batch_loop - - # the trainer is not yet registered, will be done by the trainer internally - assert trainer.fit_loop.epoch_loop.batch_loop.trainer is None - - # this uses the new custom batch loop - trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data) - - assert trainer.fit_loop.epoch_loop.batch_loop.trainer is trainer - - -if __name__ == '__main__': - run() diff --git a/pl_examples/loop_examples/simple_loop.py b/pl_examples/loop_examples/simple_loop.py deleted file mode 100644 index 84ea0d6c3821c..0000000000000 --- a/pl_examples/loop_examples/simple_loop.py +++ /dev/null @@ -1,116 +0,0 @@ -from collections import OrderedDict -from typing import Any, Dict, Iterator, Optional, Tuple, Union - -import torch -from torch import Tensor -from torch.optim import Optimizer - -from pytorch_lightning.loops import Loop -from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection - - -class SimpleLoop(Loop): - """ - This loop is for demonstration purposes only. - It implements a purely iteration-based loop with a bare miminum of functionality. - - 1 optimizer - - no logging - - no grad accumulation - - no epoch hooks calling - """ - - def __init__(self, num_iterations: int = float("inf")): - super().__init__() - self.num_iterations = num_iterations - self.train_dataloader: Optional[Iterator] = None - - # required for trainer and logger connector - self._results = ResultCollection(training=True) - - @property - def global_step(self) -> int: - return self.iteration_count - - @property - def batch_idx(self) -> int: - # required by progress bar - return self.iteration_count - - @property - def running_loss(self) -> Tensor: - # required by progress bar - return torch.tensor(123.) - - @property - def current_epoch(self) -> int: - return 0 - - @property - def skip(self) -> bool: - return self.done or self.trainer.num_training_batches == 0 - - @property - def done(self) -> bool: - return self.iteration_count >= self.num_iterations - - def reset(self) -> None: - self.iteration_count = 0 - - def on_run_start(self) -> None: - self.train_dataloader = iter(self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader)) - self.trainer.call_hook("on_train_start") - - def advance(self) -> None: - batch = next(self.train_dataloader) - - opt_idx = 0 - optimizer = self.trainer.optimizers[opt_idx] - - self.trainer.call_hook("on_train_batch_start", batch, self.iteration_count, dataloader_idx=0) - - output = self._run_optimization(batch, self.iteration_count, optimizer) - - # hook - self.trainer.call_hook("on_train_batch_end", output, batch, self.iteration_count, dataloader_idx=0) - self.trainer.call_hook("on_batch_end") - - def on_run_end(self) -> None: - self.trainer.call_hook("on_train_end") - self.trainer.accelerator.on_train_end() - self.trainer._running_stage = None - - def _run_optimization(self, batch: Any, batch_idx: int, optimizer: Optimizer): - lightning_module = self.trainer.lightning_module - - # lightning module training_step - step_kwargs = OrderedDict([("batch", batch), ("batch_idx", batch_idx)]) - lightning_module._current_fx_name = "training_step" - training_step_output = self.trainer.accelerator.training_step(step_kwargs) - self.trainer.accelerator.post_training_step() - - training_step_output = self.trainer.call_hook("training_step_end", training_step_output) - loss, extra = self._process_training_step_output(training_step_output) - - # backward pass (single optimizer, no accumulation supported) - self.trainer.accelerator.backward(loss, optimizer, optimizer_idx=0, should_accumulate=False) - - # optimizer step (no closures supported) - lightning_module.optimizer_step(optimizer=optimizer) - - output = extra - output["loss"] = loss.detach() - return output - - @staticmethod - def _process_training_step_output(training_step_output: Union[Dict, Tensor]) -> Tuple[Tensor, Dict]: - loss = None - extra = {} - - if isinstance(training_step_output, dict): - loss = training_step_output.pop("loss") - extra = training_step_output - - elif isinstance(training_step_output, Tensor): - loss = training_step_output - - return loss, extra From 34664ecb5c264e06b59ee9dd78196be90d982c53 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 16 Jul 2021 13:34:22 +0000 Subject: [PATCH 149/157] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/trainer/properties.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 650ca0f6b3873..aba04596a0627 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -68,7 +68,6 @@ class TrainerProperties(ABC): logger: LightningLoggerBase logger_connector: LoggerConnector state: TrainerState - """ Accelerator properties """ From c3d7a4e02a57290a65535d334e7384bf6e8995f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 16 Jul 2021 15:36:54 +0200 Subject: [PATCH 150/157] alphabetical ordering --- pytorch_lightning/trainer/properties.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index aba04596a0627..685ad979ee3fe 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -56,9 +56,9 @@ class TrainerProperties(ABC): _lightning_optimizers = None _predict_loop: PredictionLoop _progress_bar_callback: ProgressBarBase + _test_loop: EvaluationLoop _validate_loop: EvaluationLoop _weights_save_path: str - _test_loop: EvaluationLoop accelerator_connector: AcceleratorConnector callbacks: List[Callback] From db90f79a304bd62ea6a141c71820b05f4d3e2a5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 19 Jul 2021 11:30:51 +0200 Subject: [PATCH 151/157] test connect() method on loops --- tests/loops/test_loops.py | 59 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 57 insertions(+), 2 deletions(-) diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 695a0c7be16a0..0143bdc5a61e4 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -21,12 +21,67 @@ import pytest import torch -from pytorch_lightning.loops.base import Loop +from pytorch_lightning import Trainer +from pytorch_lightning.loops import Loop from pytorch_lightning.trainer.progress import BaseProgress -from pytorch_lightning.trainer.trainer import Trainer from tests.helpers import BoringModel +class NestedLoop(Loop): + + def __init__(self): + super().__init__() + self.child_loop0 = None + self.child_loop1 = None + + @property + def done(self) -> bool: + return False + + def connect(self, child0, child1): + self.child_loop0 = child0 + self.child_loop1 = child1 + + def reset(self) -> None: + pass + + def advance(self, *args, **kwargs): + pass + + +@pytest.mark.parametrize("loop_name", [ + "fit_loop", + "validate_loop", + "test_loop", + "predict_loop", +]) +def test_connect_loops_direct(loop_name): + """ Test Trainer referenes in loops on assignment. """ + loop = NestedLoop() + assert loop.trainer is None + + trainer = Trainer() + + # trainer.loop = loop + setattr(trainer, loop_name, loop) + assert loop.trainer is trainer + + +def test_connect_loops_recursive(): + """ Test Trainer references in a nested loop assigned to a Trainer. """ + main_loop = NestedLoop() + child0 = NestedLoop() + child1 = NestedLoop() + main_loop.connect(child0, child1) + assert main_loop.trainer is None + assert main_loop.child_loop0.trainer is None + + trainer = Trainer() + trainer.fit_loop = main_loop + assert child0.trainer is trainer + assert child1.trainer is trainer + + class CustomException(Exception): pass From 995d346fce0dc2510bfcd315ad49960ce39d6f07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 19 Jul 2021 11:33:42 +0200 Subject: [PATCH 152/157] update unused imports --- pytorch_lightning/loops/base.py | 1 - pytorch_lightning/loops/dataloader/evaluation_loop.py | 1 - pytorch_lightning/loops/epoch/training_epoch_loop.py | 1 - pytorch_lightning/loops/fit_loop.py | 3 +-- 4 files changed, 1 insertion(+), 5 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index ccef7b5843d31..d3b6ce8a03c02 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -16,7 +16,6 @@ from typing import Any, Dict, Optional from deprecate import void -from onnx.backend.test.case.node.loop import Loop import pytorch_lightning as pl from pytorch_lightning.trainer.progress import BaseProgress, Tracker diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 973d18405b737..8eacd73607665 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -17,7 +17,6 @@ from deprecate.utils import void from torch.utils.data.dataloader import DataLoader -import pytorch_lightning as pl from pytorch_lightning.loops.dataloader import DataLoaderLoop from pytorch_lightning.loops.epoch import EvaluationEpochLoop from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index ce440c60a9fcf..21e8487ac22eb 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -16,7 +16,6 @@ import torch -import pytorch_lightning as pl from pytorch_lightning import loops # import as loops to avoid circular imports from pytorch_lightning.loops.batch import TrainingBatchLoop from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 0686c70dfdd90..9df3707631285 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -14,9 +14,8 @@ import logging from contextlib import suppress -from typing import Any, Optional +from typing import Optional -import pytorch_lightning as pl from pytorch_lightning.loops import Loop from pytorch_lightning.loops.epoch import TrainingEpochLoop from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection From c7c232b2549eb1a486dfd49fe22efc1692dcc437 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 19 Jul 2021 11:47:45 +0200 Subject: [PATCH 153/157] test connect subloop --- pytorch_lightning/loops/fit_loop.py | 2 +- tests/loops/test_loops.py | 20 +++++++++++++++++++- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 9df3707631285..b637d4e3e3d4c 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -39,7 +39,7 @@ def __init__(self, min_epochs: Optional[int] = None, max_epochs: Optional[int] = super().__init__() self.max_epochs = max_epochs self.min_epochs = min_epochs - self.epoch_loop = None + self.epoch_loop: Optional[TrainingEpochLoop] = None self.epoch_progress = Progress() @property diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 0143bdc5a61e4..cc18d38b56146 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -22,7 +22,7 @@ import torch from pytorch_lightning import Trainer -from pytorch_lightning.loops import Loop +from pytorch_lightning.loops import Loop, TrainingEpochLoop, TrainingBatchLoop from pytorch_lightning.trainer.progress import BaseProgress from tests.helpers import BoringModel @@ -82,6 +82,24 @@ def test_connect_loops_recursive(): assert child1.trainer is trainer +def test_connect_subloops(tmpdir): + """ Test connecting individual subloops by calling `trainer.x.y.connect()` """ + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + ) + + epoch_loop = trainer.fit_loop.epoch_loop + new_batch_loop = TrainingBatchLoop() + epoch_loop.connect(batch_loop=new_batch_loop) + assert epoch_loop.batch_loop is new_batch_loop + assert new_batch_loop.trainer is None + + trainer.fit(model) + assert new_batch_loop.trainer is trainer + + class CustomException(Exception): pass From 155f9be7f7ac2951dc1bfbbbad48a2f4d9dbcd10 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 19 Jul 2021 09:49:11 +0000 Subject: [PATCH 154/157] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/loops/test_loops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index cc18d38b56146..93f0eecb2e2f0 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -22,7 +22,7 @@ import torch from pytorch_lightning import Trainer -from pytorch_lightning.loops import Loop, TrainingEpochLoop, TrainingBatchLoop +from pytorch_lightning.loops import Loop, TrainingBatchLoop, TrainingEpochLoop from pytorch_lightning.trainer.progress import BaseProgress from tests.helpers import BoringModel From 5735fe8ae75f4969d9a8ac0d8dca664b26defccc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 19 Jul 2021 11:59:12 +0200 Subject: [PATCH 155/157] udpate changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1bb7a3cad4ade..1e37814374e73 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -175,6 +175,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Enabled traditional/manual launching of DDP processes through `LOCAL_RANK` and `NODE_RANK` environment variable assignments ([#7480](https://github.com/PyTorchLightning/pytorch-lightning/pull/7480)) +- Added experimental support for loop specialization ([#8226](https://github.com/PyTorchLightning/pytorch-lightning/pull/8226)) + + ### Changed From e2c76915bba2be9991b69973a8acef8d93b241ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 19 Jul 2021 12:01:05 +0200 Subject: [PATCH 156/157] update unused imports --- tests/loops/test_loops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index cc18d38b56146..ef8954f58087c 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -22,7 +22,7 @@ import torch from pytorch_lightning import Trainer -from pytorch_lightning.loops import Loop, TrainingEpochLoop, TrainingBatchLoop +from pytorch_lightning.loops import Loop, TrainingBatchLoop from pytorch_lightning.trainer.progress import BaseProgress from tests.helpers import BoringModel From 8a58e01b54596ee2ff33c2ad3130100c77583846 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 19 Jul 2021 13:54:59 +0200 Subject: [PATCH 157/157] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/loops/epoch/training_epoch_loop.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 21e8487ac22eb..a79b58efe9d31 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -52,8 +52,8 @@ def __init__(self, min_steps: int, max_steps: int): self.batch_progress = Progress() self.scheduler_progress = SchedulerProgress() - self.batch_loop = Optional[TrainingBatchLoop] - self.val_loop = Optional["loops.EvaluationLoop"] + self.batch_loop: Optional[TrainingBatchLoop] = None + self.val_loop: Optional["loops.EvaluationLoop"] = None self._results = ResultCollection(training=True) self._dataloader_idx: Optional[int] = None @@ -79,7 +79,7 @@ def connect( batch_loop: Optional[TrainingBatchLoop] = None, val_loop: Optional["loops.EvaluationLoop"] = None, ) -> None: - """Optionally connect a custom batch- or validation loop to this training epoch loop.""" + """Optionally connect a custom batch or validation loop to this training epoch loop.""" if batch_loop is not None: self.batch_loop = batch_loop if val_loop is not None: